From b3104b2a10ab7cb7532442177ae0d0c40acf9d03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E8=AF=91=E6=96=87?= <1020030101@qq.com> Date: Wed, 10 Apr 2024 15:09:36 +0800 Subject: [PATCH 001/413] [Bugfix] Fix logits processor when prompt_logprobs is not None (#3899) --- tests/samplers/test_logits_processor.py | 62 +++++++++++++++++++ .../model_executor/layers/logits_processor.py | 11 +++- 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/samplers/test_logits_processor.py diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py new file mode 100644 index 0000000000000..3788e9e9752ff --- /dev/null +++ b/tests/samplers/test_logits_processor.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_logits_processor_force_generate( + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + tokenizer = vllm_model.model.get_tokenizer() + repeat_times = 2 + enforced_answers = " vLLM" + vllm_token_ids = tokenizer.encode(enforced_answers, + add_special_tokens=False) + max_tokens = len(vllm_token_ids) * repeat_times + + def pick_vllm(token_ids, logits): + token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] + logits[token_id] = torch.finfo(logits.dtype).max + return logits + + params_with_logprobs = SamplingParams( + logits_processors=[pick_vllm], + prompt_logprobs=3, + max_tokens=max_tokens, + ) + + # test logits_processors when prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[0], + sampling_params=params_with_logprobs, + prompt_token_ids=None, + ) + + # test prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[1], + sampling_params=SamplingParams( + prompt_logprobs=3, + max_tokens=max_tokens, + ), + prompt_token_ids=None, + ) + + # test grouped requests + vllm_model.model._add_request( + prompt=example_prompts[2], + sampling_params=SamplingParams(max_tokens=max_tokens), + prompt_token_ids=None, + ) + + outputs = vllm_model.model._run_engine(False) + + assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 28e8f6bb7e638..ec531f79ced52 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -86,8 +86,16 @@ def _apply_logits_processors( ) -> torch.Tensor: logits_row_idx = 0 found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group logits_processors = sampling_params.logits_processors + # handle prompt_logprobs by skipping rows in logits added for + # the prompt tokens (prompt logprobs are not processed) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + logits_row_idx += sampling_metadata.prompt_lens[i] - 1 + if logits_processors: found_logits_processors = True for seq_id in seq_ids: @@ -100,5 +108,6 @@ def _apply_logits_processors( else: logits_row_idx += len(seq_ids) if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly assert logits_row_idx == logits.shape[0] return logits From 0258b7a94b08321ca01cf170f867b67c1920af87 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 10 Apr 2024 02:39:56 -0600 Subject: [PATCH 002/413] [Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876) Signed-off-by: Travis Johnson --- tests/samplers/test_sampler.py | 116 +++++++++++++++++++++----- vllm/model_executor/layers/sampler.py | 19 ++++- 2 files changed, 112 insertions(+), 23 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1626b72282072..26e2d29ffd04c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,3 +1,4 @@ +import itertools import random from typing import List, Optional, Tuple from unittest.mock import patch @@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, - stop_token_ids=None): + *, + stop_token_ids: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, max_tokens=9999, # keep higher than max of min_tokens stop_token_ids=stop_token_ids, + # requesting prompt_logprobs changes the structure of `logits` + prompt_logprobs=prompt_logprobs, ) sampling_params.eos_token_id = eos_token_id return sampling_params @@ -217,9 +222,9 @@ def generate_test_case(): expected_penalization = [] sequence_metadata_list = [] + # 20% chance to generate seq group metadata list with all prompts + is_prompt = random.random() < 0.2 while batch_size > 0: - # 20% chance to generate prompt seq group with single sequence - is_prompt = random.random() < 0.2 num_seqs = 1 if is_prompt else random.randint(1, batch_size) eos_token_id = random.randint(0, VOCAB_SIZE - 1) @@ -240,7 +245,7 @@ def generate_test_case(): seq_group_penalization = [] for _ in range(num_seqs): num_input = random.randint(1, 100) - num_generated = random.randint(1, 100) if not is_prompt else 0 + num_generated = 0 if is_prompt else random.randint(1, 100) seq_data[next(seq_id_counter)] = create_sequence_data( num_input=num_input, num_generated=num_generated) seq_group_penalization.append(num_generated < min_tokens) @@ -292,6 +297,21 @@ def generate_test_case(): ] } + prompt_with_penalization_and_prompt_logprobs = { + "expected_penalization": [False, False, True], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=3), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + ] + } + stop_penalizing_after_min_tokens = { "expected_penalization": [False], "seq_group_metadata_list": [ @@ -309,8 +329,34 @@ def generate_test_case(): } stop_token_ids = [42, 99, 42, 0] # intentional duplication - simple_combination = { - "expected_penalization": [True, False, False], + prompt_combination = { + "expected_penalization": [False, True, False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_2", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(num_input=2), + }, + sampling_params=create_sampling_params(1, prompt_logprobs=3), + block_tables={}, + ), + SequenceGroupMetadata( + request_id="test_3", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params( + 0, stop_token_ids=stop_token_ids), + block_tables={}, + ) + ] + } + + stop_token_ids = [1, 999, 37, 37] # intentional duplication + decode_combination = { + "expected_penalization": [True, False, False, True, False], "seq_group_metadata_list": [ SequenceGroupMetadata( request_id="test_1", @@ -327,14 +373,19 @@ def generate_test_case(): ), SequenceGroupMetadata( request_id="test_2", - is_prompt=True, + is_prompt=False, seq_data={ - next(seq_id_counter): create_sequence_data(), + next(seq_id_counter): + create_sequence_data(num_generated=20), + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=10), }, sampling_params=create_sampling_params( - 0, stop_token_ids=stop_token_ids), + 10, prompt_logprobs=5, stop_token_ids=stop_token_ids), block_tables={}, - ) + ), ] } @@ -342,8 +393,10 @@ def generate_test_case(): test_cases = [ prompt_without_penalization, prompt_with_penalization, + prompt_with_penalization_and_prompt_logprobs, stop_penalizing_after_min_tokens, - simple_combination, + prompt_combination, + decode_combination, ] else: test_cases = [generate_test_case()] @@ -351,30 +404,49 @@ def generate_test_case(): def run_test_case(*, expected_penalization=None, seq_group_metadata_list=None): - assert expected_penalization, "Invalid test case" - assert seq_group_metadata_list, "Invalid test case" + assert expected_penalization, \ + "Invalid test case, need expected_penalization" + assert seq_group_metadata_list, \ + "Invalid test case, need seq_group_metadata_list" batch_size = 0 prompt_lens = [] - sampling_params_per_seq = [] + sampling_params_per_row = [] for sgm in seq_group_metadata_list: - num_seqs = len(sgm.seq_data) - batch_size += num_seqs sampling_params = sgm.sampling_params - for seq_id in sgm.seq_data: - prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) - sampling_params_per_seq.append(sampling_params) + + num_rows = len(sgm.seq_data) + if sgm.is_prompt: + # a prompt seq_group has only one sequence + seq_data = next(iter(sgm.seq_data.values())) + prompt_len = seq_data.get_prompt_len() + prompt_lens.append(prompt_len) + + if sgm.sampling_params.prompt_logprobs: + # with prompt_logprobs each token in the prompt has a row in + # logits + num_rows = prompt_len + + batch_size += num_rows + sampling_params_per_row.extend( + itertools.repeat(sampling_params, num_rows)) + + assert len( + expected_penalization + ) == batch_size, \ + ("Invalid test case, expected_penalization does not match computed" + "batch size") _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, - prompt_lens=prompt_lens, - subquery_lens=prompt_lens) + prompt_lens=prompt_lens if prompt_lens else None, + subquery_lens=prompt_lens if prompt_lens else None) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) for logits_idx, (should_penalize, sampling_params) in enumerate( - zip(expected_penalization, sampling_params_per_seq)): + zip(expected_penalization, sampling_params_per_row)): tokens_to_check = [sampling_params.eos_token_id] if sampling_params.stop_token_ids: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cb1480de03e3a..03bf38caebe0e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -27,6 +27,12 @@ class Sampler(nn.Module): 6. Sample the next tokens. Here, each sequence group within the batch can have different sampling parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. """ def forward( @@ -106,7 +112,16 @@ def _apply_min_tokens_penalty( # list of indices in logits that will be set to -inf logits_to_penalize = [] start_idx = 0 - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + + # handle prompt_logprobs by skipping rows in logits added for the prompt + # tokens (prompt logprobs are not penalized) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + start_idx += sampling_metadata.prompt_lens[i] - 1 + min_tokens = sampling_params.min_tokens if min_tokens > 0: seqs_to_penalize = [] @@ -132,6 +147,8 @@ def _apply_min_tokens_penalty( # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") + # verifies that no rows in logits were missed unexpectedly + assert start_idx == logits.shape[0] return logits From bd3c144e0b8e82c9b3c5c40c6d557fe8665de5a3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 10 Apr 2024 07:37:17 -0700 Subject: [PATCH 003/413] [Bugfix][ROCm] Add numba to Dockerfile.rocm (#3962) --- Dockerfile.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 10b8bf1e7fabd..b1c5fac9d78ef 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -91,7 +91,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ COPY ./ /app/vllm -RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --upgrade pip numba RUN python3 -m pip install xformers==0.0.23 --no-deps RUN cd /app \ From 8b317c6dd09ce566f4b4abeb446585ac75262cce Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Wed, 10 Apr 2024 10:12:00 -0500 Subject: [PATCH 004/413] [Model][AMD] ROCm support for 256 head dims for Gemma (#3972) --- vllm/attention/ops/triton_flash_attention.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index b86e845020b07..87cf30cbef79a 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -677,8 +677,7 @@ def check_args( assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 + assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 @@ -729,7 +728,7 @@ def forward( o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} + unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: From e35397468f36a857b8d2b7d92a472265e1c500cc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 10:03:02 -0700 Subject: [PATCH 005/413] [Doc] Add doc to state our model support policy (#3948) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- docs/source/models/supported_models.rst | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e7bfdcb65316e..c09b0ff250437 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -168,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub `_ and `test_big_models.py `_ for the models that have passed this test. +2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. From e4c4072c94b346053768691451566c56664e26a7 Mon Sep 17 00:00:00 2001 From: Daniel E Marasco Date: Wed, 10 Apr 2024 13:15:51 -0400 Subject: [PATCH 006/413] [Bugfix] Remove key sorting for `guided_json` parameter in OpenAi compatible Server (#3945) --- vllm/model_executor/guided_decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index e56f74c7794fb..8e710f1ac2b53 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -91,7 +91,7 @@ def _get_guide_and_mode( json = request.guided_json if isinstance(json, dict): # turn dict into hashable string - json = json_dumps(json, sort_keys=True) + json = json_dumps(json) elif isinstance(json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same From 92cd2e2f21e8ec65b2cb635a9f15de38157a1359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=CE=B1n=C3=A7ois?= Date: Wed, 10 Apr 2024 20:05:52 +0200 Subject: [PATCH 007/413] [Doc] Fix getting stared to use publicly available model (#3963) --- docs/source/serving/openai_compatible_server.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 032fe5d03bd52..388b5daa79a92 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123 +python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -16,9 +16,8 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="mistralai/Mistral-7B-Instruct-v0.2", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"} ] ) @@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="meta-llama/Llama-2-7b-hf", + model="mistralai/Mistral-7B-Instruct-v0.2", messages=[ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], extra_body={ @@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12) +An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format) Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat From 934d3662f716d60abfb04cf9fdd6d20f6e75f140 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 10 Apr 2024 16:28:25 -0600 Subject: [PATCH 008/413] [Bugfix] handle hf_config with architectures == None (#3982) Signed-off-by: Travis Johnson Co-authored-by: Simon Mo --- vllm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 753fc33e9b717..bca250e922288 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -158,7 +158,9 @@ def _verify_load_format(self) -> None: # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures and load_format == "pt": + # architectures can be None instead of [] + if architectures and "MixtralForCausalLM" in architectures \ + and load_format == "pt": raise ValueError( "Currently, the 'pt' format is not supported for Mixtral. " "Please use the 'safetensors' format instead. ") From 63e7176f265be43dcc425f5ab4ab45c90234f5c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 15:33:30 -0700 Subject: [PATCH 009/413] [Core][Refactor] move parallel_utils into vllm/distributed (#3950) [WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950) --- tests/conftest.py | 3 +-- tests/distributed/test_comm_ops.py | 6 +++--- tests/distributed/test_custom_all_reduce.py | 13 ++++++------- tests/distributed/test_pynccl.py | 4 ++-- tests/lora/conftest.py | 3 +-- vllm/distributed/__init__.py | 3 +++ .../communication_op.py | 14 ++++++++------ .../device_communicators}/__init__.py | 0 .../device_communicators}/custom_all_reduce.py | 5 +++-- .../device_communicators}/pynccl.py | 0 .../device_communicators}/pynccl_utils.py | 4 ++-- .../parallel_state.py | 4 ++-- .../parallel_utils => distributed}/utils.py | 0 vllm/lora/layers.py | 13 ++++++------- vllm/model_executor/layers/activation.py | 5 ++--- vllm/model_executor/layers/linear.py | 11 +++++------ vllm/model_executor/layers/logits_processor.py | 3 +-- .../layers/vocab_parallel_embedding.py | 8 +++----- vllm/model_executor/models/baichuan.py | 4 ++-- vllm/model_executor/models/bloom.py | 4 ++-- vllm/model_executor/models/chatglm.py | 3 +-- vllm/model_executor/models/commandr.py | 4 ++-- vllm/model_executor/models/dbrx.py | 7 +++---- vllm/model_executor/models/deepseek.py | 7 +++---- vllm/model_executor/models/falcon.py | 7 +++---- vllm/model_executor/models/gemma.py | 3 +-- vllm/model_executor/models/gpt2.py | 3 +-- vllm/model_executor/models/gpt_bigcode.py | 3 +-- vllm/model_executor/models/gpt_j.py | 3 +-- vllm/model_executor/models/gpt_neox.py | 3 +-- vllm/model_executor/models/internlm2.py | 3 +-- vllm/model_executor/models/jais.py | 4 ++-- vllm/model_executor/models/llama.py | 4 ++-- vllm/model_executor/models/minicpm.py | 7 +++---- vllm/model_executor/models/mixtral.py | 7 +++---- vllm/model_executor/models/mixtral_quant.py | 7 +++---- vllm/model_executor/models/mpt.py | 4 ++-- vllm/model_executor/models/olmo.py | 3 +-- vllm/model_executor/models/opt.py | 3 +-- vllm/model_executor/models/orion.py | 3 +-- vllm/model_executor/models/phi.py | 3 +-- vllm/model_executor/models/qwen.py | 3 +-- vllm/model_executor/models/qwen2.py | 3 +-- vllm/model_executor/models/qwen2_moe.py | 7 +++---- vllm/model_executor/models/stablelm.py | 3 +-- vllm/model_executor/models/starcoder2.py | 3 +-- vllm/model_executor/models/xverse.py | 3 +-- vllm/model_executor/parallel_utils/README.md | 1 - vllm/test_utils.py | 4 ++-- vllm/worker/cpu_worker.py | 7 +++---- vllm/worker/model_runner.py | 8 +++----- vllm/worker/worker.py | 12 ++++++------ 52 files changed, 111 insertions(+), 141 deletions(-) create mode 100644 vllm/distributed/__init__.py rename vllm/{model_executor/parallel_utils => distributed}/communication_op.py (94%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/__init__.py (100%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/custom_all_reduce.py (98%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl.py (100%) rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl_utils.py (91%) rename vllm/{model_executor/parallel_utils => distributed}/parallel_state.py (98%) rename vllm/{model_executor/parallel_utils => distributed}/utils.py (100%) delete mode 100644 vllm/model_executor/parallel_utils/README.md diff --git a/tests/conftest.py b/tests/conftest.py index e00f3eb871e37..a7e8963af0eda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel) +from vllm.distributed import destroy_model_parallel from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index d1811cb694db6..aa9e0537c6910 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -8,9 +8,9 @@ import ray import torch -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) +from vllm.distributed import (broadcast_tensor_dict, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 1e6e7f89a528c..3b1cd1773af19 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,9 +6,8 @@ import torch import torch.distributed as dist -from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators import custom_all_reduce from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_ar.init_custom_ar() + custom_all_reduce.init_custom_all_reduce() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_ar.capture(): + with custom_all_reduce.capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_ar.init_custom_ar() - fa = custom_ar.get_handle() + custom_all_reduce.init_custom_all_reduce() + fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) assert torch.allclose(out, inp * world_size) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 29782045130a6..b50eed1c8c722 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,8 +4,8 @@ import pytest import torch -from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetUniqueId) +from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, + ncclGetUniqueId) def distributed_run(fn, world_size): diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index acb5fa91e2012..207c635e2dc86 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -12,6 +12,7 @@ import vllm from vllm.config import LoRAConfig +from vllm.distributed import destroy_model_parallel, initialize_model_parallel from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) @@ -19,8 +20,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.parallel_state import ( - destroy_model_parallel, initialize_model_parallel) def cleanup(): diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 0000000000000..db325cfabf55e --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/distributed/communication_op.py similarity index 94% rename from vllm/model_executor/parallel_utils/communication_op.py rename to vllm/distributed/communication_op.py index 9cbb40708dd5b..cf15db099b304 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -4,12 +4,10 @@ import torch from torch.distributed import ProcessGroup -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.custom_all_reduce import ( - custom_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) +from .parallel_state import (get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + is_pynccl_enabled_for_all_reduce) def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ + from vllm.distributed.device_communicators import pynccl_utils + from vllm.distributed.device_communicators.custom_all_reduce import ( + custom_all_reduce) + # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ diff --git a/vllm/model_executor/parallel_utils/__init__.py b/vllm/distributed/device_communicators/__init__.py similarity index 100% rename from vllm/model_executor/parallel_utils/__init__.py rename to vllm/distributed/device_communicators/__init__.py diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py similarity index 98% rename from vllm/model_executor/parallel_utils/custom_all_reduce.py rename to vllm/distributed/device_communicators/custom_all_reduce.py index bf8ee07070c8a..84238d2e46076 100644 --- a/vllm/model_executor/parallel_utils/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -5,8 +5,6 @@ import torch.distributed as dist from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) try: import pynvml @@ -25,6 +23,9 @@ def init_custom_ar() -> None: + from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + global _CA_HANDLE if _CA_HANDLE is not None: return diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/distributed/device_communicators/pynccl.py similarity index 100% rename from vllm/model_executor/parallel_utils/pynccl.py rename to vllm/distributed/device_communicators/pynccl.py diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py similarity index 91% rename from vllm/model_executor/parallel_utils/pynccl_utils.py rename to vllm/distributed/device_communicators/pynccl_utils.py index a099777aa0005..aeb73015733d1 100644 --- a/vllm/model_executor/parallel_utils/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -9,8 +9,8 @@ logger = init_logger(__name__) try: - from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, - ncclGetVersion) + from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, + ncclGetVersion) except Exception as e: # in non-NVIDIA environments, we can't import the nccl module # e.g. when running on machines with AMD GPUs diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/distributed/parallel_state.py similarity index 98% rename from vllm/model_executor/parallel_utils/parallel_state.py rename to vllm/distributed/parallel_state.py index 3bbfa1bd5443a..4bb77146295af 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,8 +8,6 @@ import torch -from vllm.model_executor.parallel_utils import pynccl_utils - # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -266,6 +264,7 @@ def destroy_model_parallel(): _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None + from vllm.distributed.device_communicators import pynccl_utils # Destroy the pynccl states if any. pynccl_utils.destroy_process_group() @@ -279,6 +278,7 @@ def destroy_model_parallel(): @contextlib.contextmanager def with_pynccl_for_all_reduce(): + from vllm.distributed.device_communicators import pynccl_utils """use pynccl instead of torch.distributed for all reduce""" tp_size = get_tensor_model_parallel_world_size() if tp_size == 1: diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/distributed/utils.py similarity index 100% rename from vllm/model_executor/parallel_utils/utils.py rename to vllm/distributed/utils.py diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0505014753951..dd33868f76302 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -10,6 +10,12 @@ from transformers import PretrainedConfig from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather) from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -18,13 +24,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, - tensor_model_parallel_gather) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - split_tensor_along_last_dim) if TYPE_CHECKING: pass diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f569a5a49cbdf..6786c48e0caba 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -7,10 +7,9 @@ import torch.nn.functional as F from vllm._C import ops +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide from vllm.model_executor.utils import set_weight_attrs diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2d..8f42b3e8a4abe 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,13 +5,12 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import ( - divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index ec531f79ced52..e556e31f99378 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_gather) +from vllm.distributed import tensor_model_parallel_gather from vllm.model_executor.sampling_metadata import SamplingMetadata diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 73bbfac33ed13..088c0849243c0 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -4,11 +4,9 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.utils import divide +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index fa5a27b5a6974..30588aecdebe9 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -27,6 +27,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -38,8 +40,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index a9ff909090586..40966ab33631a 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,8 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +35,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4008896e48dd1..7b46ba306619a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -10,6 +10,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -21,8 +22,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 29ba3844eb11d..aa27f0a96c745 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,8 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -39,8 +41,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 14c0fece69214..49eb7f1b2c185 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,9 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, @@ -15,10 +18,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 2a2182ff4ebad..c7dd11d07e6da 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,9 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -41,10 +44,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 77c19b227d213..4f1ebcd5fb43c 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,9 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -37,10 +40,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 08609532b8b3e..fc1fc35570368 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -23,6 +23,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,8 +36,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3f816a9996be5..43f0d47fcb122 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 07c647c2e1c41..cec2d771adfa8 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -34,8 +35,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 94048efe48420..5660097652748 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index a5b5d717d9846..2f9e2171cf114 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -33,8 +34,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index bdb48bf21042e..6e9cbd3f9f43f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -17,8 +18,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 12fc9dbd50732..a041b0c9a0452 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,8 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -34,8 +36,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 72fe21df67d8a..c86e292e7df1a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,6 +29,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,8 +42,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 99d1b4eb97bb8..49eda9c9a8112 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -29,6 +29,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -42,10 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 429bc8109b9f8..ff552a9d86536 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,6 +29,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -40,10 +43,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 75f86bc134ee3..1f0c0e912beea 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,9 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, @@ -40,10 +43,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index a39f94359a948..af4cdce29d085 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,8 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -16,8 +18,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 611a48a9aad2b..3513c72879102 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -44,6 +44,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -55,8 +56,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c1ae1b2ae0f03..3a640850662c0 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -34,8 +35,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index ee910563b20df..c606ac027e9d9 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -21,8 +22,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 40e068acaba7d..e91624da90955 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -52,8 +53,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a63b9c8d63d13..6213a2ded65ab 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -22,8 +23,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8c92cd773f6b9..796e30e633e85 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -30,6 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -41,8 +42,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6b4a74198fd52..f920b4f5a40c7 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,9 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -43,10 +46,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index b83637fd50dc7..651598b770f13 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, @@ -36,8 +37,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 50d23e0a3b6ef..76e8e48673413 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, @@ -35,8 +36,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 83d2ddb2bcf35..7e9ce9e5c8e15 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -28,6 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, @@ -39,8 +40,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/parallel_utils/README.md b/vllm/model_executor/parallel_utils/README.md deleted file mode 100644 index b25e3afddad9c..0000000000000 --- a/vllm/model_executor/parallel_utils/README.md +++ /dev/null @@ -1 +0,0 @@ -The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference. \ No newline at end of file diff --git a/vllm/test_utils.py b/vllm/test_utils.py index bc220d3b8a430..0cf23e4bb7e75 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -1,7 +1,7 @@ import ray -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import get_open_port diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 42f0828b826e2..751384eb72af3 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -7,13 +7,12 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.model_runner import ModelRunner diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e7f20475ab1a7..1de4748b7bcc9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,17 +9,15 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce +from vllm.distributed.device_communicators import (custom_all_reduce, + pynccl_utils) from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.parallel_state import ( - with_pynccl_for_all_reduce) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 19de33089b2db..3f0b2fd83f3e5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,14 +8,14 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.device_communicators import pynccl_utils +from vllm.distributed.device_communicators.custom_all_reduce import ( + init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils import pynccl_utils -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_tensor_dict) -from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar -from vllm.model_executor.parallel_utils.parallel_state import ( - ensure_model_parallel_initialized, init_distributed_environment) from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner From 67b4221a61ace91a79aff507df0a95a01978300e Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 11 Apr 2024 09:56:48 +0900 Subject: [PATCH 010/413] [Core][5/N] Fully working chunked prefill e2e (#3884) --- .buildkite/test-pipeline.yaml | 2 + benchmarks/benchmark_latency.py | 3 +- benchmarks/benchmark_throughput.py | 62 ++-- .../basic_correctness/test_chunked_prefill.py | 70 ++++ tests/core/test_chunked_prefill_scheduler.py | 16 +- .../test_basic_distributed_correctness.py | 7 +- .../test_chunked_prefill_distributed.py | 66 ++++ tests/entrypoints/test_openai_server.py | 2 +- tests/models/test_models.py | 2 +- tests/worker/test_model_runner.py | 189 ++++++++-- vllm/attention/__init__.py | 4 +- vllm/attention/backends/abstract.py | 42 ++- vllm/attention/backends/flash_attn.py | 85 +++-- vllm/attention/backends/rocm_flash_attn.py | 97 ++++-- vllm/attention/backends/torch_sdpa.py | 67 ++-- vllm/attention/backends/xformers.py | 138 ++++---- vllm/attention/layer.py | 5 +- vllm/attention/ops/paged_attn.py | 6 - vllm/config.py | 13 +- vllm/core/scheduler.py | 15 +- vllm/distributed/communication_op.py | 10 +- vllm/engine/arg_utils.py | 5 +- vllm/engine/llm_engine.py | 5 +- vllm/lora/layers.py | 5 +- vllm/sequence.py | 3 +- vllm/worker/model_runner.py | 323 +++++++++++++----- 26 files changed, 927 insertions(+), 315 deletions(-) create mode 100644 tests/basic_correctness/test_chunked_prefill.py create mode 100644 tests/distributed/test_chunked_prefill_distributed.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 27e44463a30a6..695290ed74ab5 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -29,6 +29,8 @@ steps: - pytest -v -s test_pynccl.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 91510dafc57a5..aadbc441713fc 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None): help='block size of key/value cache') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, + action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e71338273d1e5..6df1e1d628e6c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -74,25 +74,31 @@ def run_vllm( quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,15 +219,15 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager, - args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -335,6 +341,14 @@ def main(args: argparse.Namespace): "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') parser.add_argument('--download-dir', type=str, default=None, diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py new file mode 100644 index 0000000000000..9ff07b3c09020 --- /dev/null +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -0,0 +1,70 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +It tests chunked prefill. Chunked prefill can be enabled by +enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens, +prefill requests are chunked. + +Run `pytest tests/models/test_chunked_prefill.py`. +""" +import pytest + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + tensor_parallel_size: int, +) -> None: + if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 + and not enforce_eager): + pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " + "for high TP to save testing time.") + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + print(vllm_outputs[0]) + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 05e62ced5898f..cce396bf4953c 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -104,10 +104,10 @@ def test_chunk(): # One chunked prefill, and one decoding. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert set(get_sequence_groups(out)) == set(running) - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 + # The first one is prefill. Scheduler guarantees ordering. + assert seq_group_meta[0].token_chunk_size == 56 # The second one is a chunked prefill. - assert seq_group_meta[1].token_chunk_size == 56 + assert seq_group_meta[1].token_chunk_size == 1 assert out.num_prefill_groups == 1 assert out.num_batched_tokens == 57 @@ -157,12 +157,12 @@ def test_complex(): # Decoding & chunked prefill & first chunk of 3rd request is scheduled. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(get_sequence_groups(out)) == 3 - # The first one is decoding. - assert seq_group_meta[0].token_chunk_size == 1 - # The second one is a chunked prefill. + # The first one is the first chunked prefill. + assert seq_group_meta[0].token_chunk_size == 7 + # The second one is the second new chunked prefill. assert seq_group_meta[1].token_chunk_size == 56 - # The third one is also chunked. - assert seq_group_meta[2].token_chunk_size == 7 + # The last one is decode. + assert seq_group_meta[2].token_chunk_size == 1 # Two of them are in chunked prefill. assert out.num_prefill_groups == 2 assert out.num_batched_tokens == 64 diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 1eba14d7a6422..77aa90b12bf8f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -33,11 +33,16 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py new file mode 100644 index 0000000000000..737b1f3169519 --- /dev/null +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -0,0 +1,66 @@ +"""Compare the outputs of HF and distributed vLLM when using greedy sampling. +vLLM will allocate all the available memory, so we need to run the tests one +by one. The solution is to pass arguments (model name) by environment +variables. + +Run: +```sh +TEST_DIST_MODEL=facebook/opt-125m pytest \ + test_chunked_prefill_distributed.py +TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ + test_chunked_prefill_distributed.py +``` +""" +import os + +import pytest +import torch + +MODELS = [ + os.environ["TEST_DIST_MODEL"], +] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + # Add a chunked prefill config. + max_num_seqs = min(chunked_prefill_token_size, 256) + assert chunked_prefill_token_size != -1 + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 442f8bdf3b4ba..6f2086c4dd269 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -141,7 +141,7 @@ def server(zephyr_lora_files): "--max-cpu-loras", "2", "--max-num-seqs", - "128" + "128", ]) ray.get(server_runner.ready.remote()) yield server_runner diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 53a80d4619646..cfe2539e3a052 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -12,7 +12,7 @@ "gpt2", "bigcode/tiny_starcoder_py", "EleutherAI/pythia-70m", - "bigscience/bloom-560m", + "bigscience/bloom-560m", # Testing alibi slopes. "microsoft/phi-2", "stabilityai/stablelm-3b-4e1t", # "allenai/OLMo-1B", # Broken diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5b6f001f62fa7..dcaae4af4a6f8 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,14 +1,18 @@ import pytest import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, SchedulerConfig from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - model_runner = ModelRunner(None, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(None, None, scheduler_config, None, None) model_runner.set_block_size(16) prompt_lens = [] @@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size): prompt_len - 1) selected_token_start_idx += prompt_len (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens + assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device @@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.prompt_lens_tensor, torch.tensor(prompt_lens, device=device)) assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.num_prompt_tokens == sum(prompt_lens) - assert attn_metadata.num_generation_tokens == 0 assert attn_metadata.max_prompt_len == max(prompt_lens) # Test subquery start locs. @@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size): assert torch.allclose(attn_metadata.block_tables, expected) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) - assert input_tokens.shape == (sum(prompt_lens), ) - assert input_positions.shape == (sum(prompt_lens), ) + assert len(input_tokens) == sum(prompt_lens) + assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - torch.testing.assert_close(input_tokens, input_positions) + assert input_tokens == input_positions actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size): revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=False) + model_runner = ModelRunner(model_config, None, scheduler_config, None, + None) model_runner.set_block_size(16) prompt_lens = [] @@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _ = ( + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) + assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False assert attn_metadata.prompt_lens is None - assert attn_metadata.num_prompt_tokens == 0 - assert attn_metadata.num_generation_tokens == expected_bs assert attn_metadata.max_prompt_len is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None @@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size): model_runner.get_max_block_per_batch()) # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True - assert attn_metadata.kv_cache_dtype == "auto" - assert input_tokens.shape == (expected_bs, ) - assert input_positions.shape == (expected_bs, ) - torch.testing.assert_close(input_tokens, input_positions) + assert len(input_tokens) == expected_bs + assert len(input_positions) == expected_bs + assert input_tokens == input_positions # Verify Sampling expected_selected_token_indices = [] @@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) + + +def test_empty_seq_group(): + """Verify prepare prompt and decode returns empty output.""" + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=False, + ) + model_runner = ModelRunner(model_config, None, None, None, None) + model_runner.set_block_size(16) + seq_group_metadata_list = [] + input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( + model_runner._prepare_decode(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + + (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, + _, _, + slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert len(input_tokens) == 0 + assert len(input_positions) == 0 + assert attn_metadata is None + assert len(slot_mapping) == 0 + assert len(return_prompt_lens) == 0 + + +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): + + def get_world_size(group=None): + return 1 + + def mock_get_process_group_ranks(group=None): + return [0] + + monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) + monkeypatch.setattr(torch.distributed, "get_process_group_ranks", + mock_get_process_group_ranks) + + model_config = ModelConfig( + "facebook/opt-125m", + "facebook/opt-125m", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + enforce_eager=enforce_eager, + ) + scheduler_config = SchedulerConfig(100000, + 100000, + 100000, + enable_chunked_prefill=True) + model_runner = ModelRunner(model_config, + None, + scheduler_config, + None, + None, + is_driver_worker=True) + model_runner.set_block_size(16) + + # Add prefill requests. + prompt_lens = [] + seq_group_metadata_list = [] + prefill_metadata_list = [] + decode_metadata_list = [] + block_tables = {0: [1]} + prefill_batch_size = batch_size // 2 + decode_batch_size = batch_size - prefill_batch_size + for i in range(prefill_batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = SequenceData(list(range(prompt_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + ) + assert seq_group_metadata.token_chunk_size == seq_data.get_len() + seq_group_metadata_list.append(seq_group_metadata) + prefill_metadata_list.append(seq_group_metadata) + + # Add decode requests + for i in range(prefill_batch_size, batch_size): + # make sure all tokens fit into one block + prompt_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(prompt_len)) + seq_data = SequenceData(prompt_toks) + seq_group_metadata = SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=False, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + ) + assert seq_group_metadata.token_chunk_size == 1 + seq_group_metadata_list.append(seq_group_metadata) + decode_metadata_list.append(seq_group_metadata) + + (input_tokens, input_positions, attn_metadata, _, _, _, + _) = model_runner.prepare_input_tensors(seq_group_metadata_list) + + prefill_meta_actual = attn_metadata.prefill_metadata + decode_meta_actual = attn_metadata.decode_metadata + + assert len(attn_metadata.slot_mapping) == len(input_tokens) + assert len(input_positions) == len(input_tokens) + assert attn_metadata.kv_cache_dtype == "auto" + assert attn_metadata.num_prefills == prefill_batch_size + if enforce_eager: + assert attn_metadata.num_decode_tokens == decode_batch_size + else: + assert attn_metadata.num_decode_tokens == _get_graph_batch_size( + decode_batch_size) + assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + + # Verify attn metadata is consistent. We don't need to test individual + # values here because they are tested above. + prefill_meta = model_runner._prepare_prompt( + prefill_metadata_list).attn_metadata + decode_meta = model_runner._prepare_decode( + decode_metadata_list).attn_metadata + + for attr_expected, attr_actual in zip(vars(prefill_meta), + vars(prefill_meta_actual)): + assert attr_expected[1] == attr_actual[1] + for attr_expected, attr_actual in zip(vars(decode_meta), + vars(decode_meta_actual)): + assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 9acb82c0df2c2..7636b34a16fed 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,5 +1,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,4 +9,5 @@ "AttentionMetadata", "Attention", "get_attn_backend", + "AttentionMetadataPerStage", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a03cf2dd7a6fa..7a4ccecf702f4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar import torch @@ -47,7 +47,8 @@ def copy_blocks( @dataclass -class AttentionMetadata: +class AttentionMetadataPerStage: + """Attention metadata for a specific stage. I.e., prefill or decode.""" def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" @@ -59,6 +60,41 @@ def asdict_zerocopy(self) -> Dict[str, Any]: } +T = TypeVar("T", bound=AttentionMetadataPerStage) + + +@dataclass +class AttentionMetadata(Generic[T]): + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # The attention metadata for prefill requests in a batch. + # None if there's no prefill requests in a batch. + prefill_metadata: Optional[T] + # The attention metadata for decode requests in a batch. + # None if there's no decode requests in a batch. + decode_metadata: Optional[T] + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + # The kv cache's data type. + kv_cache_dtype: str + + def __post_init__(self): + if self.num_prefill_tokens > 0: + assert self.num_prefills > 0 + assert self.prefill_metadata is not None + if self.num_decode_tokens > 0: + assert self.decode_metadata is not None + + class AttentionImpl(ABC): @abstractmethod @@ -80,7 +116,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4e0d9d1418b32..12e8c4404b94e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,7 +11,8 @@ from flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -53,7 +54,8 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class FlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -155,7 +162,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, + attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -188,52 +195,70 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - output = flash_attn_varlen_func( + out = flash_attn_varlen_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6019d917b4494..e55435cd2c947 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,7 +6,8 @@ import torch from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -51,7 +52,8 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -181,7 +188,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, + attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -218,9 +225,25 @@ def forward( kv_scale, ) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -230,63 +253,69 @@ def forward( key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - output = self.attn_fuc( + out = self.attn_fuc( query, key, value, - attn_metadata.prompt_lens, + prefill_meta.prompt_lens, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output, _ = self.attn_func( + out, _ = self.attn_func( query, key, value, None, - attn_metadata.seq_start_loc, - attn_metadata.seq_start_loc, - attn_metadata.max_prompt_len, - attn_metadata.max_prompt_len, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, True, self.scale, ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: - output = self.attn_func( + out = self.attn_func( q=query, k=key, v=value, - cu_seqlens_q=attn_metadata.seq_start_loc, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_q=attn_metadata.max_prompt_len, - max_seqlen_k=attn_metadata.max_prompt_len, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prompt_len, + max_seqlen_k=prefill_meta.max_prompt_len, softmax_scale=self.scale, causal=True, ) - + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention - output = PagedAttention.forward_prefix( + output[:num_prefill_tokens] = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: + + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9706e1910cb79..63904ea929870 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,7 +7,8 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -49,17 +50,14 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] prompt_lens_tensor: Optional[torch.Tensor] - num_prompt_tokens: int - num_generation_tokens: int max_subquery_len: Optional[int] = None max_prompt_len: Optional[int] = None @@ -113,7 +111,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, + attn_metadata: AttentionMetadata[TorchSDPAMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -142,36 +140,51 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: - if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: + if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + prefill_meta.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + prefill_meta.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) - attn_metadata.attn_bias = att_masks + att_masks = [None] * len(prefill_meta.prompt_lens) + prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): + out = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(prefill_meta.prompt_lens, + prefill_meta.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -181,28 +194,32 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out + out[start:end, :, :] = sub_out start = end + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output = PagedAttention.forward_decode( - query, + out = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) + assert out.shape == output[num_prefill_tokens:].shape + output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 05b68bba5e6eb..b745a04a143b4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,7 +9,8 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata) + AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -54,7 +55,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The prompt length per sequence. None if it is a decoding. prompt_lens: Optional[List[int]] # prompt_lens stored as a tensor. prompt_lens_tensor: Optional[torch.Tensor] - # The number of prompt tokens. Doesn't include padding. - num_prompt_tokens: int - # The number of generation tokens. Doesn't include padding. - num_generation_tokens: int # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| @@ -123,18 +115,27 @@ def __post_init__(self): class XFormersImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens --------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| Generation tokens can contain padding when cuda-graph is used. Currently, prompt tokens don't contain any padding. The prompts might have different lengths, while the generation tokens always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. """ def __init__( @@ -170,7 +171,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: XFormersMetadata, + attn_metadata: AttentionMetadata[XFormersMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -202,59 +203,61 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - if attn_metadata.is_prompt: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or attn_metadata.block_tables.numel() == 0: + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention. # block tables are empty if the prompt does not have a cached # prefix. - if self.num_kv_heads != self.num_heads: - # As of Nov 2023, xformers only supports MHA. For MQA/GQA, - # project the key and value tensors to the desired number of - # heads. - # TODO(woosuk): Use MQA/GQA kernels for higher performance. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], - self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - output = self._run_memory_efficient_xformers_forward( - query, key, value, attn_metadata) + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention # TODO(Hai) this triton kernel has regression issue (broke) to # deal with different data types between KV and FP8 KV cache, # to be addressed separately. - output = PagedAttention.forward_prefix( + out = PagedAttention.forward_prefix( query, key, value, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.subquery_start_loc, - attn_metadata.prompt_lens_tensor, - attn_metadata.context_lens, - attn_metadata.max_subquery_len, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, self.alibi_slopes, ) - else: - # Decoding run. - output = PagedAttention.forward_decode( - query, + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + decode_meta.block_tables, + decode_meta.context_lens, + decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -275,13 +278,30 @@ def _run_memory_efficient_xformers_forward( """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + original_query = query + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -302,6 +322,7 @@ def _run_memory_efficient_xformers_forward( # TODO(woosuk): Too many view operations. Let's try to reduce # them in the future for code readability. if self.alibi_slopes is None: + # Add the batch dimension. query = query.unsqueeze(0) key = key.unsqueeze(0) value = value.unsqueeze(0) @@ -312,14 +333,13 @@ def _run_memory_efficient_xformers_forward( attn_bias=attn_metadata.attn_bias[0], p=0.0, scale=self.scale) - - return out.view_as(query) + return out.view_as(original_query) # Attention with alibi slopes. # FIXME(woosuk): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - output = torch.empty_like(query) + output = torch.empty_like(original_query) start = 0 for i, prompt_len in enumerate(attn_metadata.prompt_lens): end = start + prompt_len @@ -331,7 +351,7 @@ def _run_memory_efficient_xformers_forward( p=0.0, scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.squeeze(0)) + output[start:end].copy_(out.view_as(original_query[start:end])) start += prompt_len return output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9856654fc5f94..fc65ae108dbb1 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import (AttentionMetadata, + AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend @@ -41,7 +42,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata, + attn_metadata: AttentionMetadata[AttentionMetadataPerStage], kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 256bffdf032eb..2d918491d6576 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,11 +13,6 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor # (batch_size,). The length of context (tokens stored in KV cache) per # sequence. WARNING: When it is a prefill request, it doesn't include new # tokens. When it is for decoding, it includes a new token. @@ -31,7 +26,6 @@ class PagedAttentionMetadata: # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # captured. block_tables: Optional[torch.Tensor] - kv_cache_dtype: str class PagedAttention: diff --git a/vllm/config.py b/vllm/config.py index bca250e922288..4102edbe01d35 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -565,9 +565,16 @@ def __init__( if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: - # If max_model_len is too short, use 2048 as the default value for - # higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + # For chunked prefill, choose the well-tuned batch size. + self.max_num_batched_tokens = 768 + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + if enable_chunked_prefill: + logger.info("Chunked prefill is enabled (EXPERIMENTAL).") + self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.use_v2_block_manager = use_v2_block_manager diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ae53f9374960..2942eab735a92 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -140,7 +140,11 @@ def _sort_by_lora_ids(self) -> bool: @property def lora_requests(self) -> Set[LoRARequest]: - return {g.seq_group.lora_request for g in self.scheduled_seq_groups} + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } @dataclass @@ -826,13 +830,12 @@ def _schedule_chunked_prefill(self): # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) - return SchedulerOutputs( scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + running_scheduled.prefill_seq_groups + - swapped_in.decode_seq_groups + - swapped_in.prefill_seq_groups), + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), @@ -907,7 +910,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. - is_prompt = i < scheduler_outputs.num_prefill_groups + is_prompt = seq_group.is_prefill() seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index cf15db099b304..1004d626b6a4b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -173,10 +173,18 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list([metadata_list], src=src, group=group) + async_handles = [] for key, value in metadata_list: if isinstance(value, TensorMetadata): tensor = tensor_dict[key] - torch.distributed.broadcast(tensor, src=src, group=group) + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) + for async_handle in async_handles: + async_handle.wait() + else: recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4b573992c06c..daefddc01b431 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -386,9 +386,8 @@ def add_cli_args( 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', - type=bool, - default=False, - help='If True, the prefill requests can be chunked based on the ' + action='store_true', + help='If set, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') parser.add_argument( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c639af696544..ddfdda898a5c6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -633,7 +633,10 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - self._process_sequence_group_outputs(seq_group, outputs) + # If uncomputed tokens > 0, it means prefill is chunked. + # We don't need to process outputs in that case. + if seq_group.get_num_uncomputed_tokens() == 0: + self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index dd33868f76302..84a94091486d7 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -267,12 +267,13 @@ def set_mapping( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + embedding_len = self.indices_len[3] + indices = self.embeddings_indices[1][:embedding_len].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, self.lora_a_stacked_2d, ) - indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + indices = self.embeddings_indices[0][:embedding_len].view_as(x) full_output = self.base_layer.forward( x.add_(indices * added_tokens_mask)) diff --git a/vllm/sequence.py b/vllm/sequence.py index 576bbe8c4f6c4..77029908c2218 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -500,7 +500,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int): def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 for seq in self.get_seqs(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1de4748b7bcc9..47ad8f0c9b78b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,12 +1,14 @@ import contextlib import time -from typing import Dict, List, Optional, Set, Tuple +from enum import IntEnum +from typing import Dict, List, NamedTuple, Optional, Set, Tuple import numpy as np import torch import torch.nn as nn -from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, + get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -37,6 +39,66 @@ ] +class PreparePromptMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadataPerStage] + prompt_lens: List[int] + subquery_lens: List[int] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + multi_modal_input: Optional[torch.Tensor] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PreparePromptMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + prompt_lens=[], + subquery_lens=[], + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + multi_modal_input=None, + slot_mapping=[], + ) + + +class PrepareDecodeMetadata(NamedTuple): + input_tokens: List[int] + input_positions: List[int] + attn_metadata: Optional[AttentionMetadata] + lora_index_mapping: List[int] + lora_prompt_mapping: List[int] + lora_requests: Set[LoRARequest] + slot_mapping: List[int] + + @classmethod + def empty(cls): + return PrepareDecodeMetadata( + input_tokens=[], + input_positions=[], + attn_metadata=None, + lora_index_mapping=[], + lora_prompt_mapping=[], + lora_requests=set(), + slot_mapping=[], + ) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + class ModelRunner: def __init__( @@ -152,10 +214,7 @@ def get_max_block_per_batch(self) -> int: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], List[int], List[int], Set[LoRARequest], - torch.Tensor]: - assert len(seq_group_metadata_list) > 0 + ) -> PreparePromptMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -169,6 +228,9 @@ def _prepare_prompt( prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] + if len(seq_group_metadata_list) == 0: + return PreparePromptMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -178,7 +240,8 @@ def _prepare_prompt( computed_block_nums = seq_group_metadata.computed_block_nums if (self.scheduler_config is not None and self.scheduler_config.chunked_prefill_enabled - and computed_block_nums is not None): + and not (computed_block_nums is None + or computed_block_nums == [])): raise RuntimeError( "chunked prefill cannot be used with prefix caching " "now.") @@ -190,13 +253,8 @@ def _prepare_prompt( # it contains output tokens. prefill_end = min(seq_data.get_len(), computed_len + token_chunk_size) - # TODO(sang): Rename it after chunked prefill is introduced. prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = len(prompt_tokens) - # Right now, the prefill_end is always same as the length of - # sequence. However, once chunked prefill is introduced, this - # assumption can be changed. - assert prefill_end == seq_data.get_len() + prompt_len = prefill_end prompt_lens.append(prompt_len) # NOTE: This only works for oooooooxxx style attention. @@ -206,6 +264,14 @@ def _prepare_prompt( computed_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[computed_len:] prefix_block_tables.append(computed_block_nums) + elif self.scheduler_config.chunked_prefill_enabled: + if seq_group_metadata.block_tables is not None: + # Prefill has chunked before. + block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) + else: + # The first prefill. + prefix_block_tables.append([]) else: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this @@ -267,20 +333,8 @@ def _prepare_prompt( max_subquery_len = max(subquery_lens) max_prompt_len = max(prompt_lens) - num_prompt_tokens = len(input_tokens) assert max_subquery_len > 0 - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - lora_index_mapping = lora_index_mapping - context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -332,11 +386,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - slot_mapping=slot_mapping, prompt_lens=prompt_lens, prompt_lens_tensor=prompt_lens_tensor, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=0, max_subquery_len=max_subquery_len, max_context_len=None, max_prompt_len=max_prompt_len, @@ -345,18 +396,25 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input) + + return PreparePromptMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + prompt_lens=prompt_lens, + subquery_lens=subquery_lens, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping, + ) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - List[int], Set[LoRARequest]]: - assert len(seq_group_metadata_list) > 0 + ) -> PrepareDecodeMetadata: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -366,6 +424,9 @@ def _prepare_decode( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + if len(seq_group_metadata_list) == 0: + return PrepareDecodeMetadata.empty() + for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt assert seq_group_metadata.token_chunk_size == 1 @@ -424,15 +485,6 @@ def _prepare_decode( lora_index_mapping.append(0) batch_size = graph_batch_size - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) @@ -440,9 +492,9 @@ def _prepare_decode( if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == input_tokens.shape[0] - assert context_lens.shape[0] == input_positions.shape[0] - assert context_lens.shape[0] == slot_mapping.shape[0] + assert context_lens.shape[0] == len(input_tokens) + assert context_lens.shape[0] == len(input_positions) + assert context_lens.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -464,11 +516,8 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping, prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=len(input_tokens), max_subquery_len=None, max_context_len=max_context_len, max_prompt_len=None, @@ -477,10 +526,16 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, - kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + return PrepareDecodeMetadata( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + lora_index_mapping=lora_index_mapping, + lora_prompt_mapping=lora_prompt_mapping, + lora_requests=lora_requests, + slot_mapping=slot_mapping, + ) def _prepare_sample( self, @@ -586,26 +641,66 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, - subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests, multi_modal_input - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, attn_metadata, - lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] - subquery_lens = None - multi_modal_input = None + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, @@ -615,6 +710,16 @@ def prepare_input_tensors( lora_mapping = None # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -623,19 +728,49 @@ def prepare_input_tensors( "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, } - metadata_dict.update(attn_metadata.asdict_zerocopy()) + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, @@ -646,6 +781,23 @@ def prepare_input_tensors( perform_sampling=False, ) + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -663,8 +815,10 @@ def execute_model( if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) - # Execute the model. - if attn_metadata.use_cuda_graph: + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: @@ -842,13 +996,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - attn_metadata = self.attn_backend.make_metadata( + decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - slot_mapping=slot_mapping[:batch_size], prompt_lens=None, prompt_lens_tensor=None, - num_prompt_tokens=0, - num_generation_tokens=batch_size, max_subquery_len=None, max_context_len=self.max_context_len_to_capture, max_prompt_len=None, @@ -857,6 +1008,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + ) + attn_metadata = AttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], + prefill_metadata=None, + decode_metadata=decode_metadata, kv_cache_dtype=self.kv_cache_dtype, ) @@ -950,8 +1109,8 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.context_lens, - "block_tables": attn_metadata.block_tables, + "context_lens": attn_metadata.decode_metadata.context_lens, + "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} return @@ -972,10 +1131,10 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, - non_blocking=True) + self.input_buffers["context_lens"].copy_( + attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. self.graph.replay() From caada5e50aa16cd5f59bd7889128a83588ca1f99 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 18:48:26 -0700 Subject: [PATCH 011/413] [Core][Model] torch.compile for layernorm in commandr (#3985) [Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985) --- vllm/model_executor/models/commandr.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index aa27f0a96c745..aa9b28b676e0b 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -48,6 +48,18 @@ from vllm.sequence import SamplerOutput +@torch.compile +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + class LayerNorm(nn.Module): def __init__(self, param_shape=None, eps=1e-5): @@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5): set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - mean = hidden_states.mean(-1, keepdim=True) - variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) - hidden_states = (hidden_states - - mean) * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = self.weight.to(torch.float32) * hidden_states - return hidden_states.to(input_dtype), residuals + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() From e42df7227d18e2b96785f8ee52053663ade05b63 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 11 Apr 2024 12:09:50 +0900 Subject: [PATCH 012/413] [Test] Add xformer and flash attn tests (#3961) Co-authored-by: Simon Mo --- tests/basic_correctness/test_basic_correctness.py | 6 ++++++ vllm/attention/selector.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..bd4c7ea3301be 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.attention.selector import VLLM_ATTENTION_BACKEND + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -14,6 +16,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -22,7 +25,10 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, + attn_backend: str, + monkeypatch, ) -> None: + monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c699aed48d49..554e802cd5513 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,4 +1,5 @@ import enum +import os from functools import lru_cache from typing import Type @@ -10,6 +11,8 @@ logger = init_logger(__name__) +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" + class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "Cannot use FlashAttention backend because the flash_attn package " "is not found. Please install it for better performance.") return _Backend.XFORMERS + + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var is not None: + return _Backend[backend_by_env_var] + + # Default case. return _Backend.FLASH_ATTN From e9da5a40c63ce7f8a85438d3c7d919b46e7939f5 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 11 Apr 2024 03:26:07 +0000 Subject: [PATCH 013/413] [Misc] Add indirection layer for custom ops (#3913) --- .../kernels/benchmark_paged_attention.py | 2 +- tests/kernels/test_attention.py | 6 +- tests/kernels/test_cache.py | 25 ++- vllm/_custom_ops.py | 193 ++++++++++++++++++ vllm/attention/ops/paged_attn.py | 10 +- vllm/model_executor/layers/activation.py | 2 +- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- .../model_executor/layers/quantization/awq.py | 2 +- .../layers/quantization/gptq.py | 2 +- .../layers/quantization/marlin.py | 2 +- .../layers/quantization/squeezellm.py | 2 +- .../model_executor/layers/rotary_embedding.py | 2 +- vllm/utils.py | 4 +- 14 files changed, 224 insertions(+), 32 deletions(-) create mode 100644 vllm/_custom_ops.py diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index f71d1fcaaef50..5c3650fa72d17 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -5,7 +5,7 @@ import torch -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 03ea72924921e..9b1f3e30b6dca 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 @@ -237,14 +237,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - cache_ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4141aacafd0b2..d1051fd7e2f4d 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -4,7 +4,7 @@ import pytest import torch -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -80,7 +80,7 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + ops.copy_blocks(key_caches, value_caches, block_mapping) # Run the reference implementation. for src, dsts in block_mapping.items(): @@ -145,9 +145,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(key_cache, cloned_key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(value_cache, cloned_value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -156,14 +156,14 @@ def test_reshape_and_cache( kv_scale = 1.0 # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, + kv_cache_dtype, kv_scale) if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - cache_ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(key_cache, result_key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - cache_ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -251,9 +251,8 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], - block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) for src, dst in block_mapping.items(): assert torch.allclose(src_key_caches_clone[src].cpu(), @@ -291,9 +290,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - cache_ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache, cache_fp8) converted_cache = torch.empty_like(cache) - cache_ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(cache_fp8, converted_cache) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000000000..a0837a20875fe --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,193 @@ +from typing import Dict, Optional + +import torch + +try: + from vllm._C import cache_ops as vllm_cache_ops + from vllm._C import ops as vllm_ops +except ImportError: + pass + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.silu_and_mul(out, x) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_and_mul(out, x) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_tanh_and_mul(out, x) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_fast(out, x) + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + vllm_ops.gelu_new(out, x) + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, + context_lens, block_size, max_context_len, + alibi_slopes, kv_cache_dtype, kv_scale) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes, kv_cache_dtype, + kv_scale) + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, + is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + vllm_ops.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + vllm_ops.rms_norm(out, input, weight, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, + thy) + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, + scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + + +# squeezellm +def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, + lookup_table: torch.Tensor) -> None: + vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, +) -> None: + vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype, kv_scale) + + +def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, + block_mapping: torch.Tensor) -> None: + vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: Dict[int, int]) -> None: + vllm_cache_ops.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: + vllm_cache_ops.convert_fp8(output, input) + + +#TODO: cuda_utils, custom_ar diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2d918491d6576..cd0690a4ba957 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -3,7 +3,7 @@ import torch -from vllm._C import cache_ops, ops +from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -69,7 +69,7 @@ def write_to_paged_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - cache_ops.reshape_and_cache( + ops.reshape_and_cache( key, value, key_cache, @@ -199,11 +199,11 @@ def swap_blocks( ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -212,4 +212,4 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 6786c48e0caba..baf1d4f266181 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.quantization import QuantizationConfig diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1ec09f0cd4c28..377b6588dbf47 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,7 @@ import triton import triton.language as tl -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.utils import is_hip diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index cb3cee2bad5ad..a6619714b8aab 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops class RMSNorm(nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf50..daea5ac73e429 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed811..757ab1af8392e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -6,7 +6,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf4..a6482c059cc41 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1f..bb295df2acc3f 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -3,7 +3,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d80e73bbe39e9..eb8d5f6dfb2a9 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from vllm._C import ops +from vllm import _custom_ops as ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/utils.py b/vllm/utils.py index 8ba03333d3b6c..8ab8927512cc9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -279,10 +279,10 @@ def _generate_random_fp8( #-----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} - from vllm._C import cache_ops + from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - cache_ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor_tmp, tensor) del tensor_tmp From f3d0bf7589d6e63a691dcbb9d1db538c184fde29 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 10 Apr 2024 20:33:02 -0700 Subject: [PATCH 014/413] [Doc][Installation] delete python setup.py develop (#3989) --- docs/source/getting_started/installation.rst | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 5dfb32080f97a..e7826114ffa9d 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -85,13 +85,3 @@ You can also build and install vLLM from source: $ nvcc --version # verify that nvcc is in your PATH $ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME - -.. note:: - If you are developing the C++ backend of vLLM, consider building vLLM with - - .. code-block:: console - - $ python setup.py develop - - since it will give you incremental builds. The downside is that this method - is `deprecated by setuptools `_. From c1dc547129f5faaa2ca5ba557145b8ec8838693c Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 11 Apr 2024 07:50:00 -0700 Subject: [PATCH 015/413] [Kernel] Fused MoE Config for Mixtral 8x22 (#4002) --- ...048,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++ ...048,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ ...096,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++ ...096,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++ 4 files changed, 584 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000000000..0bb423b28f5ab --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..26bcbf26970c7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 0000000000000..dbc624731f5cb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000000..32c0c9da471cb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} From 08ccee1e830d39ecdb3c6cf382c843dbf5ae830e Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Thu, 11 Apr 2024 23:59:26 +0800 Subject: [PATCH 016/413] punica fix-bgmv-kernel-640 (#4007) --- csrc/punica/bgmv/bgmv_config.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2219d960ae62f..1084a0f20df6b 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128) \ f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ From 8afca50889bad6ad987c523c48c31fc52fcb72e4 Mon Sep 17 00:00:00 2001 From: bigPYJ1151 Date: Fri, 12 Apr 2024 02:56:49 +0800 Subject: [PATCH 017/413] [Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (#3824) --- vllm/attention/backends/torch_sdpa.py | 72 ++--- vllm/executor/cpu_executor.py | 10 + vllm/utils.py | 1 - vllm/worker/cpu_model_runner.py | 408 ++++++++++++++++++++++++++ vllm/worker/cpu_worker.py | 13 +- 5 files changed, 443 insertions(+), 61 deletions(-) create mode 100644 vllm/worker/cpu_model_runner.py diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 63904ea929870..d21b54b16db4b 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -50,20 +50,15 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, + AttentionMetadataPerStage): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool + slot_mapping: torch.Tensor prompt_lens: Optional[List[int]] - prompt_lens_tensor: Optional[torch.Tensor] - - max_subquery_len: Optional[int] = None - max_prompt_len: Optional[int] = None - subquery_start_loc: Optional[torch.Tensor] = None - seq_start_loc: Optional[torch.Tensor] = None - use_cuda_graph: bool = False def __post_init__(self): # Set during the execution of the first attention op. @@ -111,7 +106,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[TorchSDPAMetadata], + attn_metadata: TorchSDPAMetadata, kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -140,51 +135,36 @@ def forward( attn_metadata.kv_cache_dtype, kv_scale) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - if (kv_cache is None or prefill_meta.block_tables.numel() == 0): + if attn_metadata.is_prompt: + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if prefill_meta.attn_bias is None: + if attn_metadata.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - prefill_meta.prompt_lens) # type: ignore + attn_metadata.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - prefill_meta.prompt_lens, self.sliding_window, + attn_metadata.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(prefill_meta.prompt_lens) - prefill_meta.attn_bias = att_masks + att_masks = [None] * len(attn_metadata.prompt_lens) + attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 - out = torch.empty((num_tokens, self.num_heads, self.head_size), - dtype=query.dtype) - for prompt_len, mask in zip(prefill_meta.prompt_lens, - prefill_meta.attn_bias): + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype) + for prompt_len, mask in zip(attn_metadata.prompt_lens, + attn_metadata.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], @@ -194,32 +174,28 @@ def forward( dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) - out[start:end, :, :] = sub_out + output[start:end, :, :] = sub_out start = end - assert out.shape == output[:num_prefill_tokens].shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") - if decode_meta := attn_metadata.decode_metadata: + else: # Decoding run. - out = PagedAttention.forward_decode( - decode_query, + output = PagedAttention.forward_decode( + query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + attn_metadata.block_tables, + attn_metadata.context_lens, + attn_metadata.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) - assert out.shape == output[num_prefill_tokens:].shape - output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) @@ -241,7 +217,7 @@ def _make_alibi_bias( bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] - bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) + bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, prompt_len, prompt_len), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 2bf97338da0ed..eda4e8989c163 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, assert lora_config is None, "cpu backend doesn't support LoRA" model_config = _verify_and_get_model_config(model_config) cache_config = _verify_and_get_cache_config(cache_config) + scheduler_config = _verify_and_get_scheduler_config(scheduler_config) self.model_config = model_config self.cache_config = cache_config @@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: return config +def _verify_and_get_scheduler_config( + config: SchedulerConfig) -> SchedulerConfig: + if config.chunked_prefill_enabled: + logger.warning("Chunked prefill is not supported on CPU, disable it.") + config.chunked_prefill_enabled = False + + return config + + def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: _GB = 1 << 30 if config.enable_prefix_caching: diff --git a/vllm/utils.py b/vllm/utils.py index 8ab8927512cc9..fdb0a3768ab0d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool: print_warning_once("Pin memory is not supported on Neuron.") return False elif is_cpu(): - print_warning_once("Pin memory is not supported on CPU.") return False return True diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000000000..49e1ad5709f5d --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,408 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad, maybe_expand_dim + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +class CPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + + self.model = None + self.block_size = None # Set after initial profiling. + + self.kv_cache_dtype = kv_cache_dtype + + self.attn_backend = get_attn_backend( + self.model_config.dtype if model_config is not None else None) + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + prompt_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + prompt_len = len(prompt_tokens) + + prompt_lens.append(prompt_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, prompt_len))) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, prompt_len - self.sliding_window) + + for i in range(computed_len, prompt_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + prompt_lens=prompt_lens, + num_prefills=len(prompt_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + prefill_metadata=None, + decode_metadata=None, + max_context_len=None, + context_lens=None, + block_tables=torch.tensor([]), + slot_mapping=slot_mapping, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + prompt_lens, + ) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_context_len = max(context_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + max_context_len=max_context_len, + num_prefills=0, + prefill_metadata=None, + decode_metadata=None, + context_lens=context_lens, + block_tables=block_tables, + kv_cache_dtype=self.kv_cache_dtype, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + subquery_len = prompt_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += subquery_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) + categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len - 1)) + selected_token_indices.append(selected_token_start_idx + + subquery_len - 1) + selected_token_start_idx += subquery_len + + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=self.device).manual_seed(sampling_params.seed) + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) + categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs + + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long) + + categorized_sample_indices = { + t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + generators=generators, + ) + return sampling_metadata + + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, + SamplingMetadata]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, + prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + generators=None, + perform_sampling=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not sampling_metadata.perform_sampling: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 751384eb72af3..3989207e8dd83 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -12,25 +12,14 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.model_loader import get_model from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.model_runner import ModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase logger = init_logger(__name__) -class CPUModelRunner(ModelRunner): - - def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - - class CPUCacheEngine: """Manages the KV cache for CPU backend. From a10d3056da644c31e4ebf95a2b6ad65a626a7350 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 13:35:51 -0700 Subject: [PATCH 018/413] [Core] Set `linear_weights` directly on the layer (#3977) --- csrc/quantization/gptq/q_gemm.cu | 2 +- tests/kernels/test_moe.py | 2 +- vllm/lora/layers.py | 12 +-- vllm/model_executor/layers/linear.py | 77 ++++++++++--------- .../model_executor/layers/quantization/awq.py | 29 +++---- .../layers/quantization/gptq.py | 47 ++++++----- .../layers/quantization/marlin.py | 23 +++--- .../layers/quantization/squeezellm.py | 24 +++--- 8 files changed, 114 insertions(+), 102 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f557..cc56649917a8a 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -2067,7 +2067,7 @@ void gptq_shuffle const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94e..046f11d957bdd 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype): ).cuda() # Load the weights - vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 84a94091486d7..a8ec4dcfd6137 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -368,7 +368,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -402,10 +402,6 @@ def forward(self, input_): if self.base_layer.skip_bias_add else None) return output, output_bias - @property - def linear_weights(self): - return self.base_layer.linear_weights - @classmethod def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, @@ -505,7 +501,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -746,7 +742,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -838,7 +834,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x) + self.base_layer, x) _apply_lora( x, self.lora_a_stacked, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 8f42b3e8a4abe..3ca870742efc5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -28,19 +28,24 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - """Create weights for a linear layer.""" + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + + The weights will be set as attributes of the layer.""" raise NotImplementedError @abstractmethod def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """Apply the weights to the input tensor.""" + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = weights["weight"] + weight = layer.weight if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias @@ -111,12 +118,9 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) + self.linear_method.create_weights(self, self.input_size, + self.output_size, self.input_size, + self.output_size, self.params_dtype) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -126,7 +130,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output = self.linear_method.apply_weights(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -177,13 +181,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size, + self.output_size_per_partition, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -211,8 +215,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_, bias) + output_parallel = self.linear_method.apply_weights(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -523,13 +526,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self.linear_method.create_weights(self, + self.input_size_per_partition, + self.output_size, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -569,7 +572,7 @@ def forward(self, input_): # Matrix multiply. output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_parallel) + self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index daea5ac73e429..98651aed8be0e 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, }) - return { - "qweight": qweight, - "qzeros": qzeros, - "scales": scales, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - scales = weights["scales"] - qzeros = weights["qzeros"] + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) @@ -163,5 +166,5 @@ def apply_weights(self, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 757ab1af8392e..f370b94a210ee 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( @@ -179,37 +181,40 @@ def create_weights( "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) - return { - "qweight": qweight, - "g_idx": g_idx, - "qzeros": qzeros, - "scales": scales, - "exllama_state": exllama_state, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("g_idx", g_idx) + set_weight_attrs(g_idx, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.exllama_state = exllama_state def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] + qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + layer.g_idx.data = torch.empty((0, ), + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY, + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: - output = output + bias + output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index a6482c059cc41..bf0500f1155a1 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if params_dtype != torch.float16: @@ -187,21 +189,22 @@ def create_weights( dtype=torch.int), requires_grad=False) - return { - "B": qweight, - "s": scales, - "workspace": workspace, - } + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) def apply_weights( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weights["B"] - scales = weights["s"] - workspace = weights["workspace"] + qweight = layer.B + scales = layer.s + workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index bb295df2acc3f..661ff9c55d0d1 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -103,17 +104,18 @@ def create_weights(self, input_size_per_partition: int, set_weight_attrs(lookup_table, { "output_dim": 0, }) - return { - "qweight": qweight, - "lookup_table": lookup_table, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("lookup_table", lookup_table) + set_weight_attrs(lookup_table, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - lookup_table = weights["lookup_table"] + qweight = layer.qweight + lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): @@ -126,5 +128,5 @@ def apply_weights(self, ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) From 559eb852f83fe7867390dd2986b4f93a6572cf10 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 11 Apr 2024 14:00:48 -0700 Subject: [PATCH 019/413] [Core] init_distributed_environment align with init_process_group(#4014) [Core][Distributed] make init_distributed_environment compatible with init_process_group (#4014) --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4bb77146295af..9fceffe7cb88b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -39,9 +39,9 @@ def init_distributed_environment( - world_size: int, - rank: int, - distributed_init_method: Optional[str] = None, + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", local_rank: int = -1, backend: str = "nccl", ): From 95e7d4a97cd64f8c6dc226ec0bbceebef6458701 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:15:50 -0700 Subject: [PATCH 020/413] Fix echo/logprob OpenAI completion bug (#3441) Co-authored-by: Dylan Hawk --- tests/entrypoints/test_openai_server.py | 31 ++++++++++++ vllm/entrypoints/openai/serving_chat.py | 9 ++-- vllm/entrypoints/openai/serving_completion.py | 15 ++++-- vllm/entrypoints/openai/serving_engine.py | 47 +++++++++++-------- 4 files changed, 73 insertions(+), 29 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 6f2086c4dd269..7940430b8b654 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -742,5 +742,36 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, + model_name: str): + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # test using text and token IDs + for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): + completion = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + echo=True, + logprobs=1) + + prompt_text = tokenizer.decode(prompt) if isinstance(prompt, + list) else prompt + assert (completion.choices[0].text is not None + and re.search(r"^" + prompt_text, completion.choices[0].text)) + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) > 5 + assert (len(logprobs.token_logprobs) > 5 + and logprobs.token_logprobs[0] is None) + assert (len(logprobs.top_logprobs) > 5 + and logprobs.top_logprobs[0] is None) + assert len(logprobs.tokens) > 5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3cb614..a03c5dc88108f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,8 +63,9 @@ async def create_chat_completion( request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) + # Tokenize/detokenize depending on prompt format (string/token list) + prompt_ids, prompt_text = self._validate_prompt_and_tokenize( + request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) guided_decode_logits_processor = ( @@ -78,8 +79,8 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids, + result_generator = self.engine.generate(prompt_text, sampling_params, + request_id, prompt_ids, lora_request) # Streaming response if request.stream: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 06e7a9225fefb..c1f1744a118bd 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -136,23 +136,24 @@ async def create_completion(self, request: CompletionRequest, for i, prompt in enumerate(prompts): if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt_ids=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) else: - input_ids = self._validate_prompt_and_tokenize( + prompt_formats = self._validate_prompt_and_tokenize( request, prompt=prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens) + prompt_ids, prompt_text = prompt_formats generators.append( - self.engine.generate(prompt, + self.engine.generate(prompt_text, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids, + prompt_token_ids=prompt_ids, lora_request=lora_request)) except ValueError as e: # TODO: Use a vllm-specific Validation Error @@ -326,7 +327,8 @@ def request_output_to_completion_response( output_text = prompt_text elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids - top_logprobs = prompt_logprobs + output.logprobs + top_logprobs = (prompt_logprobs + output.logprobs + if request.logprobs else None) output_text = prompt_text + output.text else: token_ids = output.token_ids @@ -334,6 +336,9 @@ def request_output_to_completion_response( output_text = output.text if request.logprobs is not None: + assert top_logprobs is not None, ( + "top_logprobs must be provided when logprobs " + "is requested") logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f69388c0251e..77a568b564039 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import conint @@ -99,27 +99,32 @@ def _create_logprobs( last_token_len = 0 if num_output_top_logprobs: logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] - if step_top_logprobs is not None: - token_logprob = step_top_logprobs[token_id].logprob + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(None) + logprobs.top_logprobs.append(None) else: - token_logprob = None - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - logprobs.token_logprobs.append(token_logprob) + token_logprob = step_top_logprobs[token_id].logprob + token = step_top_logprobs[token_id].decoded_token + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + p.decoded_token: p.logprob + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + if len(logprobs.text_offset) == 0: logprobs.text_offset.append(initial_text_offset) else: logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len) last_token_len = len(token) - - if num_output_top_logprobs: - logprobs.top_logprobs.append({ - p.decoded_token: p.logprob - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) return logprobs def create_error_response( @@ -164,12 +169,12 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]: raise ValueError("The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None - ) -> List[int]: + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -187,6 +192,8 @@ def _validate_prompt_and_tokenize( else: input_ids = prompt_ids + input_text = prompt if prompt is not None else self.tokenizer.decode( + prompt_ids) token_num = len(input_ids) if request.max_tokens is None: @@ -201,4 +208,4 @@ def _validate_prompt_and_tokenize( f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) else: - return input_ids + return input_ids, input_text From 1e96c3341a4e055ae392085fecc7a672295b71c2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 15:18:57 -0700 Subject: [PATCH 021/413] Add extra punica sizes to support bigger vocabs (#4015) --- csrc/punica/bgmv/bgmv_config.h | 12 +++++- csrc/punica/punica_ops.cc | 14 +++--- tests/lora/test_layers.py | 78 +++++++++++++++++++--------------- tests/lora/test_punica.py | 49 +++++++++++++++++++-- vllm/lora/layers.py | 4 +- 5 files changed, 109 insertions(+), 48 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 1084a0f20df6b..9b76b98ab3322 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + f(in_T, out_T, W_T, narrow, 64000) \ + f(in_T, out_T, W_T, narrow, 64256) \ + f(in_T, out_T, W_T, narrow, 64512) \ + f(in_T, out_T, W_T, narrow, 102400) \ + f(in_T, out_T, W_T, narrow, 102656) \ + f(in_T, out_T, W_T, narrow, 102912) \ + f(in_T, out_T, W_T, narrow, 128000) \ + f(in_T, out_T, W_T, narrow, 128256) \ + f(in_T, out_T, W_T, narrow, 128512) \ +// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA +// and vllm/tests/lora/test_punica.py // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 28739be14b862..7ebfd851c4feb 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, } } -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); +inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { + return (uint64_t(a) << 32) | uint64_t(b); } #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") @@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { template inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, const int64_t *lora_indices, - uint16_t in_features, uint16_t out_features, + uint32_t in_features, uint32_t out_features, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { - switch (pack_u16(in_features, out_features)) { + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u16(feat_in, feat_out): \ + case pack_u32(feat_in, feat_out): \ bgmv_kernel(Y, X, W, lora_indices, y_offset, \ full_y_size, batch_size, num_layers, \ layer_idx, scale); \ @@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: @@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 71ce6f1764832..e9e0c8554c1ef 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,7 +170,8 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -203,12 +204,13 @@ def create_random_embedding_layer(): active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -240,12 +242,13 @@ def create_random_embedding_layer(): active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -263,7 +266,9 @@ def create_random_embedding_layer(): # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding_data = torch.rand_like(embedding.weight.data) embedding.weight.data = embedding_data - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - 512 + lora_config.lora_extra_vocab_size * max_loras, + vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=512) - expanded_embedding.weight.data[:512, :] = embedding_data + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place lora_embedding = VocabParallelEmbeddingWithLoRA( @@ -298,7 +303,7 @@ def create_random_embedding_layer(): id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, 512 + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size)), generate_embeddings_tensor=256, ) @@ -316,7 +321,7 @@ def create_random_embedding_layer(): active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -327,16 +332,18 @@ def create_random_embedding_layer(): for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): embedding_id = lora_id - 1 - input_[-1] = 512 + (embedding_id * embeddings_tensor_len) - original_input_[-1] = 512 - input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = 512 + embeddings_tensor_len - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[512:512 + + expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -370,14 +377,15 @@ def create_random_embedding_layer(): active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) original_inputs = deepcopy(inputs) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(original_inputs)) @@ -393,7 +401,9 @@ def create_random_embedding_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_lm_head_logits_processor(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def _pretest(): - linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, - 1024, 32000) + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, vocab_size) linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, 32000:] = 0 + linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - 32000 + lora_config.lora_extra_vocab_size, 32000) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device) lora_logits_processor.create_lora_weights(max_loras, lora_config) @@ -444,7 +454,7 @@ def _pretest(): lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size, ) lora_logits_processor.set_mapping(*mapping_info, ) @@ -460,7 +470,7 @@ def _pretest(): org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = (32000 + + logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -468,11 +478,11 @@ def _pretest(): result = logits_processor._get_logits(hidden_states=input_, embedding=linear.weight, embedding_bias=None) - result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = 32000 + logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -489,14 +499,14 @@ def _pretest(): lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, - embedding_bias=None)[:, :32000] + embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 2736a1c7ade27..cab8b44ccd2df 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,10 +43,51 @@ def _lora_ref_impl( H1 = H2 = [ - 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, - 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, - 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, - 32768, 33024 + 128, + 256, + 512, + 1024, + 1152, + 1280, + 1536, + 2048, + 2304, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 4608, + 5120, + 5504, + 5632, + 6144, + 6848, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 22016, + 24576, + 27392, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 49152, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, ] SEED = [0xabcdabcd987] CUDA_DEVICES = [ diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8ec4dcfd6137..5456b5613c47a 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -935,9 +935,9 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - if 32000 < self.base_layer.vocab_size > 33024: + if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 33024") + "32000 >= vocab_size <= 128512") self.lora_a_stacked = torch.zeros( ( max_loras, From e46a60aa4c90cf3dfd9b90782f2eeabbda935eef Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 11 Apr 2024 23:34:12 +0100 Subject: [PATCH 022/413] [BugFix] Fix handling of stop strings and stop token ids (#3672) --- tests/conftest.py | 2 +- .../{samplers => engine}/test_stop_reason.py | 2 +- tests/engine/test_stop_strings.py | 111 ++++++++++++++++++ vllm/engine/llm_engine.py | 98 ++++++++++------ vllm/outputs.py | 4 +- vllm/sampling_params.py | 9 ++ vllm/sequence.py | 6 + vllm/transformers_utils/detokenizer.py | 7 +- 8 files changed, 202 insertions(+), 37 deletions(-) rename tests/{samplers => engine}/test_stop_reason.py (97%) create mode 100644 tests/engine/test_stop_strings.py diff --git a/tests/conftest.py b/tests/conftest.py index a7e8963af0eda..5c50fc2d1bab6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -401,7 +401,7 @@ def __del__(self): cleanup() -@pytest.fixture +@pytest.fixture(scope="session") def vllm_runner(): return VllmRunner diff --git a/tests/samplers/test_stop_reason.py b/tests/engine/test_stop_reason.py similarity index 97% rename from tests/samplers/test_stop_reason.py rename to tests/engine/test_stop_reason.py index b242c405a4fb6..b2f521a8ae4ce 100644 --- a/tests/samplers/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -3,7 +3,7 @@ 2. One of the provided stop tokens 3. The EOS token -Run `pytest tests/samplers/test_stop_reason.py`. +Run `pytest tests/engine/test_stop_reason.py`. """ import pytest diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py new file mode 100644 index 0000000000000..6b747beb4b543 --- /dev/null +++ b/tests/engine/test_stop_strings.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional + +import pytest + +from vllm import CompletionOutput, LLMEngine, SamplingParams + +MODEL = "meta-llama/llama-2-7b-hf" +MAX_TOKENS = 200 + + +@pytest.fixture(scope="session") +def vllm_model(vllm_runner): + return vllm_runner(MODEL) + + +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=".") + + _test_stopping(vllm_model.model.llm_engine, + stop=["."], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization.", + expected_reason=".") + + +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer organization. We are a ", + expected_reason="group of peo") + + _test_stopping( + vllm_model.model.llm_engine, + stop=["group of peo", "short"], + include_in_output=True, + expected_output= + "VLLM is a 100% volunteer organization. We are a group of peo", + expected_reason="group of peo") + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=False, + expected_output="VLLM is a 100% volunteer or", + expected_reason="gani") + + _test_stopping(vllm_model.model.llm_engine, + stop=["gani"], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organi", + expected_reason="gani") + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + # token id 13013 => " organization" + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=False, + expected_output="VLLM is a 100% volunteer", + expected_reason=13013) + + _test_stopping(vllm_model.model.llm_engine, + stop_token_ids=[13013], + include_in_output=True, + expected_output="VLLM is a 100% volunteer organization", + expected_reason=13013) + + +def _test_stopping(llm_engine: LLMEngine, + expected_output: str, + expected_reason: Any, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_in_output: bool = False) -> None: + llm_engine.add_request( + "id", "A story about vLLM:\n", + SamplingParams( + temperature=0.0, + max_tokens=MAX_TOKENS, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_in_output, + ), None) + + output: Optional[CompletionOutput] = None + output_text = "" + stop_reason = None + while llm_engine.has_unfinished_requests(): + (request_output, ) = llm_engine.step() + (output, ) = request_output.outputs + + # Ensure we don't backtrack + assert output.text.startswith(output_text) + output_text = output.text + stop_reason = output.stop_reason + + assert output is not None + assert output_text == expected_output + assert stop_reason == expected_reason diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ddfdda898a5c6..a91629a630591 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for seq, _ in child_seqs: if seq_group.sampling_params.detokenize: - self.detokenizer.decode_sequence_inplace( + new_char_count = self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) - self._check_stop(seq, seq_group.sampling_params) + else: + new_char_count = 0 + self._check_stop(seq, new_char_count, seq_group.sampling_params) # Non-beam search case if not seq_group.sampling_params.use_beam_search: @@ -798,56 +800,86 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, + def _check_stop(self, seq: Sequence, new_char_count: int, sampling_params: SamplingParams) -> None: - """Stop the finished sequences.""" - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + """Stop the finished sequences. - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ # Check if the minimum number of tokens has been generated yet; # skip the stop string/token checks if not if seq.get_output_len() < sampling_params.min_tokens: return - if sampling_params.detokenize: - for stop_str in sampling_params.stop: - if seq.output_text.endswith(stop_str): - self._finalize_sequence(seq, sampling_params, stop_str) - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() if last_token_id in sampling_params.stop_token_ids: - stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( - last_token_id) - self._finalize_sequence(seq, sampling_params, stop_str) + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED seq.stop_reason = last_token_id return - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str return - def _finalize_sequence(self, seq: Sequence, - sampling_params: SamplingParams, - stop_string: str) -> None: - if sampling_params.include_stop_str_in_output: + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return - if stop_string and seq.output_text.endswith(stop_string): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_string)] + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/outputs.py b/vllm/outputs.py index 61fe20bfc2744..d01be0eb0efd2 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # always has the logprobs of the sampled tokens even if the # logprobs are not requested. include_logprobs = seq_group.sampling_params.logprobs is not None + text_buffer_length = seq_group.sampling_params.output_text_buffer_length outputs = [ - CompletionOutput(seqs.index(seq), seq.output_text, + CompletionOutput(seqs.index(seq), + seq.get_output_text_to_return(text_buffer_length), seq.get_output_token_ids(), seq.get_cumulative_logprob(), seq.output_logprobs if include_logprobs else None, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4fdc3c6dedaef..0b9787608798c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -166,6 +166,13 @@ def __init__( self.logits_processors = logits_processors self.include_stop_str_in_output = include_stop_str_in_output self.truncate_prompt_tokens = truncate_prompt_tokens + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + else: + self.output_text_buffer_length = 0 + self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -226,6 +233,8 @@ def _verify_args(self) -> None: and self.truncate_prompt_tokens < 1): raise ValueError(f"truncate_prompt_tokens must be >= 1, " f"got {self.truncate_prompt_tokens}") + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " diff --git a/vllm/sequence.py b/vllm/sequence.py index 77029908c2218..cdb6cce6f0255 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -235,6 +235,12 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + def get_output_text_to_return(self, buffer_length: int): + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 486c1938e1e10..005932f1e3df4 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace( prev_tokens.extend(next_iter_tokens) def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> None: + prms: SamplingParams) -> int: """Decodes the new token for a sequence. In-place operation. Args: seq: The sequence to decode. prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. """ all_input_ids = seq.get_token_ids() token_id_generated_this_iteration = all_input_ids[-1] @@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence, seq.read_offset = read_offset seq.output_text += new_decoded_token_text + return len(new_decoded_token_text) + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], From c2b4a1bce9a7707179cdfab2fb498c20b2b221e6 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Thu, 11 Apr 2024 17:17:21 -0700 Subject: [PATCH 023/413] [Doc] Add typing hints / mypy types cleanup (#3816) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- benchmarks/backend_request_func.py | 62 ++++++++++--------- docs/source/conf.py | 3 +- setup.py | 5 +- vllm/core/block/interfaces.py | 31 ++++++---- vllm/engine/metrics.py | 10 ++- vllm/logger.py | 8 ++- .../model_executor/layers/rotary_embedding.py | 15 ++--- vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/dbrx.py | 2 +- .../transformers_utils/tokenizers/baichuan.py | 10 +-- vllm/utils.py | 4 +- 11 files changed, 90 insertions(+), 64 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index ad428bd1c3644..bab570252c929 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -27,8 +27,8 @@ class RequestFuncInput: class RequestFuncOutput: generated_text: str = "" success: bool = False - latency: float = 0 - ttft: float = 0 # Time to first token + latency: float = 0.0 + ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies prompt_len: int = 0 @@ -58,23 +58,24 @@ async def async_request_tgi( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -119,23 +120,24 @@ async def async_request_trt_llm( output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data:") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data:") data = json.loads(chunk) timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -151,7 +153,7 @@ async def async_request_trt_llm( output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -195,7 +197,7 @@ async def async_request_deepspeed_mii( output.generated_text = parsed_resp["text"][0] output.success = True else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False @@ -234,19 +236,20 @@ async def async_request_openai_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -255,7 +258,7 @@ async def async_request_openai_completions( if data["choices"][0]["text"]: timestamp = time.perf_counter() # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -315,19 +318,20 @@ async def async_request_openai_chat_completions( output.prompt_len = request_func_input.prompt_len generated_text = "" - ttft = 0 + ttft = 0.0 st = time.perf_counter() most_recent_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: continue - chunk = remove_prefix(chunk.decode("utf-8"), "data: ") + chunk = remove_prefix(chunk_bytes.decode("utf-8"), + "data: ") if chunk == "[DONE]": latency = time.perf_counter() - st else: @@ -337,7 +341,7 @@ async def async_request_openai_chat_completions( delta = data["choices"][0]["delta"] if delta.get("content", None): # First token - if ttft == 0: + if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft @@ -354,7 +358,7 @@ async def async_request_openai_chat_completions( output.success = True output.latency = latency else: - output.error = response.reason + output.error = response.reason or "" output.success = False except Exception: output.success = False diff --git a/docs/source/conf.py b/docs/source/conf.py index 44cda7c99cdd5..7a8c365ffb3bb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,6 +12,7 @@ import logging import sys +from typing import List from sphinx.ext import autodoc @@ -45,7 +46,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +exclude_patterns: List[str] = [] # Exclude the prompt "$" when copying code copybutton_prompt_text = r"\$ " diff --git a/setup.py b/setup.py index 98c92f9196e7e..9f0814e9f3bff 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ import subprocess import sys from shutil import which -from typing import List +from typing import Dict, List import torch from packaging.version import Version, parse @@ -52,7 +52,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: class cmake_build_ext(build_ext): # A dict of extension directories that have been configured. - did_config = {} + did_config: Dict[str, bool] = {} # # Determine number of compilation jobs and optionally nvcc compile threads. @@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ + assert CUDA_HOME is not None, "CUDA_HOME is not set" nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True) output = nvcc_output.split() diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 9f466566f096b..fbceacf0ec417 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,5 +1,5 @@ -from abc import ABC, abstractmethod, abstractproperty -from typing import Dict, List, Optional, Protocol +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol from vllm.utils import Device @@ -10,23 +10,28 @@ class Block(ABC): def append_token_ids(self, token_ids: List[int]) -> None: pass - @abstractproperty + @property + @abstractmethod def block_id(self) -> Optional[int]: pass - @abstractproperty + @property + @abstractmethod def token_ids(self) -> List[int]: pass - @abstractproperty + @property + @abstractmethod def num_empty_slots(self) -> int: pass - @abstractproperty + @property + @abstractmethod def is_full(self) -> bool: pass - @abstractproperty + @property + @abstractmethod def prev_block(self) -> Optional["Block"]: pass @@ -47,12 +52,13 @@ def __call__( class BlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: pass @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + token_ids: List[int], device: Device) -> Block: pass @abstractmethod @@ -64,11 +70,12 @@ def fork(self, last_block: Block) -> List[Block]: pass @abstractmethod - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Device) -> int: pass - @abstractproperty - def all_block_ids(self) -> frozenset[int]: + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 905db52a1912b..02560907a1282 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Protocol import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -119,6 +119,12 @@ class Stats: time_e2e_requests: List[float] +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... + + class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" @@ -135,7 +141,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: self.labels = labels self.metrics = Metrics(labelnames=list(labels.keys())) - def info(self, type: str, obj: object) -> None: + def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": self.metrics.info_cache_config.info(obj.metrics_info()) diff --git a/vllm/logger.py b/vllm/logger.py index e5e46f5cce3fe..af9575085ef37 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -4,6 +4,7 @@ import logging import os import sys +from typing import Optional VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) @@ -26,7 +27,7 @@ def format(self, record): _root_logger = logging.getLogger("vllm") -_default_handler = None +_default_handler: Optional[logging.Handler] = None def _setup_logger(): @@ -55,7 +56,12 @@ def init_logger(name: str): # Use the same settings as above for root logger logger = logging.getLogger(name) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) + if VLLM_CONFIGURE_LOGGING: + if _default_handler is None: + raise ValueError( + "_default_handler is not set up. This should never happen!" + " Please open an issue on Github.") logger.addHandler(_default_handler) logger.propagate = False return logger diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index eb8d5f6dfb2a9..6519781c8a8eb 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int, # Find dim range bounds based on rotations -def _yarn_find_correction_range(low_rot: int, - high_rot: int, - dim: int, - base: float = 10000, - max_position_embeddings: int = 2048) -> int: +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( @@ -293,8 +294,8 @@ def __init__( *, extrapolation_factor: float = 1, attn_factor: float = 1, - beta_fast: float = 32, - beta_slow: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8a6ba6c5b396c..ce7a30dce72fa 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,10 +1,10 @@ -from typing import Optional +from typing import Dict, Optional from transformers import AutoConfig, PretrainedConfig from vllm.transformers_utils.configs import * -_CONFIG_REGISTRY = { +_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, "mpt": MPTConfig, diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py index 3a19af7129e73..1d2724f22abd6 100644 --- a/vllm/transformers_utils/configs/dbrx.py +++ b/vllm/transformers_utils/configs/dbrx.py @@ -12,7 +12,7 @@ logger = logging.get_logger(__name__) -DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore class DbrxAttentionConfig(PretrainedConfig): diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py index 02045bdcb2ccf..79894035cb1f1 100644 --- a/vllm/transformers_utils/tokenizers/baichuan.py +++ b/vllm/transformers_utils/tokenizers/baichuan.py @@ -16,11 +16,11 @@ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} -PRETRAINED_VOCAB_FILES_MAP = { +PRETRAINED_VOCAB_FILES_MAP = { # type: ignore "vocab_file": {}, "tokenizer_file": {}, } -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore class BaichuanTokenizer(PreTrainedTokenizer): @@ -148,9 +148,9 @@ def save_vocabulary(self, `Tuple(str)`: Paths to the files saved. """ if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) " - "should be a directory") - return + raise ValueError(f"Vocabulary path ({save_directory}) " + "should be a directory") + out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + diff --git a/vllm/utils.py b/vllm/utils.py index fdb0a3768ab0d..669b65891d0db 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -294,7 +294,7 @@ def create_kv_caches_with_random( head_size: int, cache_dtype: Optional[Union[str, torch.dtype]], model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: Optional[int] = 0, + seed: int = 0, device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) @@ -400,7 +400,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): gc.collect() -def str_to_int_tuple(s: str) -> Tuple[int]: +def str_to_int_tuple(s: str) -> Tuple[int, ...]: """Convert a string to a tuple of integers.""" try: return tuple(map(int, s.split(","))) From 1096717ae9e0b414ad625c1a12354dd1d949ffb1 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Fri, 12 Apr 2024 12:02:44 +0800 Subject: [PATCH 024/413] [Core] Support LoRA on quantized models (#4012) --- tests/lora/conftest.py | 5 + tests/lora/test_quant_model.py | 179 +++++++++++++++++++++++++++++++++ vllm/config.py | 9 +- vllm/lora/layers.py | 67 +++++++----- 4 files changed, 234 insertions(+), 26 deletions(-) create mode 100644 tests/lora/test_quant_model.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 207c635e2dc86..1127cc33183c9 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -143,6 +143,11 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def tinyllama_lora_files(): + return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py new file mode 100644 index 0000000000000..3d86a4366aa57 --- /dev/null +++ b/tests/lora/test_quant_model.py @@ -0,0 +1,179 @@ +# Adapted from +# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py +from dataclasses import dataclass +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +from .conftest import cleanup + + +@dataclass +class ModelWithQuantization: + model_path: str + quantization: str + + +MODELS: List[ModelWithQuantization] = [ + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), +] + + +def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256): + raw_prompts = [ + "Give me an orange-ish brown color", + "Give me a neon pink color", + ] + + def format_prompt_tuples(prompt): + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + prompts = [format_prompt_tuples(p) for p in raw_prompts] + + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=max_tokens, + stop=["<|im_end|>"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", [1]) +def test_quant_model_lora(tinyllama_lora_files, model, tp_size): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < tp_size: + # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + max_model_len=400, + tensor_parallel_size=tp_size, + quantization=model.quantization, + trust_remote_code=True) + + if model.quantization is None: + expected_no_lora_output = [ + "Here are some examples of orange-brown colors", + "I'm sorry, I don't have" + ] + expected_lora_output = [ + "#ff8050", + "#ff8080", + ] + elif model.quantization == "AWQ": + expected_no_lora_output = [ + "I'm sorry, I don't understand", + "I'm sorry, I don't understand", + ] + expected_lora_output = [ + "#f07700: A v", + "#f00000: A v", + ] + elif model.quantization == "GPTQ": + expected_no_lora_output = [ + "I'm sorry, I don't have", + "I'm sorry, I don't have", + ] + expected_lora_output = [ + "#f08800: This is", + "#f07788 \n#", + ] + + def expect_match(output, expected_output): + # HACK: GPTQ lora outputs are just incredibly unstable. + # Assert that the outputs changed. + if (model.quantization == "GPTQ" + and expected_output is expected_lora_output): + assert output != expected_no_lora_output + for i, o in enumerate(output): + assert o.startswith( + '#'), f"Expected example {i} to start with # but got {o}" + return + assert output == expected_output + + max_tokens = 10 + + print("lora adapter created") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 1") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=1, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("no lora") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=0, + max_tokens=max_tokens) + expect_match(output, expected_no_lora_output) + + print("lora 2") + output = do_sample(llm, + tinyllama_lora_files, + lora_id=2, + max_tokens=max_tokens) + expect_match(output, expected_lora_output) + + print("removing lora") + + del llm + cleanup() + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.skip("Requires multiple GPUs") +def test_quant_model_tp_equality(tinyllama_lora_files, model): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 2: + # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") + + llm_tp1 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1, + quantization=model.quantization, + trust_remote_code=True) + output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(model=model.model_path, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2, + quantization=model.quantization) + output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 diff --git a/vllm/config.py b/vllm/config.py index 4102edbe01d35..da7eb2810ff05 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -822,9 +822,12 @@ def verify_with_model_config(self, model_config: ModelConfig): self.lora_dtype = model_config.dtype elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - if model_config.quantization is not None: - raise ValueError( - "LoRA is not supported with quantized models yet.") + if model_config.quantization and model_config.quantization not in [ + "awq", "gptq" + ]: + # TODO support marlin and squeezellm + logger.warning(f"{model_config.quantization} quantization is not " + "tested with LoRA yet.") def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5456b5613c47a..4b9653de73a88 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -29,6 +29,19 @@ pass +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + if hasattr(base_layer, "weight"): + return base_layer.weight.device + if hasattr(base_layer, "linear_weights") and isinstance( + base_layer.linear_weights, dict): + values = list(base_layer.linear_weights.values()) + if len(values) and isinstance(values[0], torch.Tensor): + return values[0].device + raise ValueError(f"Unsupported base layer: {base_layer}") + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -302,6 +315,9 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size_per_partition + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -312,17 +328,17 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None @@ -442,18 +458,18 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, - self.base_layer.weight.shape[0] // 2, + self.output_size // 2, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) for _ in range(n_slices)) self.indices: Optional[torch.Tensor] = None @@ -619,25 +635,25 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) self.lora_b_stacked = ( @@ -647,7 +663,7 @@ def create_lora_weights( self.q_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -655,7 +671,7 @@ def create_lora_weights( self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), torch.zeros( max_loras, @@ -663,7 +679,7 @@ def create_lora_weights( self.kv_proj_shard_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ), ) @@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() self.base_layer = base_layer + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) def create_lora_weights( self, @@ -777,20 +796,20 @@ def create_lora_weights( max_loras, 1, lora_config.max_lora_rank, - self.base_layer.weight.shape[1], + self.input_size, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.base_layer.weight.shape[0], + self.output_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, + device=self.device, ) self.indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None @@ -809,7 +828,7 @@ def set_lora( self.reset_lora(index) if self.base_layer.tp_size > 1: tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.base_layer.weight.shape[1] + shard_size = self.input_size start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] @@ -884,7 +903,9 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight + + return self.base_layer.weight if hasattr( + self.base_layer, "weight") else self.base_layer.qweight @classmethod def can_replace_layer(cls, source_layer: nn.Module, From 7fd3949a0b1c6cd0dcd7066aca48d9d589f2f68e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 13:30:54 +0800 Subject: [PATCH 025/413] [Frontend][Core] Move `merge_async_iterators` to utils (#4026) --- vllm/entrypoints/openai/serving_completion.py | 38 +----------------- vllm/utils.py | 40 ++++++++++++++++++- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c1f1744a118bd..e24aa2489a80f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,4 +1,3 @@ -import asyncio import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Tuple) @@ -17,7 +16,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput -from vllm.utils import random_uuid +from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: return prompt_is_tokens, prompts -def merge_async_iterators(*iterators): - """Merge multiple asynchronous iterators into a single iterator. - - This method handle the case where some iterators finish before others. - When it yields, it yields a tuple (i, item) where i is the index of the - iterator that yields the item. - """ - queue = asyncio.Queue() - - finished = [False] * len(iterators) - - async def producer(i, iterator): - try: - async for item in iterator: - await queue.put((i, item)) - except Exception as e: - await queue.put(e) - finished[i] = True - - _tasks = [ - asyncio.create_task(producer(i, iterator)) - for i, iterator in enumerate(iterators) - ] - - async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item - await asyncio.gather(*_tasks) - - return consumer() - - class OpenAIServingCompletion(OpenAIServing): def __init__(self, diff --git a/vllm/utils.py b/vllm/utils.py index 669b65891d0db..0967dfc969c8a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -9,8 +9,8 @@ from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname -from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List, - Optional, Tuple, TypeVar, Union) +from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, + Hashable, List, Optional, Tuple, TypeVar, Union) import psutil import torch @@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: return _async_wrapper +def merge_async_iterators( + *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i: int, iterator: AsyncIterator[T]): + try: + async for item in iterator: + await queue.put((i, item)) + except Exception as e: + await queue.put(e) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + await asyncio.gather(*_tasks) + + return consumer() + + def get_ip() -> str: host_ip = os.environ.get("HOST_IP") if host_ip: From 36729bac1303b655b816b77f45b17237bfafd692 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 Apr 2024 01:56:57 +0900 Subject: [PATCH 026/413] [Test] Test multiple attn backend for chunked prefill. (#4023) --- .buildkite/test-pipeline.yaml | 8 +++++++- .../test_basic_correctness.py | 6 ------ .../basic_correctness/test_chunked_prefill.py | 4 ---- vllm/attention/backends/rocm_flash_attn.py | 18 ++++++------------ 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 695290ed74ab5..8d7d6304cf12e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -12,7 +12,13 @@ steps: command: pytest -v -s async_engine - label: Basic Correctness Test - command: pytest -v -s basic_correctness + commands: + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test command: pytest -v -s core diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index bd4c7ea3301be..97cff623c5e1d 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -4,8 +4,6 @@ """ import pytest -from vllm.attention.selector import VLLM_ATTENTION_BACKEND - MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -16,7 +14,6 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False, True]) -@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -25,10 +22,7 @@ def test_models( dtype: str, max_tokens: int, enforce_eager: bool, - attn_backend: str, - monkeypatch, ) -> None: - monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend) hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9ff07b3c09020..d83416eb51b43 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -33,10 +33,6 @@ def test_models( enforce_eager: bool, tensor_parallel_size: int, ) -> None: - if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16 - and not enforce_eager): - pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} " - "for high TP to save testing time.") max_num_seqs = min(chunked_prefill_token_size, 256) enable_chunked_prefill = False max_num_batched_tokens = None diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e55435cd2c947..c42660fb8f74f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -162,7 +162,7 @@ def __init__( # AMD Radeon 7900 series (gfx1100) currently does not support # xFormers nor FlashAttention. As a temporary workaround, we use # naive PyTorch implementation of attention. - self.attn_fuc = _naive_attention() + self.attn_fuc = _naive_attention logger.debug("Using naive attention in ROCmBackend") elif self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 @@ -334,26 +334,21 @@ def _naive_attention( prompt_lens: List[int], scale: float, ) -> torch.Tensor: - num_tokens = query.shape[0] output = torch.empty_like(query) start = 0 for _, prompt_len in enumerate(prompt_lens): end = start + prompt_len out = _naive_masked_attention( - query[None, start:end], - key[None, start:end], - value[None, start:end], + query[start:end], + key[start:end], + value[start:end], scale, ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) start += prompt_len - # Using view got RuntimeError: view size is not compatible - # with input tensor's size and stride (at least one - # dimension spans across two contiguous subspaces). - # Use reshape instead. - return output.reshape(num_tokens, -1) + return output def _naive_masked_attention( @@ -362,14 +357,13 @@ def _naive_masked_attention( value: torch.Tensor, scale: float, ) -> torch.Tensor: - seq_len, _, _ = query.shape + seq_len, head_size, head_dim = query.shape attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=query.dtype, device=query.device), diagonal=1) attn_mask = attn_mask * torch.finfo(query.dtype).min - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) From 96b6a6d790115d04bb87d410f3bdd5d7d85b43f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 12:35:44 -0700 Subject: [PATCH 027/413] [Bugfix] fix type hint for py 3.8 (#4036) --- vllm/executor/executor_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c18edd75d7a4d..55bccfa8e3ca9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -31,7 +31,7 @@ def __init__( raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. From d4ec9ffb9574988132d927fd1615180522877262 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 12 Apr 2024 13:56:04 -0700 Subject: [PATCH 028/413] [Misc] Fix typo in scheduler.py (#4022) --- vllm/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 2942eab735a92..e44f983e15374 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -674,7 +674,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - The current policy is designed to opimimize the throughput. First, + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can be swapped or preempted. From 09473ee41c0a22c4d18936ea7eb2328071c19308 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 13 Apr 2024 06:35:50 +0900 Subject: [PATCH 029/413] [mypy] Add mypy type annotation part 1 (#4006) --- .github/workflows/mypy.yaml | 50 ++++++++++++++++++++++++++ format.sh | 22 +++++++++--- pyproject.toml | 5 ++- requirements-common.txt | 3 +- requirements-dev.txt | 2 +- vllm/config.py | 9 +++-- vllm/core/block_manager_v1.py | 12 ++++--- vllm/core/block_manager_v2.py | 4 ++- vllm/core/interfaces.py | 4 ++- vllm/core/scheduler.py | 25 +++++++------ vllm/distributed/communication_op.py | 10 +++--- vllm/engine/ray_utils.py | 18 ++++++---- vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/llm.py | 8 +++-- vllm/executor/cpu_executor.py | 4 +-- vllm/executor/gpu_executor.py | 4 +-- vllm/executor/neuron_executor.py | 4 +-- vllm/executor/ray_gpu_executor.py | 11 +++--- vllm/sampling_params.py | 5 +-- vllm/sequence.py | 8 ++--- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/detokenizer.py | 7 ++-- vllm/transformers_utils/tokenizer.py | 4 +-- vllm/usage/usage_lib.py | 8 ++--- vllm/utils.py | 12 ++++--- 25 files changed, 171 insertions(+), 72 deletions(-) create mode 100644 .github/workflows/mypy.yaml diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 0000000000000..fbe0f816fd4af --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,50 @@ +name: mypy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.9.0 + pip install types-setuptools + pip install types-PyYAML + pip install types-requests + pip install types-setuptools + - name: Mypy + run: | + mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + + # TODO(sang): Follow up + # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + diff --git a/format.sh b/format.sh index deb57b2b049d1..1c195b899c742 100755 --- a/format.sh +++ b/format.sh @@ -93,9 +93,23 @@ fi echo 'vLLM yapf: Done' # Run mypy -# TODO(zhuohan): Enable mypy -# echo 'vLLM mypy:' -# mypy +echo 'vLLM mypy:' +mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + +# TODO(sang): Follow up +# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + CODESPELL_EXCLUDES=( '--skip' '*docs/source/_build/**' @@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then exit 1 fi - - diff --git a/pyproject.toml b/pyproject.toml index 2a00d6796ee02..b870a4b85897b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,13 @@ ignore = [ python_version = "3.8" ignore_missing_imports = true + check_untyped_defs = true files = "vllm" # TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" +exclude = [ + "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", +] [tool.codespell] diff --git a/requirements-common.txt b/requirements-common.txt index ff053388a23e1..c96f9c9937fb0 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,4 +11,5 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -outlines == 0.0.34 # Requires torch >= 2.1.0 \ No newline at end of file +outlines == 0.0.34 # Requires torch >= 2.1.0 +typing_extensions \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1b..96dfda6faf00f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ codespell==2.2.6 isort==5.13.2 # type checking -mypy==0.991 +mypy==1.9.0 types-PyYAML types-requests types-setuptools diff --git a/vllm/config.py b/vllm/config.py index da7eb2810ff05..bbda4ecf3cc56 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2,7 +2,7 @@ import json import os from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch from packaging.version import Version @@ -141,7 +141,7 @@ def _verify_load_format(self) -> None: supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = [] + rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " @@ -679,6 +679,9 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + assert (speculative_model is not None + and num_speculative_tokens is not None) + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None @@ -993,7 +996,7 @@ def _get_and_verify_max_len( derived_max_model_len *= scaling_factor if max_model_len is None: - max_model_len = derived_max_model_len + max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e7e3b4dc1e9b4..e391a3b1e5a33 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,5 +1,6 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set @@ -231,10 +232,10 @@ def __init__( if self.enable_caching: logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) else: self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) @@ -588,7 +589,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]: for b in takewhile(lambda b: b.computed, block_table[:-1]) ] - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Return the block ids that are common for a given sequence group. Used in prefill (can skip prefill of some blocks). diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 813e71ad883b2..19f0cf415eb34 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" +from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional from vllm.core.block.block_table import BlockTable @@ -205,7 +206,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): # as computed. self.block_allocator.mark_blocks_as_computed() - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Determine which blocks for which we skip prefill. With prefix caching we can skip prefill for previously-generated blocks. diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 711536bcc97be..c1f68a2e891bf 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,5 +1,6 @@ import enum from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from typing import Dict, List from vllm.sequence import Sequence, SequenceGroup @@ -103,7 +104,8 @@ def access_all_blocks_in_seq( pass @abstractmethod - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index e44f983e15374..18ddcd1d6d466 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -42,8 +42,8 @@ class SchedulingBudget: """ token_budget: int max_num_seqs: int - _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) - _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) + _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 @@ -133,7 +133,7 @@ def is_empty(self) -> bool: return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) - def _sort_by_lora_ids(self) -> bool: + def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) def has_unfinished_seqs(self) -> bool: - return self.waiting or self.running or self.swapped + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) @@ -404,7 +405,7 @@ def _schedule_running( budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.pop(seq_group.lora_int_id) + curr_loras.remove(seq_group.lora_int_id) if running_queue: # Preempt the lowest-priority sequence groups. @@ -496,7 +497,7 @@ def _schedule_swapped( now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) - leftover_swapped = deque() + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] @@ -507,7 +508,9 @@ def _schedule_swapped( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. @@ -593,7 +596,7 @@ def _schedule_prefills( # Copy the queue so that the input queue is not modified. waiting_queue = deque([s for s in waiting_queue]) - leftover_waiting_sequences = deque() + leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] @@ -635,6 +638,8 @@ def _schedule_prefills( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None if (self.lora_enabled and lora_int_id > 0 and lora_int_id not in curr_loras and len(curr_loras) >= self.lora_config.max_loras): @@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self): token_budget=self.scheduler_config.max_num_batched_tokens, max_num_seqs=self.scheduler_config.max_num_seqs, ) - curr_loras = set() + curr_loras: Set[int] = set() remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) @@ -1087,7 +1092,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int: def _get_num_new_tokens(self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> Tuple[int, bool]: + budget: SchedulingBudget) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 1004d626b6a4b..a3e93691a1e8e 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup @@ -144,7 +144,7 @@ def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, -) -> Dict[Any, Union[torch.Tensor, Any]]: +) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary.""" group = group or torch.distributed.group.WORLD ranks = torch.distributed.get_process_group_ranks(group) @@ -157,10 +157,10 @@ def broadcast_tensor_dict( rank = torch.distributed.get_rank() if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): assert value.is_cuda, ( @@ -190,10 +190,10 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list(recv_metadata_list, src=src, group=group) - metadata_list = recv_metadata_list[0] + assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = [] - for key, value in metadata_list: + for key, value in recv_metadata_list[0]: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 70d5c9b1fae05..04d4ed83976d0 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,9 +1,10 @@ import pickle -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.utils import get_ip, is_hip, set_cuda_visible_devices +from vllm.worker.worker import Worker logger = init_logger(__name__) @@ -18,15 +19,20 @@ def __init__(self, init_cached_hf_modules=False) -> None: if init_cached_hf_modules: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() - self.worker = None + self._worker: Optional[Worker] = None # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn): - self.worker = worker_init_fn() + def init_worker(self, worker_init_fn: Callable[[], Worker]): + self._worker = worker_init_fn() + + @property + def worker(self) -> Worker: + assert self._worker is not None + return self._worker def __getattr__(self, name): return getattr(self.worker, name) @@ -70,8 +76,8 @@ def execute_model_compiled_dag_remote(self, ignored): logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with " "`pip install ray`.") - ray = None - RayWorkerVllm = None + ray = None # type: ignore + RayWorkerVllm = None # type: ignore def initialize_ray_cluster( diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 2a47eae112c12..587142adb9c6b 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 5777e8179a1c1..63ff0b30da552 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -170,8 +170,12 @@ def generate( multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + if prompts is not None: + num_requests = len(prompts) + else: + assert prompt_token_ids is not None + num_requests = len(prompt_token_ids) + for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index eda4e8989c163..33e67d8b3eec2 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch @@ -61,7 +61,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ca5cb7367c5..f20221a0b941a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -66,7 +66,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 57436a85cfa27..ee8e87432fa67 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -47,7 +47,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6c0ccd7e64c90..b937693c92257 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,7 +3,7 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -197,7 +197,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers, ) - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes @@ -205,7 +205,7 @@ def determine_num_available_blocks(self) -> tuple[int, int]: compatible with all workers. Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] + - Tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers("determine_num_available_blocks", ) @@ -276,7 +276,7 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, @@ -291,6 +291,7 @@ def _run_workers( if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. + assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. @@ -369,7 +370,7 @@ async def _run_workers_async( self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0b9787608798c..53a38b25bfdac 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -5,7 +5,8 @@ from typing import Callable, List, Optional, Union import torch -from pydantic import conint +from pydantic import Field +from typing_extensions import Annotated _SAMPLING_EPS = 1e-5 @@ -127,7 +128,7 @@ def __init__( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n diff --git a/vllm/sequence.py b/vllm/sequence.py index cdb6cce6f0255..dcde81df19923 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -171,10 +171,10 @@ def get_last_token_id(self) -> int: return self.prompt_token_ids[-1] return self.output_token_ids[-1] - def get_prompt_token_ids(self) -> int: + def get_prompt_token_ids(self) -> List[int]: return self.prompt_token_ids - def get_output_token_ids(self) -> int: + def get_output_token_ids(self) -> List[int]: return self.output_token_ids @property @@ -370,7 +370,7 @@ class SequenceGroupState: """Mutable state tied to a specific sequence group""" # torch.Generator used in seeded sampling - generator: Optional = None + generator: Optional = None # type: ignore class MultiModalData: @@ -599,7 +599,7 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> int: + def token_chunk_size(self) -> Optional[int]: """Return the number of tokens to be processed (chunk size).""" return self._token_chunk_size diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ce7a30dce72fa..1756c91a612f0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,7 +2,8 @@ from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import * +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + JAISConfig, MPTConfig, RWConfig) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 005932f1e3df4..f064c26c3f40c 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] + sub_texts: List[str] = [] + current_sub_text: List[str] = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: @@ -263,6 +263,7 @@ def detokenize_incrementally( tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if new_token_id >= len(tokenizer): @@ -271,6 +272,8 @@ def detokenize_incrementally( # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e216a99af91f9..5d3d5801c960d 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import * +from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async logger = init_logger(__name__) @@ -28,7 +28,7 @@ def get_cached_tokenizer( tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) - class CachedTokenizer(tokenizer.__class__): + class CachedTokenizer(tokenizer.__class__): # type: ignore @property def all_special_ids(self): diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 658fe5c98f5ee..b2672f7f1da61 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -7,7 +7,7 @@ from enum import Enum from pathlib import Path from threading import Thread -from typing import Dict, Optional +from typing import Any, Dict, Optional from uuid import uuid4 import cpuinfo @@ -124,7 +124,7 @@ def __init__(self) -> None: def report_usage(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any] = None) -> None: + extra_kvs: Optional[Dict[str, Any]] = None) -> None: t = Thread(target=self._report_usage_worker, args=(model_architecture, usage_context, extra_kvs or {}), daemon=True) @@ -132,13 +132,13 @@ def report_usage(self, def _report_usage_worker(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continous_usage() def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: # Platform information if torch.cuda.is_available(): device_property = torch.cuda.get_device_properties(0) diff --git a/vllm/utils.py b/vllm/utils.py index 0967dfc969c8a..4c0dc9ca729a9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -60,7 +60,7 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: Hashable) -> Optional[T]: return self.get(key) def __setitem__(self, key: Hashable, value: T) -> None: @@ -76,7 +76,7 @@ def get(self, key: Hashable, default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: - value = self.cache[key] + value: Optional[T] = self.cache[key] self.cache.move_to_end(key) else: value = default_value @@ -87,7 +87,7 @@ def put(self, key: Hashable, value: T) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: T): + def _on_remove(self, key: Hashable, value: Optional[T]): pass def remove_oldest(self): @@ -100,9 +100,11 @@ def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: run_on_remove = key in self.cache - value = self.cache.pop(key, default_value) + value: Optional[T] = self.cache.pop(key, default_value) if run_on_remove: self._on_remove(key, value) return value From fbb9d9eef48a29e0ea821bbf399e4bf9a08d6ac1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 16:40:39 -0700 Subject: [PATCH 030/413] [Core] fix custom allreduce default value (#4040) --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63ff0b30da552..9e08c253dc539 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -86,7 +86,7 @@ def __init__( swap_space: int = 4, enforce_eager: bool = False, max_context_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = True, + disable_custom_all_reduce: bool = False, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: From d04973ad5446fe05c06035f6b2d99402fc3ac7bf Mon Sep 17 00:00:00 2001 From: Bellk17 Date: Fri, 12 Apr 2024 16:41:26 -0700 Subject: [PATCH 031/413] Fix triton compilation issue (#3984) Co-authored-by: Woosuk Kwon --- vllm/attention/ops/triton_flash_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 87cf30cbef79a..e160411859f0b 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -415,7 +415,11 @@ def attn_fwd( return is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + if is_mqa: # noqa: SIM108 + off_h_k = off_h_q % hk + else: + off_h_k = off_h_q + n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k From b8aacac31a4e2e03381fdaef6f1e4bbb895f3b64 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Sat, 13 Apr 2024 07:56:37 +0800 Subject: [PATCH 032/413] [Bugfix] Fix LoRA bug (#4032) --- vllm/lora/layers.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4b9653de73a88..aac86351b15e1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -32,14 +32,17 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear if hasattr(base_layer, "weight"): return base_layer.weight.device - if hasattr(base_layer, "linear_weights") and isinstance( - base_layer.linear_weights, dict): - values = list(base_layer.linear_weights.values()) - if len(values) and isinstance(values[0], torch.Tensor): - return values[0].device - raise ValueError(f"Unsupported base layer: {base_layer}") + # GPTQ/AWQ/SqueezeLLM + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") def _apply_lora( From 546e7211684a28bbe53088961b4cf5123e235760 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 12 Apr 2024 18:43:37 -0700 Subject: [PATCH 033/413] [CI/Test] expand ruff and yapf for all supported python version (#4037) --- .github/workflows/mypy.yaml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/yapf.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index fbe0f816fd4af..6db0bb7645ecd 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index e8060e369a889..e71033f828006 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index b163c960db555..04f307bcf8b0e 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From 5c2e66e4871917c5d59cc4a8b89ef53e690e9bd9 Mon Sep 17 00:00:00 2001 From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com> Date: Fri, 12 Apr 2024 21:07:04 -0700 Subject: [PATCH 034/413] [Bugfix] More type hint fixes for py 3.8 (#4039) --- vllm/executor/executor_base.py | 2 +- vllm/worker/cpu_worker.py | 4 ++-- vllm/worker/neuron_worker.py | 4 ++-- vllm/worker/worker_base.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 55bccfa8e3ca9..bbfbfc689c99f 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -39,7 +39,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: ExecutorBase may require modification of the result, e.g. to ensure the selected cache sizes are compatible with all workers. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3989207e8dd83..41341b063bed7 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch import torch.distributed @@ -157,7 +157,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of blocks available for the KV cache. This determines how many KV blocks can fit into the configured CPU diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 6136d50d0c068..2f22f82c045db 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.distributed @@ -40,7 +40,7 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. Swapping is not yet supported, so always return num_cpu_blocks=0. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e3027c406ffeb..d8c9febb11584 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Dict, List, Tuple from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -18,14 +18,14 @@ def init_device(self) -> None: raise NotImplementedError @abstractmethod - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available blocks for the GPU KV cache and swappable CPU KV cache. The implementation may run profiling or other heuristics to determine the size of caches. - Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks are blocks that are "active" on the device and can be appended to. num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be appended to. From 98afde19fc273b1e6a695990b93ec07157b856f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Apr 2024 07:12:53 -0700 Subject: [PATCH 035/413] [Core][Distributed] improve logging for init dist (#4042) --- vllm/distributed/parallel_state.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9fceffe7cb88b..1258bf58cb453 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -8,6 +8,10 @@ import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + # Tensor model parallel group that the current rank belongs to. _TENSOR_MODEL_PARALLEL_GROUP = None # Pipeline model parallel group that the current rank belongs to. @@ -45,6 +49,8 @@ def init_distributed_environment( local_rank: int = -1, backend: str = "nccl", ): + logger.debug(f"{world_size=} {rank=} {local_rank=} " + f"{distributed_init_method=} {backend=}") if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " From ec8e3c695f2dce080bde569746180300e91084a3 Mon Sep 17 00:00:00 2001 From: zspo Date: Sat, 13 Apr 2024 22:52:36 +0800 Subject: [PATCH 036/413] [Bugfix] fix_log_time_in_metrics (#4050) --- vllm/engine/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 02560907a1282..04e27e69ce0f3 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -130,7 +130,7 @@ class StatLogger: def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: # Metadata for logging locally. - self.last_local_log = time.monotonic() + self.last_local_log = time.time() self.local_interval = local_interval # Tracked stats over current local logging interval. From 0a430b4ae2763c2f161e3bfb1529acf4685f7caa Mon Sep 17 00:00:00 2001 From: zspo Date: Sat, 13 Apr 2024 22:54:03 +0800 Subject: [PATCH 037/413] [Bugfix] fix_small_bug_in_neuron_executor (#4051) --- vllm/executor/neuron_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index ee8e87432fa67..d45f18e466256 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -25,6 +25,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config + self.cache_config = cache_config assert lora_config is None, "LoRA is not supported for Neuron backend." self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -43,6 +44,7 @@ def _init_worker(self): self.parallel_config, self.scheduler_config, self.device_config, + self.cache_config, ) self.driver_worker.init_device() self.driver_worker.load_model() From 989ae2538df211ca3a31f77ac8e106c5c97c6e53 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Sat, 13 Apr 2024 22:55:05 +0800 Subject: [PATCH 038/413] [Kernel] Add punica dimension for Baichuan-13B (#4053) --- csrc/punica/bgmv/bgmv_config.h | 1 + tests/lora/test_baichuan.py | 2 +- tests/lora/test_punica.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 9b76b98ab3322..d2906914f927e 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 2178266d2e0c8..5ab863eea94b3 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.skip("Requires multiple GPUs") -def test_llama_tensor_parallel_equality(baichuan_lora_files): +def test_baichuan_tensor_parallel_equality(baichuan_lora_files): # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index cab8b44ccd2df..8b174f01d87d4 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -72,6 +72,7 @@ def _lora_ref_impl( 11008, 13824, 14336, + 15360, 22016, 24576, 27392, From 711a000255eac3e034f0b73aa5cc62b45201a571 Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Sat, 13 Apr 2024 20:13:01 -0400 Subject: [PATCH 039/413] [Frontend] [Core] feat: Add model loading using `tensorizer` (#3476) --- .buildkite/test-pipeline.yaml | 3 + docs/source/conf.py | 1 + docs/source/models/engine_args.rst | 3 +- examples/tensorize_vllm_model.py | 254 ++++++++++++++ requirements-cpu.txt | 2 +- requirements-dev.txt | 1 + setup.py | 3 + tests/tensorizer/__init__.py | 0 .../tensorize_vllm_model_for_testing.py | 245 ++++++++++++++ tests/tensorizer/test_tensorizer.py | 302 +++++++++++++++++ vllm/config.py | 74 +++- vllm/engine/arg_utils.py | 45 ++- vllm/engine/llm_engine.py | 8 +- vllm/executor/gpu_executor.py | 23 +- vllm/executor/ray_gpu_executor.py | 6 +- vllm/model_executor/model_loader.py | 61 +++- vllm/model_executor/tensorizer_loader.py | 319 ++++++++++++++++++ vllm/model_executor/weight_utils.py | 34 +- vllm/worker/model_runner.py | 9 +- vllm/worker/worker.py | 9 +- 20 files changed, 1351 insertions(+), 51 deletions(-) create mode 100644 examples/tensorize_vllm_model.py create mode 100644 tests/tensorizer/__init__.py create mode 100644 tests/tensorizer/tensorize_vllm_model_for_testing.py create mode 100644 tests/tensorizer/test_tensorizer.py create mode 100644 vllm/model_executor/tensorizer_loader.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 8d7d6304cf12e..aa4582bbda0c7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -91,6 +91,9 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 +- label: Tensorizer Test + command: apt-get install curl libsodium23 && pytest -v -s tensorizer + - label: Metrics Test command: pytest -v -s metrics diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a8c365ffb3bb..19cc8557a7541 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -83,6 +83,7 @@ "vllm._C", "numpy", "tqdm", + "tensorizer", ] for mock_target in autodoc_mock_imports: diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index d8a7ac72e0175..886a806934c04 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM: Directory to download and load the weights, default to the default cache dir of huggingface. -.. option:: --load-format {auto,pt,safetensors,npcache,dummy} +.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer} The format of the model weights to load. @@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py new file mode 100644 index 0000000000000..3c20a38c7f726 --- /dev/null +++ b/examples/tensorize_vllm_model.py @@ -0,0 +1,254 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True, + help="The directory to serialize the model to. " + "This can be a local directory or S3 URI. The path to where the " + "tensors are saved is a combination of the supplied `dir` and model " + "reference ID. For instance, if `dir` is the serialized directory, " + "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " + "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " + "where `suffix` is given by `--suffix` or a random UUID if not " + "provided.") + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 36d20bc9473ea..5779b38b24e69 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. +triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 96dfda6faf00f..1317e51b2dd11 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,6 +14,7 @@ types-setuptools # testing pytest +tensorizer==2.9.0a0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 9f0814e9f3bff..813321efe796d 100644 --- a/setup.py +++ b/setup.py @@ -405,6 +405,9 @@ def _read_requirements(filename: str) -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, + extras_require={ + "optional": ["tensorizer==2.9.0a1"], + }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer/tensorize_vllm_model_for_testing.py new file mode 100644 index 0000000000000..d0be08329fd64 --- /dev/null +++ b/tests/tensorizer/tensorize_vllm_model_for_testing.py @@ -0,0 +1,245 @@ +import argparse +import dataclasses +import os +import time +import uuid +from functools import partial +from typing import Type + +import torch +import torch.nn as nn +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, + TensorSerializer, stream_io) +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor +from transformers import AutoConfig, PretrainedConfig + +from vllm.distributed import initialize_model_parallel +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.tensorizer_loader import TensorizerArgs + +# yapf conflicts with isort for this docstring +# yapf: disable +""" +tensorize_vllm_model.py is a script that can be used to serialize and +deserialize vLLM models. These models can be loaded using tensorizer directly +to the GPU extremely quickly. Tensor encryption and decryption is also +supported, although libsodium must be installed to use it. Install +vllm with tensorizer support using `pip install vllm[tensorizer]`. + +To serialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + serialize \ + --serialized-directory s3://my-bucket/ \ + --suffix vllm + +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, +and saves it to your S3 bucket. A local directory can also be used. + +You can also encrypt the model weights with a randomly-generated key by +providing a `--keyfile` argument. + +To deserialize a model, you can run something like this: + +python tensorize_vllm_model.py \ + --model EleutherAI/gpt-j-6B \ + --dtype float16 \ + deserialize \ + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + +Which downloads the model tensors from your S3 bucket and deserializes them. +To provide S3 credentials, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, +the OpenAI entrypoint, as arguments for LLM(), or as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. + + +You can also provide a `--keyfile` argument to decrypt the model weights if +they were serialized with encryption. + +For more information on the available arguments, run +`python tensorize_vllm_model.py --help`. +""" + + +def parse_args(): + parser = argparse.ArgumentParser( + description="An example script that can be used to serialize and " + "deserialize vLLM models. These models " + "can be loaded using tensorizer directly to the GPU " + "extremely quickly. Tensor encryption and decryption is " + "also supported, although libsodium must be installed to " + "use it.") + parser = EngineArgs.add_cli_args(parser) + subparsers = parser.add_subparsers(dest='command') + + serialize_parser = subparsers.add_parser( + 'serialize', help="Serialize a model to `--serialized-directory`") + + serialize_parser.add_argument( + "--suffix", + type=str, + required=False, + help=( + "The suffix to append to the serialized model directory, which is " + "used to construct the location of the serialized model tensors, " + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " + "`--suffix` is `v1`, the serialized model tensors will be " + "saved to " + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " + "If none is provided, a random UUID will be used.")) + serialize_parser.add_argument( + "--serialized-directory", + type=str, + required=True) + + serialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Encrypt the model weights with a randomly-generated binary key," + " and save the key at this path")) + + deserialize_parser = subparsers.add_parser( + 'deserialize', + help=("Deserialize a model from `--path-to-tensors`" + " to verify it can be loaded and used.")) + + deserialize_parser.add_argument( + "--path-to-tensors", + type=str, + required=True, + help="The local path or S3 URI to the model tensors to deserialize. ") + + deserialize_parser.add_argument( + "--keyfile", + type=str, + required=False, + help=("Path to a binary key to use to decrypt the model weights," + " if the model was serialized with encryption")) + + return parser.parse_args() + + +def make_model_contiguous(model): + # Ensure tensors are saved in memory contiguously + for param in model.parameters(): + param.data = param.data.contiguous() + + +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def serialize(): + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + + model = (engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() if keyfile else None + if keyfile: + with _write_stream(keyfile) as stream: + stream.write(encryption_params.key) + + with _write_stream(model_path) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + + print("Serialization complete. Model tensors saved to", model_path) + if keyfile: + print("Key saved to", keyfile) + + +def deserialize(): + config = AutoConfig.from_pretrained(model_ref) + + with no_init_or_tensor(): + model_class = _get_vllm_model_architecture(config) + model = model_class(config) + + before_mem = get_mem_usage() + start = time.time() + + if keyfile: + with _read_stream(keyfile) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + tensorizer_args.deserializer_params['encryption'] = \ + decryption_params + + with (_read_stream(model_path)) as stream, TensorDeserializer( + stream, **tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(model) + end = time.time() + + # Brag about how fast we are. + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + print( + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + ) + print(f"Memory usage before: {before_mem}") + print(f"Memory usage after: {after_mem}") + + return model + + +args = parse_args() + +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") + or None) +s3_secret_access_key = (args.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) + +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) + +_read_stream, _write_stream = (partial( + stream_io.open_stream, + mode=mode, + s3_access_key_id=s3_access_key_id, + s3_secret_access_key=s3_secret_access_key, + s3_endpoint=s3_endpoint, +) for mode in ("rb", "wb+")) + +model_ref = args.model + +model_name = model_ref.split("/")[1] + +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = "8080" + +torch.distributed.init_process_group(world_size=1, rank=0) +initialize_model_parallel() + +keyfile = args.keyfile if args.keyfile else None + +if args.command == "serialize": + input_dir = args.serialized_directory.rstrip('/') + suffix = args.suffix if args.suffix else uuid.uuid4().hex + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" + model_path = f"{base_path}/model.tensors" + serialize() +elif args.command == "deserialize": + tensorizer_args = TensorizerArgs.from_cli_args(args) + model_path = args.path_to_tensors + deserialize() +else: + raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer/test_tensorizer.py new file mode 100644 index 0000000000000..2ab893e95da9c --- /dev/null +++ b/tests/tensorizer/test_tensorizer.py @@ -0,0 +1,302 @@ +import gc +import subprocess +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from tests.entrypoints.test_openai_server import ServerRunner +from vllm import SamplingParams +from vllm.config import TensorizerConfig +from vllm.model_executor.tensorizer_loader import ( + EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, + load_with_tensorizer, open_stream) + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) + +model_ref = "facebook/opt-125m" + + +def is_curl_installed(): + try: + subprocess.check_call(['curl', '--version']) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + return False + + +@pytest.fixture(autouse=True) +def tensorizer_config(): + config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + return config + + +@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') +def test_load_with_tensorizer(mock_agent, tensorizer_config): + mock_linear_method = MagicMock() + mock_agent_instance = mock_agent.return_value + mock_agent_instance.deserialize.return_value = MagicMock() + + result = load_with_tensorizer(tensorizer_config, + linear_method=mock_linear_method) + + mock_agent.assert_called_once_with(tensorizer_config, + linear_method=mock_linear_method) + mock_agent_instance.deserialize.assert_called_once() + assert result == mock_agent_instance.deserialize.return_value + + +def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = True + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is True + + +def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): + tensorizer_config.vllm_tensorized = False + + result = is_vllm_serialized_tensorizer(tensorizer_config) + + assert result is False + + +def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_can_deserialize_s3(vllm_runner): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + loaded_hf_model = vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + ) + + deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) + + assert deserialized_outputs + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_deserialized_encrypted_vllm_model_has_same_outputs( + vllm_runner, tmp_path): + vllm_model = vllm_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + key_path = tmp_path / (model_ref + ".key") + outputs = vllm_model.generate(prompts, sampling_params) + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + + encryption_params = EncryptionParams.random() + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + with open_stream(key_path, "wb+") as stream: + stream.write(encryption_params.key) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + encryption_keyfile=key_path, + num_readers=1, + vllm_tensorized=True) + + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + # Assumes SamplingParams being seeded ensures the outputs are deterministic + assert outputs == deserialized_outputs + + +def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, + tmp_path): + hf_model = hf_runner(model_ref) + model_path = tmp_path / (model_ref + ".tensors") + max_tokens = 50 + outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(hf_model.model) + del hf_model + gc.collect() + torch.cuda.empty_cache() + loaded_hf_model = vllm_runner(model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False) + + deserialized_outputs = loaded_hf_model.generate_greedy( + prompts, max_tokens=max_tokens) + + assert outputs == deserialized_outputs + + +def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): + from huggingface_hub import snapshot_download + + from examples.multilora_inference import (create_test_prompts, + process_requests) + + model_ref = "meta-llama/Llama-2-7b-hf" + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + + # Serialize model before deserializing and binding LoRA adapters + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") + model = (vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(model) + del vllm_model, model + gc.collect() + torch.cuda.empty_cache() + loaded_vllm_model = vllm_runner( + model_ref, + tensorizer_uri=model_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=True, + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=50, + max_model_len=1000, + ) + process_requests(loaded_vllm_model.model.llm_engine, test_prompts) + + assert loaded_vllm_model + + +def test_load_without_tensorizer_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, tensorizer_uri="test") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorize_vllm_model(tmp_path): + # Test serialize command + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + # Test deserialize command + deserialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", + path_to_tensors + ] + result = subprocess.run(deserialize_args, capture_output=True, text=True) + assert result.returncode == 0, (f"Deserialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_openai_apiserver_with_tensorizer(tmp_path): + ## Serialize model + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--dtype", "float16", "serialize", "--serialized-directory", + tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + print(result.stdout) # Print the output of the serialize command + + assert result.returncode == 0, (f"Serialize command failed with output:" + f"\n{result.stdout}\n{result.stderr}") + + path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + + ## Start OpenAI API server + openai_args = [ + "--model", model_ref, "--dtype", "float16", "--load-format", + "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", + "--port", "8000" + ] + + server = ServerRunner.remote(openai_args) + + print("Server ready.") + assert server.ready.remote() + + +def test_raise_value_error_on_invalid_load_format(vllm_runner): + with pytest.raises(ValueError): + vllm_runner(model_ref, + load_format="safetensors", + tensorizer_uri="test") + + +def test_tensorizer_with_tp(vllm_runner): + with pytest.raises(ValueError): + model_ref = "EleutherAI/pythia-1.4b" + tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" + + vllm_runner( + model_ref, + tensorizer_uri=tensorized_path, + load_format="tensorizer", + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + tensor_parallel_size=2, + ) + + +@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") +def test_tensorizer_warn_quant(tmp_path): + model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + serialize_args = [ + "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", + model_ref, "--quantization", "gptq", "--tensorizer-uri", "test", + "serialize", "--serialized-directory", tmp_path, "--suffix", "tests" + ] + result = subprocess.run(serialize_args, capture_output=True, text=True) + assert 'PerformanceWarning' in result.stderr diff --git a/vllm/config.py b/vllm/config.py index bbda4ecf3cc56..dce2944b2ee8a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,8 @@ import enum +import io import json import os +import typing from dataclasses import dataclass, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union @@ -16,6 +18,8 @@ if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.model_executor.tensorizer_loader import TensorizerArgs + logger = init_logger(__name__) _GB = 1 << 30 @@ -139,13 +143,14 @@ def __init__( def _verify_load_format(self) -> None: load_format = self.load_format.lower() supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" + "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer" ] rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or " + "'dummy'.") if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ f for f in supported_load_format @@ -882,6 +887,65 @@ def get_image_input_enum_type( f"{[x.name for x in cls.ImageInputType]}.") from e +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[torch.nn.Module] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Union[str, torch.dtype] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + from vllm.model_executor.tensorizer_loader import TensorizerArgs + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config) -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + from vllm.model_executor.tensorizer_loader import ( + tensorizer_warning) + tensorizer_warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + if (model_config.load_format != "tensorizer" + and self.tensorizer_uri is not None): + raise ValueError( + "A tensorizer uri was passed for tensorizer loading, but the " + f"load format was set to {model_config.load_format}. " + "Please set the load format to 'tensorizer' to use " + f"tensorizer args.") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, @@ -1029,6 +1093,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1036,6 +1101,11 @@ def __post_init__(self): self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) + self.tensorizer_config.verify_with_model_config(self.model_config) + if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index daefddc01b431..831a03be65f61 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,12 +1,15 @@ import argparse import dataclasses +import io +import os from dataclasses import dataclass -from typing import Optional +from typing import BinaryIO, Optional, Union from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig, - VisionLanguageConfig) + SpeculativeConfig, TensorizerConfig, + TokenizerPoolConfig, VisionLanguageConfig) +from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -58,12 +61,22 @@ class EngineArgs: num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 + # Tensorizer configuration parameters + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] = None + vllm_tensorized: bool = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None - scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False @@ -135,7 +148,9 @@ def add_cli_args( '--load-format', type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + choices=[ + 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' + ], help='The format of the model weights to load. ' '"auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' @@ -145,7 +160,10 @@ def add_cli_args( '"npcache" will load the weights in pytorch format and store ' 'a numpy cache to speed up the loading. ' '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + 'which is mainly for profiling.' + '"tensorizer" will load the weights using tensorizer from CoreWeave' + 'which assumes tensorizer_uri is set to the location of the ' + 'serialized weights.') parser.add_argument( '--dtype', type=str, @@ -403,6 +421,7 @@ def add_cli_args( default=None, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding') + parser = TensorizerArgs.add_cli_args(parser) return parser @classmethod @@ -465,6 +484,17 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + tensorizer_config = TensorizerConfig( + tensorizer_uri=self.tensorizer_uri, + vllm_tensorized=self.vllm_tensorized, + verify_hash=self.verify_hash, + num_readers=self.num_readers, + encryption_keyfile=self.encryption_keyfile, + s3_access_key_id=self.s3_access_key_id, + s3_secret_access_key=self.s3_secret_access_key, + s3_endpoint=self.s3_endpoint, + ) + if self.image_input_type: if (not self.image_token_id or not self.image_input_shape or not self.image_feature_size): @@ -488,7 +518,8 @@ def create_engine_config(self, ) -> EngineConfig: device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, - speculative_config=speculative_config) + speculative_config=speculative_config, + tensorizer_config=tensorizer_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a91629a630591..8c37c5a9d6ee9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +74,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -110,6 +111,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config self.log_stats = log_stats self._init_tokenizer() @@ -125,6 +127,7 @@ def __init__( lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + tensorizer_config=tensorizer_config, ) self._initialize_kv_caches() @@ -264,6 +267,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.tensorizer_config: + self.tensorizer_config.verify_with_parallel_config( + self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index f20221a0b941a..30577ecf62faa 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -2,7 +2,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,17 +15,14 @@ class GPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config @@ -33,6 +30,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for GPU backend" @@ -61,6 +59,7 @@ def _init_worker(self): distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) self.driver_worker.init_device() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b937693c92257..28dc3e0db312a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -42,6 +42,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -50,6 +51,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.tensorizer_config = tensorizer_config assert (not speculative_config ), "Speculative decoding not yet supported for RayGPU backend." @@ -171,6 +173,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, + tensorizer_config=self.tensorizer_config, )) # Initialize the driver worker with the Worker class. @@ -187,6 +190,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0f..c70ca48bca70a 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -3,11 +3,14 @@ from typing import Tuple, Type import torch -import torch.nn as nn +from torch import nn from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llava import LlavaForConditionalGeneration +from vllm.model_executor.tensorizer_loader import ( + ParameterizedLoadFormat, is_vllm_serialized_tensorizer, + load_with_tensorizer) from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: lora_config = kwargs.get("lora_config", None) vision_language_config = kwargs.get("vision_language_config", None) + tensorizer_config = kwargs.get("tensorizer_config", None) model_class = _get_model_architecture(model_config)[0] # Get the (maybe quantized) linear method. @@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif model_class in _VISION_MODEL_CLASSES: + extra_kwargs["vision_language_config"] = vision_language_config + with torch.device(device_config.device): - if hasattr(model_class, "supported_lora_modules"): - model = model_class(model_config.hf_config, linear_method, - lora_config) - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - else: - if model_class not in _VISION_MODEL_CLASSES: - model = model_class(model_config.hf_config, linear_method) - else: - model = model_class(model_config.hf_config, - vision_language_config, linear_method) + if (model_config.load_format == "tensorizer" + and is_vllm_serialized_tensorizer(tensorizer_config)): + extra_kwargs["linear_method"] = linear_method + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + model = model_class(config=model_config.hf_config, + linear_method=linear_method, + **extra_kwargs) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + if model_config.load_format == "tensorizer": + # Provide a dynamic load format for `model.load_weights` + # to retain tensorizer args from CLI. + model_config.load_format = ParameterizedLoadFormat( + model_config.load_format) + model_config.load_format.params = ( + tensorizer_config._construct_tensorizer_args()) + + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py new file mode 100644 index 0000000000000..ed3ad9e2ffa15 --- /dev/null +++ b/vllm/model_executor/tensorizer_loader.py @@ -0,0 +1,319 @@ +import argparse +import dataclasses +import io +import os +import time +import typing +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.config import TensorizerConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +tensorizer_load_fail = False + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) +except ImportError: + tensorizer_load_fail = True + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor' +] + +logger = init_logger(__name__) + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +def tensorizer_warning(message: str): + return warnings.warn(message, category=PerformanceWarning, stacklevel=2) + + +def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: + if tensorizer_config is None: + return False + return tensorizer_config.vllm_tensorized + + +class ParameterizedLoadFormat(str): + __slots__ = "params" + + +class PerformanceWarning(UserWarning): + + def __str__(self): + return (f"{super().__str__()}" + " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS" + " environment variable to hide this)") + + +if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower() + not in ("", "0", "n", "no", "off", "disable")): + warnings.simplefilter("ignore", category=PerformanceWarning) + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is 1. This greatly increases + performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = (self.s3_access_key_id + or os.environ.get("S3_ACCESS_KEY_ID")) or None + self.s3_secret_access_key = ( + self.s3_secret_access_key + or os.environ.get("S3_SECRET_ACCESS_KEY")) or None + self.s3_endpoint = (self.s3_endpoint + or os.environ.get("S3_ENDPOINT_URL")) or None + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + # Omitting self.dtype and self.device as this behaves weirdly + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Tensorizer CLI arguments""" + + # Create the argument group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + '--load-format=tensorizer')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=1, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + group.add_argument( + "--vllm-tensorized", + action="store_true", + help="If enabled, indicates that the serialized model is a vLLM " + "model. This is used to determine the behavior of the " + "TensorDeserializer when loading tensors from a " + "serialized model.") + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, + linear_method: LinearMethodBase, **extra_kwargs): + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.extra_kwargs = extra_kwargs + if extra_kwargs.get("linear_method", None) is not None: + self.linear_method = extra_kwargs["linear_method"] + else: + self.linear_method = linear_method + self.model = self._init_model() + + if tensorizer_load_fail: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`.") + + def _init_model(self): + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + with no_init_or_tensor(): + return self.tensorizer_config.model_class( + config=model_args, + linear_method=self.linear_method, + **self.extra_kwargs) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < + child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + # Lazy load the tensors from S3 into the model. + start = time.perf_counter() + with open_stream( + self.tensorizer_args.tensorizer_uri, + mode="rb", + **self.tensorizer_args.stream_params, + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info(f"Deserialized {total_bytes_str} in " + f"{end - start:0.2f}s, {per_second}/s") + logger.info(f"Memory usage before: {before_mem}") + logger.info(f"Memory usage after: {after_mem}") + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + return self.model.eval() diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 0961478930d74..08425604f0511 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -5,7 +5,7 @@ import json import os from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union import filelock import huggingface_hub.constants @@ -161,7 +161,8 @@ def prepare_hf_model_weights( revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) + is_local = os.path.isdir(model_name_or_path) \ + and load_format != "tensorizer" use_safetensors = False # Some quantized models use .pt files for storing the weights. if load_format == "auto": @@ -173,13 +174,15 @@ def prepare_hf_model_weights( allow_patterns = ["*.pt"] elif load_format == "npcache": allow_patterns = ["*.bin"] + elif load_format == "tensorizer": + allow_patterns = ["*.tensors"] else: raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: allow_patterns += ["*.pt"] - if not is_local: + if not is_local and load_format != "tensorizer": # Before we download we look at that is available: fs = HfFileSystem() file_list = fs.ls(model_name_or_path, detail=False, revision=revision) @@ -224,6 +227,9 @@ def prepare_hf_model_weights( if not any(f.endswith(x) for x in blacklist) ] + if load_format == "tensorizer": + return hf_folder, hf_weights_files, use_safetensors + if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") @@ -234,7 +240,7 @@ def prepare_hf_model_weights( def hf_model_weights_iterator( model_name_or_path: str, cache_dir: Optional[str] = None, - load_format: str = "auto", + load_format: Union[Tuple, str] = "auto", revision: Optional[str] = None, fall_back_to_pt: Optional[bool] = True, ) -> Iterator[Tuple[str, torch.Tensor]]: @@ -277,6 +283,26 @@ def hf_model_weights_iterator( with open(param_path, "rb") as f: param = np.load(f) yield name, torch.from_numpy(param) + elif load_format == "tensorizer": + from vllm.model_executor.tensorizer_loader import (TensorDeserializer, + open_stream, + tensorizer_warning) + tensorizer_args = load_format.params + tensorizer_warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state elif use_safetensors: for st_file in hf_weights_files: with safe_open(st_file, framework="pt") as f: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47ad8f0c9b78b..7dbe14ead0976 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) + SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -111,11 +112,13 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -158,7 +161,9 @@ def load_model(self) -> None: lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + tensorizer_config=self.tensorizer_config, + ) self.model_memory_usage = m.consumed_memory logger.info(f"Loading model weights took " diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3f0b2fd83f3e5..82491c6df6616 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -42,6 +43,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + tensorizer_config: Optional[TensorizerConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -53,6 +55,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.tensorizer_config = tensorizer_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -70,7 +73,9 @@ def __init__( lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + tensorizer_config=tensorizer_config, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine = None From 2cd6b4f3625466eb5849bcfd7a6fb316735adab8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 13 Apr 2024 23:40:21 -0700 Subject: [PATCH 040/413] [Core] avoid too many cuda context by caching p2p test (#4021) --- .../device_communicators/custom_all_reduce.py | 53 +++++------ vllm/distributed/parallel_state.py | 9 ++ vllm/distributed/utils.py | 87 ++++++++++++++++++- 3 files changed, 116 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 84238d2e46076..f83caef879da3 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -42,12 +42,17 @@ def init_custom_ar() -> None: " disable_custom_all_reduce=True explicitly.", world_size, str(_SUPPORTED_WORLD_SIZES)) return - if not _can_p2p(rank, world_size): + num_dev = torch.cuda.device_count() + # note: num dev can be larger than world_size if we're only using + # first few GPUs + if num_dev < world_size: logger.warn( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return + "Cannot test GPU P2P because not all GPUs are visible to the " + "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" + " is set.") + return False + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported full_nvlink = _is_full_nvlink(rank, world_size) if world_size > 2 and not full_nvlink: logger.warn( @@ -55,6 +60,15 @@ def init_custom_ar() -> None: " than two PCIe-only GPUs. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.") return + # test P2P capability + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warn( + "Custom allreduce is disabled because your platform lacks GPU P2P" + " capability or P2P test failed. To silence this warning, specify" + " disable_custom_all_reduce=True explicitly.") + return _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) @@ -143,40 +157,15 @@ def _is_full_nvlink(rank, world_size): def _can_p2p(rank: int, world_size: int) -> bool: - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warn( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return False + from vllm.distributed.utils import gpu_p2p_access_check for i in range(world_size): if i == rank: continue - if not torch.cuda.can_device_access_peer(rank, i): - return False - # on some platforms, P2P support might be buggy and we need - # additional checks. See also: - # https://github.com/vllm-project/vllm/issues/2728 - if not _can_actually_p2p(rank, i): + if not gpu_p2p_access_check(rank, i): return False return True -# code partly borrowed from -# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 -# License: MIT -def _can_actually_p2p(idx_a, idx_b): - dev_i = f"cuda:{idx_a}" - dev_j = f"cuda:{idx_b}" - a = torch.randn(5, device=dev_i) + 123.0 - b = a.to(dev_j) - c = b.to(dev_i) - return torch.all(a == c) - - class CustomAllreduce: # max_size: max supported allreduce size diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1258bf58cb453..e2473736375e0 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,6 +41,13 @@ # source rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None +_LOCAL_RANK = -1 + + +def get_local_rank(): + global _LOCAL_RANK + return _LOCAL_RANK + def init_distributed_environment( world_size: int = -1, @@ -66,6 +73,8 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, backend="gloo") + global _LOCAL_RANK + _LOCAL_RANK = local_rank def initialize_model_parallel( diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 0cd420c8e11b5..e0a871ebe1756 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,9 +2,18 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import Sequence +import json +import os +from typing import Dict, Optional, Sequence import torch +import torch.distributed as dist + +from vllm.logger import init_logger + +from .parallel_state import get_cpu_world_group, get_local_rank + +logger = init_logger(__name__) def ensure_divisibility(numerator, denominator): @@ -46,3 +55,79 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list + + +# code partly borrowed from +# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 +# License: MIT +def _can_actually_p2p(idx_a, idx_b): + dev_i = f"cuda:{idx_a}" + dev_j = f"cuda:{idx_b}" + a = torch.randn(5, device=dev_i) + 123.0 + b = a.to(dev_j) + c = b.to(dev_i) + return torch.all(a == c).cpu().item() + + +# why do we need this cache? +# 1. we can have runtime checks for P2P access, where every process checks +# P2P access to all other GPUs. Unfortunately, the test might cost many +# (world_size * world_size) cuda context, and reduce the memory available +# for the model. see https://github.com/vllm-project/vllm/issues/3821 +# 2. alternatively, we can have a p2p map that is generated by the master +# process and broadcasted to all other processes. This still requires +# #world_size of cuda context, belonging to the master process, on each GPU. +# 3. we can have a cache file, that records the p2p access status. The first +# time the master process checks the p2p access, it will generate the cache +# file, at the cost of #world_size of cuda context. Later on, all processes +# can read the cache file to check the p2p access status without any cost of +# additional cuda context. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(i: int, j: int) -> bool: + """Check if GPU i can access GPU j.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{i}->{j}"] + + is_distributed = dist.is_initialized() + + num_dev = torch.cuda.device_count() + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + path = os.path.expanduser( + f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + if (not is_distributed or get_local_rank() == 0) \ + and (not os.path.exists(path)): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info(f"generating GPU P2P access cache for in {path}") + cache = {} + for _i in range(num_dev): + for _j in range(num_dev): + # on some platforms, P2P support might be buggy and we need + # additional checks. See also: + # https://github.com/vllm-project/vllm/issues/2728 + cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer( + _i, _j) and _can_actually_p2p(_i, _j) + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + cpu_world_group = get_cpu_world_group() + dist.barrier(cpu_world_group) + logger.info(f"reading GPU P2P access cache from {path}") + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{i}->{j}"] From 563c54f760f870ae44c7662c8a9ec3a223a3c4c4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 14 Apr 2024 22:12:42 +0100 Subject: [PATCH 041/413] [BugFix] Fix tensorizer extra in setup.py (#4072) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 813321efe796d..19a9150ad2e64 100644 --- a/setup.py +++ b/setup.py @@ -406,7 +406,7 @@ def _read_requirements(filename: str) -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "optional": ["tensorizer==2.9.0a1"], + "tensorizer": ["tensorizer==2.9.0a1"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, From aceb17cf2d629175a484c3d9df355f44bd334cb3 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sun, 14 Apr 2024 14:35:55 -0700 Subject: [PATCH 042/413] [Docs] document that mixtral 8x22b is supported (#4073) --- README.md | 2 +- docs/source/models/supported_models.rst | 36 ++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index d53227b82d87a..8434c11883341 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) +- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c09b0ff250437..5e5ce871f61dd 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -30,23 +30,23 @@ Alongside each architecture, we include some popular models that use it. * - :code:`CohereForCausalLM` - Command-R - :code:`CohereForAI/c4ai-command-r-v01`, etc. - - + - * - :code:`DbrxForCausalLM` - DBRX - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. - - + - * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - - + - * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - - + - * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -54,19 +54,19 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. - - + - * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - - + - * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. - - + - * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - - + - * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. @@ -93,32 +93,32 @@ Alongside each architecture, we include some popular models that use it. - ✅︎ * - :code:`MixtralForCausalLM` - Mixtral-8x7B, Mixtral-8x7B-Instruct - - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. + - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc. - ✅︎ * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - - + - * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. - - + - * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - - + - * - :code:`OrionForCausalLM` - Orion - :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc. - - + - * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - - + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - - + - * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. @@ -126,11 +126,11 @@ Alongside each architecture, we include some popular models that use it. * - :code:`Qwen2MoeForCausalLM` - Qwen2MoE - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - - + - * - :code:`StableLmForCausalLM` - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - - + - If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. From 8db1bf32f8924403c6a845b5ce71ba0f14beb038 Mon Sep 17 00:00:00 2001 From: Roy Date: Mon, 15 Apr 2024 08:43:54 +0800 Subject: [PATCH 043/413] [Misc] Upgrade triton to 2.2.0 (#4061) --- requirements-cpu.txt | 2 +- requirements-cuda.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 5779b38b24e69..e911ad03295f0 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -3,4 +3,4 @@ # Dependencies for x86_64 CPUs torch == 2.2.1+cpu -triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6ee75e8139c04..c6d2cd46aee54 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,3 @@ pynvml == 11.5.0 vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 -triton >= 2.1.0 From e11e2007368b22fce05b9ecdf00dd48eda471f9e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 14 Apr 2024 21:50:08 -0700 Subject: [PATCH 044/413] [Bugfix] Fix filelock version requirement (#4075) --- requirements-common.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index c96f9c9937fb0..90a3bc8abc1db 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer outlines == 0.0.34 # Requires torch >= 2.1.0 -typing_extensions \ No newline at end of file +typing_extensions +filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 From 0003e9154bf1091d0de7ca7a6c7f1253df1eca5b Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Mon, 15 Apr 2024 23:35:55 +0800 Subject: [PATCH 045/413] [Misc][Minor] Fix CPU block num log in CPUExecutor. (#4088) --- vllm/executor/cpu_executor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 33e67d8b3eec2..e63a88be7868f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -74,7 +74,10 @@ def initialize_cache(self, num_gpu_blocks: int, # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. - logger.info(f"# CPU blocks: {num_cpu_blocks}") + # NOTE: `cpu block` for CPU backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info(f"# CPU blocks: {num_gpu_blocks}") self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, From eb46fbfda25348422918c4a876e17aef05fc5e34 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 15 Apr 2024 13:05:09 -0700 Subject: [PATCH 046/413] [Core] Simplifications to executor classes (#4071) --- vllm/executor/cpu_executor.py | 31 +++++++++------------------- vllm/executor/executor_base.py | 27 +++++++++++++++++------- vllm/executor/gpu_executor.py | 32 ++++------------------------- vllm/executor/neuron_executor.py | 29 ++++++-------------------- vllm/executor/ray_gpu_executor.py | 34 ++++--------------------------- 5 files changed, 44 insertions(+), 109 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e63a88be7868f..f562e4e0ae3de 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,10 +1,9 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -16,23 +15,13 @@ class CPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: - assert device_config.device_type == "cpu" - assert lora_config is None, "cpu backend doesn't support LoRA" - model_config = _verify_and_get_model_config(model_config) - cache_config = _verify_and_get_cache_config(cache_config) - scheduler_config = _verify_and_get_scheduler_config(scheduler_config) - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config + def _init_executor(self) -> None: + assert self.device_config.device_type == "cpu" + assert self.lora_config is None, "cpu backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.scheduler_config = _verify_and_get_scheduler_config( + self.scheduler_config) # Instantiate the worker and load the model to CPU. self._init_worker() @@ -99,7 +88,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbfbfc689c99f..bbb6ec80f7b7e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -16,7 +16,6 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - @abstractmethod def __init__( self, model_config: ModelConfig, @@ -27,8 +26,23 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: - raise NotImplementedError + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config + + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + pass @abstractmethod def determine_num_available_blocks(self) -> Tuple[int, int]: @@ -71,7 +85,7 @@ def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @abstractmethod @@ -94,8 +108,7 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" - raise NotImplementedError + self.check_health() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 30577ecf62faa..bae509f48025b 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,24 +12,8 @@ class GPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig]) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for GPU backend" # Instantiate the worker and load the model to GPU. @@ -103,7 +84,7 @@ def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: @@ -127,8 +108,3 @@ async def execute_model_async( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy) return output - - async def check_health_async(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index d45f18e466256..273b17a927efd 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -13,24 +10,10 @@ class NeuronExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - assert lora_config is None, "LoRA is not supported for Neuron backend." - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not self.speculative_config ), "Speculative decoding not yet supported for Neuron backend." # Instantiate the worker and load the model to the device. @@ -80,7 +63,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 28dc3e0db312a..5db2f3f652532 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,11 +3,8 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -32,27 +29,8 @@ class RayGPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray @@ -273,7 +251,7 @@ def remove_lora(self, lora_id: int) -> bool: lora_id=lora_id, ) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self._run_workers("list_loras") def _run_workers( @@ -416,7 +394,3 @@ async def execute_model_async( # Only the driver worker returns the sampling results. output = all_outputs[0] return output - - async def check_health_async(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() From d619ae2d19c41d9aa8f68fa0e5e32cc410dc2522 Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Mon, 15 Apr 2024 16:28:25 -0400 Subject: [PATCH 047/413] [Doc] Add better clarity for tensorizer usage (#4090) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- docs/source/models/engine_args.rst | 2 +- examples/tensorize_vllm_model.py | 60 +++++++++++++++++------- vllm/model_executor/tensorizer_loader.py | 6 +-- 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index 886a806934c04..235cb4e128c99 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -45,7 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM: * "safetensors" will load the weights in the safetensors format. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "dummy" will initialize the weights with random values, mainly for profiling. - * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`. + * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_ See `examples/tensorize_vllm_model.py `_ to serialize a vLLM model, and for more information. .. option:: --dtype {auto,half,float16,bfloat16,float,float32} diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index 3c20a38c7f726..8cf8be09d0b9c 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -23,14 +23,16 @@ # yapf: disable """ tensorize_vllm_model.py is a script that can be used to serialize and -deserialize vLLM models. These models can be loaded using tensorizer directly -to the GPU extremely quickly. Tensor encryption and decryption is also -supported, although libsodium must be installed to use it. Install -vllm with tensorizer support using `pip install vllm[tensorizer]`. +deserialize vLLM models. These models can be loaded using tensorizer +to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, +or locally. Tensor encryption and decryption is also supported, although +libsodium must be installed to use it. Install vllm with tensorizer support +using `pip install vllm[tensorizer]`. -To serialize a model, you can run something like this: +To serialize a model, install vLLM from source, then run something +like this from the root level of this repository: -python tensorize_vllm_model.py \ +python -m examples.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ serialize \ @@ -38,31 +40,57 @@ --suffix vllm Which downloads the model from HuggingFace, loads it into vLLM, serializes it, -and saves it to your S3 bucket. A local directory can also be used. +and saves it to your S3 bucket. A local directory can also be used. This +assumes your S3 credentials are specified as environment variables +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. +To provide S3 credentials directly, you can provide `--s3-access-key-id` and +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this +script. You can also encrypt the model weights with a randomly-generated key by providing a `--keyfile` argument. -To deserialize a model, you can run something like this: +To deserialize a model, you can run something like this from the root +level of this repository: -python tensorize_vllm_model.py \ +python -m examples.tensorize_vllm_model \ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors Which downloads the model tensors from your S3 bucket and deserializes them. -To provide S3 credentials, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, -the OpenAI entrypoint, as arguments for LLM(), or as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. - You can also provide a `--keyfile` argument to decrypt the model weights if they were serialized with encryption. -For more information on the available arguments, run -`python tensorize_vllm_model.py --help`. +For more information on the available arguments for serializing, run +`python -m examples.tensorize_vllm_model serialize --help`. + +Or for deserializing: + +`python -m examples.tensorize_vllm_model deserialize --help`. + +Once a model is serialized, it can be used to load the model when running the +OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing +the `--tensorizer-uri` CLI argument that is functionally the same as the +`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to +signify that the model to be deserialized is a vLLM model, rather than a +HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer +in the same inference server, albeit without the speed optimizations. To +deserialize an encrypted file, the `--encryption-keyfile` argument can be used +to provide the path to the keyfile used to encrypt the model weights. For +information on all the arguments that can be used to configure tensorizer's +deserialization, check out the tensorizer options argument group in the +`vllm/entrypoints/openai/api_server.py` script with `--help`. + +Tensorizer can also be invoked with the `LLM` class directly to load models: + + llm = LLM(model="facebook/opt-125m", + load_format="tensorizer", + tensorizer_uri=path_to_opt_tensors, + num_readers=3, + vllm_tensorized=True) """ diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py index ed3ad9e2ffa15..8550cc97aefe8 100644 --- a/vllm/model_executor/tensorizer_loader.py +++ b/vllm/model_executor/tensorizer_loader.py @@ -126,7 +126,6 @@ def __post_init__(self): "s3_endpoint": self.s3_endpoint, } - # Omitting self.dtype and self.device as this behaves weirdly self.deserializer_params = { "verify_hash": self.verify_hash, "encryption": self.encryption_keyfile, @@ -145,7 +144,7 @@ def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Tensorizer CLI arguments""" - # Create the argument group + # Tensorizer options arg group group = parser.add_argument_group( 'tensorizer options', description=('Options for configuring the behavior of the' @@ -205,9 +204,7 @@ def add_cli_args( @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": - # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] - # Set the attributes from the parsed arguments. tensorizer_args = cls(**{ attr: getattr(args, attr) for attr in attrs if hasattr(args, attr) @@ -291,7 +288,6 @@ def deserialize(self): nn.Module: The deserialized model. """ before_mem = get_mem_usage() - # Lazy load the tensors from S3 into the model. start = time.perf_counter() with open_stream( self.tensorizer_args.tensorizer_uri, From 4695397dcfef693a0a10f1eb8bf77ea905c54829 Mon Sep 17 00:00:00 2001 From: Ricky Xu Date: Mon, 15 Apr 2024 14:24:45 -0700 Subject: [PATCH 048/413] [Bugfix] Fix ray workers profiling with nsight (#4095) --- vllm/executor/ray_gpu_executor.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5db2f3f652532..7aca5e36107aa 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -48,6 +48,21 @@ def _init_executor(self) -> None: if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: @@ -63,6 +78,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The remaining workers are the actual ray actors. self.workers: List[RayWorkerVllm] = [] + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + # Create the workers. driver_ip = get_ip() for bundle_id, bundle in enumerate(placement_group.bundle_specs): From 37e84a403d6d11b670a42e84153204cd8b76b849 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 Apr 2024 06:47:31 +0900 Subject: [PATCH 049/413] [Typing] Fix Sequence type GenericAlias only available after Python 3.9. (#4092) --- vllm/core/block_manager_v1.py | 5 +++-- vllm/core/block_manager_v2.py | 2 +- vllm/core/interfaces.py | 2 +- vllm/utils.py | 7 ++++--- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e391a3b1e5a33..be093922b84f2 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,9 +1,10 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod -from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Set from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 19f0cf415eb34..6339a6baf4161 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,6 +1,6 @@ """A block manager that manages token blocks.""" -from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional +from typing import Sequence as GenericSequence from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index c1f68a2e891bf..56c2c5995c38b 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,7 +1,7 @@ import enum from abc import ABC, abstractmethod -from collections.abc import Sequence as GenericSequence from typing import Dict, List +from typing import Sequence as GenericSequence from vllm.sequence import Sequence, SequenceGroup diff --git a/vllm/utils.py b/vllm/utils.py index 4c0dc9ca729a9..aad62516ad1b9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -6,11 +6,12 @@ import subprocess import uuid import warnings -from collections import OrderedDict, defaultdict +from collections import defaultdict from functools import lru_cache, partial from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, - Hashable, List, Optional, Tuple, TypeVar, Union) + Hashable, List, Optional, OrderedDict, Tuple, TypeVar, + Union) import psutil import torch @@ -51,7 +52,7 @@ def reset(self) -> None: class LRUCache(Generic[T]): def __init__(self, capacity: int): - self.cache = OrderedDict[Hashable, T]() + self.cache: OrderedDict[Hashable, T] = OrderedDict() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: From 4e7ee664e201442e24e2298a36a5264b98691626 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 16 Apr 2024 14:24:53 +0900 Subject: [PATCH 050/413] [Core] Fix engine-use-ray broken (#4105) --- tests/async_engine/test_api_server.py | 17 +++++++++++++---- vllm/engine/async_llm_engine.py | 7 +++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 248bfbc8ab5c0..7f57d5cf9b182 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(tokenizer_pool_size: int): +def api_server(tokenizer_pool_size: int, engine_use_ray: bool, + worker_use_ray: bool): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() - uvicorn_process = subprocess.Popen([ + commands = [ sys.executable, "-u", str(script_path), "--model", "facebook/opt-125m", "--host", "127.0.0.1", "--tokenizer-pool-size", str(tokenizer_pool_size) - ]) + ] + if engine_use_ray: + commands.append("--engine-use-ray") + if worker_use_ray: + commands.append("--worker-use-ray") + uvicorn_process = subprocess.Popen(commands) yield uvicorn_process.terminate() @pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) -def test_api_server(api_server, tokenizer_pool_size: int): +@pytest.mark.parametrize("worker_use_ray", [False, True]) +@pytest.mark.parametrize("engine_use_ray", [False, True]) +def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool, + engine_use_ray: bool): """ Run the API server and test it. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f610495135121..1dbf58904541c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -333,8 +333,7 @@ def from_engine_args( if engine_config.device_config.device_type == "neuron": raise NotImplementedError("Neuron is not supported for " "async engine yet.") - elif (engine_config.parallel_config.worker_use_ray - or engine_args.engine_use_ray): + elif engine_config.parallel_config.worker_use_ray: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync @@ -410,8 +409,8 @@ def _init_engine(self, *args, else: # FIXME(woosuk): This is a bit hacky. Be careful when changing the # order of the arguments. - cache_config = args[1] - parallel_config = args[2] + cache_config = kwargs["cache_config"] + parallel_config = kwargs["parallel_config"] if parallel_config.tensor_parallel_size == 1: num_gpus = cache_config.gpu_memory_utilization else: From 05434764cd99990035779cf9a4ed86623b528825 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 16 Apr 2024 08:54:57 +0300 Subject: [PATCH 051/413] LM Format Enforcer Guided Decoding Support (#3868) Co-authored-by: Simon Mo --- requirements-common.txt | 1 + tests/entrypoints/test_guided_processors.py | 42 +++++++- tests/entrypoints/test_openai_server.py | 69 ++++++++---- vllm/config.py | 26 ++++- vllm/engine/arg_utils.py | 18 +++- vllm/engine/llm_engine.py | 10 +- vllm/entrypoints/openai/protocol.py | 12 +++ vllm/entrypoints/openai/serving_chat.py | 6 +- vllm/entrypoints/openai/serving_completion.py | 6 +- .../guided_decoding/__init__.py | 25 +++++ .../lm_format_enforcer_decoding.py | 69 ++++++++++++ .../outlines_decoding.py} | 7 +- .../outlines_logits_processors.py} | 100 +++++++++--------- 13 files changed, 304 insertions(+), 87 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/__init__.py create mode 100644 vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py rename vllm/model_executor/{guided_decoding.py => guided_decoding/outlines_decoding.py} (93%) rename vllm/model_executor/{guided_logits_processors.py => guided_decoding/outlines_logits_processors.py} (70%) diff --git a/requirements-common.txt b/requirements-common.txt index 90a3bc8abc1db..c1614d2537b25 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,6 +11,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer +lm-format-enforcer == 0.9.3 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 5622744566bcc..30f0ad5d8272f 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -1,11 +1,14 @@ # This unit test should be moved to a new # tests/test_guided_decoding directory. - +import pytest import torch from transformers import AutoTokenizer -from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.entrypoints.openai.protocol import CompletionRequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + JSONLogitsProcessor, RegexLogitsProcessor) TEST_SCHEMA = { "type": "object", @@ -73,3 +76,36 @@ def test_guided_logits_processors(): json_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +async def test_guided_logits_processor_black_box(backend: str): + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + regex_request = CompletionRequest(model='test', + prompt=token_ids, + guided_regex=TEST_REGEX) + regex_lp = await get_guided_decoding_logits_processor( + backend, regex_request, tokenizer) + assert regex_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + json_request = CompletionRequest(model='test', + prompt=token_ids, + guided_json=TEST_SCHEMA) + json_lp = await get_guided_decoding_logits_processor( + backend, json_request, tokenizer) + assert json_lp is not None + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7940430b8b654..14e6ee0ffe9d9 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text -async def test_guided_json_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile " @@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): n=3, temperature=1.0, max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) -async def test_guided_json_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_json_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=500, - extra_body=dict(guided_json=TEST_SCHEMA)) + max_tokens=1000, + extra_body=dict(guided_json=TEST_SCHEMA, + guided_decoding_backend=guided_decoding_backend)) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): assert json1["age"] != json2["age"] -async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", n=3, temperature=1.0, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None -async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(TEST_REGEX, ip1) is not None @@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict(guided_regex=TEST_REGEX)) + extra_body=dict(guided_regex=TEST_REGEX, + guided_decoding_backend=guided_decoding_backend)) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(TEST_REGEX, ip2) is not None assert ip1 != ip2 -async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): completion = await client.completions.create( model=MODEL_NAME, prompt="The best language for type-safe systems programming is ", n=2, temperature=1.0, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 2 @@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): assert completion.choices[i].text in TEST_CHOICE -async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice1 = chat_completion.choices[0].message.content assert choice1 in TEST_CHOICE @@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict(guided_choice=TEST_CHOICE)) + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) choice2 = chat_completion.choices[0].message.content assert choice2 in TEST_CHOICE assert choice1 != choice2 -async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict(guided_json=42)) + extra_body=dict(guided_json=42, + guided_decoding_backend=guided_decoding_backend)) messages = [{ "role": "system", diff --git a/vllm/config.py b/vllm/config.py index dce2944b2ee8a..bf31b03b7c6c4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -66,8 +66,8 @@ class ModelConfig: weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. @@ -422,7 +422,7 @@ def verify_with_parallel_config( @dataclass class TokenizerPoolConfig: """Configuration for the tokenizer pool. - + Args: pool_size: Number of tokenizer workers in the pool. pool_type: Type of the pool. @@ -446,9 +446,9 @@ def create_config( tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. - + If tokenizer_pool_size is 0, return None. - + Args: tokenizer_pool_size: Number of tokenizer workers in the pool. tokenizer_pool_type: Type of the pool. @@ -1079,6 +1079,21 @@ def _get_and_verify_max_len( return int(max_model_len) +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' + guided_decoding_backend: str = 'outlines' + + def __post_init__(self): + valid_guided_backends = ['outlines', 'lm-format-enforcer'] + backend = self.guided_decoding_backend + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}") + + @dataclass(frozen=True) class EngineConfig: """Dataclass which contains all engine-related configuration. This @@ -1093,6 +1108,7 @@ class EngineConfig: lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] + decoding_config: Optional[DecodingConfig] tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 831a03be65f61..3de74b0ac28b9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,9 +5,9 @@ from dataclasses import dataclass from typing import BinaryIO, Optional, Union -from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, TensorizerConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -80,6 +80,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + guided_decoding_backend: str = 'outlines' # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None @@ -200,6 +201,13 @@ def add_cli_args( default=EngineArgs.max_model_len, help='model context length. If unspecified, ' 'will be automatically derived from the model.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc)') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', @@ -511,6 +519,9 @@ def create_engine_config(self, ) -> EngineConfig: else: vision_language_config = None + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -519,6 +530,7 @@ def create_engine_config(self, ) -> EngineConfig: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, + decoding_config=decoding_config, tensorizer_config=tensorizer_config) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c37c5a9d6ee9..f06c1d18ace4b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,9 +4,10 @@ from transformers import PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TensorizerConfig, + VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -74,6 +75,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, @@ -100,6 +102,7 @@ def __init__( f"kv_cache_dtype={cache_config.cache_dtype}, " f"quantization_param_path={model_config.quantization_param_path}, " f"device_config={device_config.device}, " + f"decoding_config={decoding_config!r}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -111,6 +114,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.decoding_config = decoding_config or DecodingConfig() self.tensorizer_config = tensorizer_config self.log_stats = log_stats diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f94d22d279cc4..cf779d44c816b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) # doc: end-chat-completion-extra-params @@ -265,6 +271,12 @@ class CompletionRequest(BaseModel): description=( "If specified, the output will follow the context free grammar."), ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) # doc: end-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a03c5dc88108f..c9ed4a9de20f4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -68,9 +68,13 @@ async def create_chat_completion( request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e24aa2489a80f..a71f2d6a4426a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -88,9 +88,13 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + decoding_config = self.engine.engine.decoding_config + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, await self.engine.get_tokenizer())) + guided_decoding_backend, request, await + self.engine.get_tokenizer())) if guided_decode_logit_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000000000..0558d6c95d97b --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,25 @@ +from typing import Optional, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( + get_lm_format_enforcer_guided_decoding_logits_processor) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_guided_decoding_logits_processor( + guided_decoding_backend: str, request: Union[CompletionRequest, + ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + if guided_decoding_backend == 'outlines': + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + if guided_decoding_backend == 'lm-format-enforcer': + return await get_lm_format_enforcer_guided_decoding_logits_processor( + request, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_decoding_backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000000000..0d74a5f8e81ff --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,69 @@ +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest) +from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) +from vllm.sampling_params import LogitsProcessor + + +async def get_lm_format_enforcer_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if request.guided_json: + schema = _normalize_json_schema_object(request.guided_json) + character_level_parser = JsonSchemaParser(schema) + elif request.guided_choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in request.guided_choice]) + elif request.guided_regex: + character_level_parser = RegexParser(request.guided_regex) + elif request.guided_grammar: + # CFG grammar not supported by LMFE, revert to outlines + return await get_outlines_guided_decoding_logits_processor( + request, tokenizer) + elif (request.response_format is not None + and request.response_format.type == "json_object"): + character_level_parser = JsonSchemaParser( + None) # None means any json object + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + if isinstance(schema, BaseModel): + return schema.model_json_schema() + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py similarity index 93% rename from vllm/model_executor/guided_decoding.py rename to vllm/model_executor/guided_decoding/outlines_decoding.py index 8e710f1ac2b53..bd4564a36e1ed 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,9 +12,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) -from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, - JSONLogitsProcessor, - RegexLogitsProcessor) +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) class GuidedDecodingMode(Enum): @@ -54,7 +53,7 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm -async def get_guided_decoding_logits_processor( +async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: """ diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py similarity index 70% rename from vllm/model_executor/guided_logits_processors.py rename to vllm/model_executor/guided_decoding/outlines_logits_processors.py index 035fe00037328..28041695546dc 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -13,9 +13,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import math from collections import defaultdict +from functools import lru_cache from typing import Callable, DefaultDict, Dict, List, Optional, Union import torch @@ -27,50 +29,6 @@ class BaseLogitsProcessor: - def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. - - """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer - - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) - - def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE - - string = tokenizer.convert_tokens_to_string([token]) - - # A hack to handle missing spaces to HF's Llama tokenizers - if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": - return " " + string - - return string - - def change_decoder( - decoder: Callable[[List[int]], str] - ) -> Callable[[List[int]], List[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" - - def new_decoder(inp_tokens: List[int]) -> List[str]: - return [decoder(inp_tokens)] - - return new_decoder - - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - - return tokenizer - def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) @@ -78,7 +36,6 @@ def init_state(self): def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" - seq_id = hash(tuple(input_ids)) if len(input_ids) == 0: @@ -96,7 +53,6 @@ def __call__(self, input_ids: List[int], device=scores.device) mask[allowed_tokens] = 0 scores.add_(mask) - return scores @@ -113,7 +69,7 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm @@ -167,6 +123,54 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - tokenizer = self.adapt_tokenizer(tokenizer) + tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + + +@lru_cache +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer From 69e1d2fb6922b2d388bae41286d8867976cbd6c6 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 16 Apr 2024 11:34:39 -0700 Subject: [PATCH 052/413] [Core] Refactor model loading code (#4097) --- .buildkite/test-pipeline.yaml | 2 +- examples/fp8/extract_scales.py | 4 +- examples/tensorize_vllm_model.py | 2 +- tests/lora/conftest.py | 10 +- tests/lora/test_worker.py | 10 +- tests/model_executor/weight_utils.py | 2 +- .../test_autogptq_marlin_configs.py | 4 - tests/samplers/test_sampler.py | 14 +- tests/spec_decode/utils.py | 1 + .../__init__.py | 0 .../tensorize_vllm_model_for_testing.py | 4 +- .../test_tensorizer.py | 139 ++++--- tests/test_config.py | 4 - tests/test_logits_processor.py | 7 +- tests/worker/test_model_runner.py | 39 +- tests/worker/test_swap.py | 1 + vllm/config.py | 201 ++++------ vllm/engine/arg_utils.py | 59 ++- vllm/engine/llm_engine.py | 19 +- vllm/executor/cpu_executor.py | 1 + vllm/executor/executor_base.py | 10 +- vllm/executor/gpu_executor.py | 2 +- vllm/executor/ray_gpu_executor.py | 5 +- vllm/model_executor/model_loader.py | 128 ------- vllm/model_executor/model_loader/__init__.py | 30 ++ vllm/model_executor/model_loader/loader.py | 354 ++++++++++++++++++ .../neuron.py} | 0 .../tensorizer.py} | 116 ++++-- vllm/model_executor/model_loader/utils.py | 40 ++ .../{ => model_loader}/weight_utils.py | 295 +++++++-------- vllm/model_executor/models/baichuan.py | 14 +- vllm/model_executor/models/bloom.py | 14 +- vllm/model_executor/models/chatglm.py | 14 +- vllm/model_executor/models/commandr.py | 16 +- vllm/model_executor/models/dbrx.py | 16 +- vllm/model_executor/models/decilm.py | 14 +- vllm/model_executor/models/deepseek.py | 20 +- vllm/model_executor/models/falcon.py | 14 +- vllm/model_executor/models/gemma.py | 14 +- vllm/model_executor/models/gpt2.py | 14 +- vllm/model_executor/models/gpt_bigcode.py | 14 +- vllm/model_executor/models/gpt_j.py | 14 +- vllm/model_executor/models/gpt_neox.py | 14 +- vllm/model_executor/models/internlm2.py | 14 +- vllm/model_executor/models/jais.py | 16 +- vllm/model_executor/models/llama.py | 16 +- vllm/model_executor/models/llava.py | 14 +- vllm/model_executor/models/minicpm.py | 14 +- vllm/model_executor/models/mixtral.py | 20 +- vllm/model_executor/models/mixtral_quant.py | 19 +- vllm/model_executor/models/mpt.py | 14 +- vllm/model_executor/models/olmo.py | 16 +- vllm/model_executor/models/opt.py | 14 +- vllm/model_executor/models/orion.py | 14 +- vllm/model_executor/models/phi.py | 14 +- vllm/model_executor/models/qwen.py | 14 +- vllm/model_executor/models/qwen2.py | 14 +- vllm/model_executor/models/qwen2_moe.py | 20 +- vllm/model_executor/models/stablelm.py | 14 +- vllm/model_executor/models/starcoder2.py | 14 +- vllm/model_executor/models/xverse.py | 14 +- vllm/transformers_utils/tokenizer.py | 21 +- vllm/worker/cpu_model_runner.py | 12 +- vllm/worker/cpu_worker.py | 7 +- vllm/worker/model_runner.py | 15 +- vllm/worker/neuron_model_runner.py | 2 +- vllm/worker/worker.py | 10 +- 67 files changed, 1064 insertions(+), 973 deletions(-) rename tests/{tensorizer => tensorizer_loader}/__init__.py (100%) rename tests/{tensorizer => tensorizer_loader}/tensorize_vllm_model_for_testing.py (98%) rename tests/{tensorizer => tensorizer_loader}/test_tensorizer.py (67%) delete mode 100644 vllm/model_executor/model_loader.py create mode 100644 vllm/model_executor/model_loader/__init__.py create mode 100644 vllm/model_executor/model_loader/loader.py rename vllm/model_executor/{neuron_model_loader.py => model_loader/neuron.py} (100%) rename vllm/model_executor/{tensorizer_loader.py => model_loader/tensorizer.py} (78%) create mode 100644 vllm/model_executor/model_loader/utils.py rename vllm/model_executor/{ => model_loader}/weight_utils.py (53%) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa4582bbda0c7..f39c3302ac2e9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -92,7 +92,7 @@ steps: parallelism: 4 - label: Tensorizer Test - command: apt-get install curl libsodium23 && pytest -v -s tensorizer + command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test command: pytest -v -s metrics diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py index 5e5b31265e3af..1eb961a5a76e3 100644 --- a/examples/fp8/extract_scales.py +++ b/examples/fp8/extract_scales.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.schema import QuantParamSchema -# Adapted from vllm/model_executor/weight_utils.py +# Adapted from vllm/model_executor/model_loader/weight_utils.py # The main differences are that we add the NPZ format and simplify # its functionality drastically for our purposes (e.g. we assume that # the quantized model exists locally and there is no need to download it) @@ -71,7 +71,7 @@ def _prepare_hf_weights( return hf_weights_files, use_safetensors -# Adapted from vllm/model_executor/weight_utils.py +# Adapted from vllm/model_executor/model_loader/weight_utils.py def _hf_tensorfile_iterator(filename: str, load_format: str, use_safetensors: bool): if load_format == "npz": diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index 8cf8be09d0b9c..e2456168de9d5 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -16,8 +16,8 @@ from vllm.distributed import initialize_model_parallel from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.model_loader.tensorizer import TensorizerArgs from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.tensorizer_loader import TensorizerArgs # yapf conflicts with isort for this docstring # yapf: disable diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 1127cc33183c9..2dabfb6b4337c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() get_model_old = get_model - def get_model_patched(model_config, device_config, **kwargs): - return get_model_old(model_config, - device_config, - lora_config=LoRAConfig(max_loras=4, - max_lora_rank=8)) + def get_model_patched(*, model_config, device_config, **kwargs): + kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) + return get_model_old(model_config=model_config, + device_config=device_config, + **kwargs) with patch("vllm.worker.model_runner.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 54594690f7922..732e91a52c0a9 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -3,8 +3,8 @@ import tempfile from unittest.mock import patch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest from vllm.worker.worker import Worker @@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files): "meta-llama/Llama-2-7b-hf", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, ), + load_config=LoadConfig( + download_dir=None, + load_format="dummy", + ), parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32), device_config=DeviceConfig("cuda"), diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index 3154f2826d10c..b0086dd7a7d71 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -3,7 +3,7 @@ import huggingface_hub.constants import pytest -from vllm.model_executor.weight_utils import enable_hf_transfer +from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer def test_hf_transfer_auto_activation(): diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py index cd64622e2226f..1310b4da218b5 100644 --- a/tests/quantization/test_autogptq_marlin_configs.py +++ b/tests/quantization/test_autogptq_marlin_configs.py @@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None: model_path, tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None: model_path, tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 26e2d29ffd04c..dbbe13b8da060 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -32,7 +32,12 @@ def _prepare_test( 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(None, None, None, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) return input_tensor, fake_logits, sampler, model_runner @@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(None, None, None, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4637826f254d6..edba4c226b289 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -118,6 +118,7 @@ def create_worker(cls: type, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, + load_config=engine_config.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer_loader/__init__.py similarity index 100% rename from tests/tensorizer/__init__.py rename to tests/tensorizer_loader/__init__.py diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py similarity index 98% rename from tests/tensorizer/tensorize_vllm_model_for_testing.py rename to tests/tensorizer_loader/tensorize_vllm_model_for_testing.py index d0be08329fd64..e4b15fd57add4 100644 --- a/tests/tensorizer/tensorize_vllm_model_for_testing.py +++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py @@ -16,8 +16,8 @@ from vllm.distributed import initialize_model_parallel from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.model_executor.model_loader.tensorizer import TensorizerArgs from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.tensorizer_loader import TensorizerArgs # yapf conflicts with isort for this docstring # yapf: disable @@ -74,7 +74,7 @@ def parse_args(): "extremely quickly. Tensor encryption and decryption is " "also supported, although libsodium must be installed to " "use it.") - parser = EngineArgs.add_cli_args(parser) + parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser)) subparsers = parser.add_subparsers(dest='command') serialize_parser = subparsers.add_parser( diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py similarity index 67% rename from tests/tensorizer/test_tensorizer.py rename to tests/tensorizer_loader/test_tensorizer.py index 2ab893e95da9c..a97cc0b3706b4 100644 --- a/tests/tensorizer/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,16 +1,19 @@ import gc +import json +import os import subprocess from unittest.mock import MagicMock, patch +import openai import pytest +import ray import torch from tests.entrypoints.test_openai_server import ServerRunner from vllm import SamplingParams -from vllm.config import TensorizerConfig -from vllm.model_executor.tensorizer_loader import ( - EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer, - load_with_tensorizer, open_stream) +from vllm.model_executor.model_loader.tensorizer import ( + EncryptionParams, TensorizerConfig, TensorSerializer, + is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) prompts = [ "Hello, my name is", @@ -22,6 +25,8 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) model_ref = "facebook/opt-125m" +tensorize_model_for_testing_script = os.path.join( + os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") def is_curl_installed(): @@ -38,7 +43,7 @@ def tensorizer_config(): return config -@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent') +@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent') def test_load_with_tensorizer(mock_agent, tensorizer_config): mock_linear_method = MagicMock() mock_agent_instance = mock_agent.return_value @@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): del vllm_model, model gc.collect() torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner(model_ref, - load_format="tensorizer", - tensorizer_uri=model_path, - num_readers=1, - vllm_tensorized=True) + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True), + ) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # Assumes SamplingParams being seeded ensures the outputs are deterministic @@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" - loaded_hf_model = vllm_runner( - model_ref, - tensorizer_uri=tensorized_path, - load_format="tensorizer", - num_readers=1, - vllm_tensorized=False, - s3_endpoint="object.ord1.coreweave.com", - ) + loaded_hf_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=tensorized_path, + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + )) deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) @@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( gc.collect() torch.cuda.empty_cache() loaded_vllm_model = vllm_runner(model_ref, - tensorizer_uri=model_path, load_format="tensorizer", - encryption_keyfile=key_path, - num_readers=1, - vllm_tensorized=True) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + encryption_keyfile=key_path, + num_readers=1, + vllm_tensorized=True)) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) @@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, gc.collect() torch.cuda.empty_cache() loaded_hf_model = vllm_runner(model_ref, - tensorizer_uri=model_path, load_format="tensorizer", - num_readers=1, - vllm_tensorized=False) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=False)) deserialized_outputs = loaded_hf_model.generate_greedy( prompts, max_tokens=max_tokens) @@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): torch.cuda.empty_cache() loaded_vllm_model = vllm_runner( model_ref, - tensorizer_uri=model_path, load_format="tensorizer", - num_readers=1, - vllm_tensorized=True, + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=model_path, + num_readers=1, + vllm_tensorized=True, + ), enable_lora=True, max_loras=1, max_lora_rank=8, @@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_load_without_tensorizer_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, tensorizer_uri="test") + vllm_runner(model_ref, + model_loader_extra_config=TensorizerConfig( + tensorizer_uri="test", vllm_tensorized=False)) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_tensorize_vllm_model(tmp_path): # Test serialize command serialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--dtype", "float16", "serialize", "--serialized-directory", - tmp_path, "--suffix", "tests" + "python3", tensorize_model_for_testing_script, "--model", model_ref, + "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, + "--suffix", "tests" ] result = subprocess.run(serialize_args, capture_output=True, text=True) print(result.stdout) # Print the output of the serialize command @@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path): # Test deserialize command deserialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors", + "python3", tensorize_model_for_testing_script, "--model", model_ref, + "--dtype", "float16", "deserialize", "--path-to-tensors", path_to_tensors ] result = subprocess.run(deserialize_args, capture_output=True, text=True) @@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path): def test_openai_apiserver_with_tensorizer(tmp_path): ## Serialize model serialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--dtype", "float16", "serialize", "--serialized-directory", - tmp_path, "--suffix", "tests" + "python3", tensorize_model_for_testing_script, "--model", model_ref, + "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, + "--suffix", "tests" ] result = subprocess.run(serialize_args, capture_output=True, text=True) print(result.stdout) # Print the output of the serialize command @@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path): f"\n{result.stdout}\n{result.stderr}") path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" + model_loader_extra_config = { + "tensorizer_uri": path_to_tensors, + "vllm_tensorized": True + } ## Start OpenAI API server openai_args = [ "--model", model_ref, "--dtype", "float16", "--load-format", - "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized", - "--port", "8000" + "tensorizer", "--model-loader-extra-config", + json.dumps(model_loader_extra_config), "--port", "8000" ] server = ServerRunner.remote(openai_args) + assert ray.get(server.ready.remote()) print("Server ready.") - assert server.ready.remote() + + client = openai.OpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + completion = client.completions.create(model=model_ref, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) def test_raise_value_error_on_invalid_load_format(vllm_runner): with pytest.raises(ValueError): vllm_runner(model_ref, load_format="safetensors", - tensorizer_uri="test") + model_loader_extra_config=TensorizerConfig( + tensorizer_uri="test", vllm_tensorized=False)) def test_tensorizer_with_tp(vllm_runner): @@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner): vllm_runner( model_ref, - tensorizer_uri=tensorized_path, load_format="tensorizer", - num_readers=1, - vllm_tensorized=False, - s3_endpoint="object.ord1.coreweave.com", + model_loader_extra_config=TensorizerConfig( + tensorizer_uri=tensorized_path, + num_readers=1, + vllm_tensorized=False, + s3_endpoint="object.ord1.coreweave.com", + ), tensor_parallel_size=2, ) - - -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_tensorizer_warn_quant(tmp_path): - model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" - serialize_args = [ - "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model", - model_ref, "--quantization", "gptq", "--tensorizer-uri", "test", - "serialize", "--serialized-directory", tmp_path, "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - assert 'PerformanceWarning' in result.stderr diff --git a/tests/test_config.py b/tests/test_config.py index 13a9f76212679..19db10630bbae 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,8 +11,6 @@ def test_get_sliding_window(): "Qwen/Qwen1.5-7B", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -30,8 +28,6 @@ def test_get_sliding_window(): "mistralai/Mistral-7B-v0.1", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index fe321520114f7..5bb93ca74855b 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -37,7 +37,12 @@ def _prepare_test( 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(None, None, None, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) return input_tensor, fake_logits, logits_processor, model_runner diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index dcaae4af4a6f8..59bed2ce0dad3 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size): 100000, 100000, enable_chunked_prefill=False) - model_runner = ModelRunner(None, None, scheduler_config, None, None) + model_runner = ModelRunner(model_config=None, + parallel_config=None, + scheduler_config=scheduler_config, + device_config=None, + load_config=None, + lora_config=None) model_runner.set_block_size(16) prompt_lens = [] @@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size): "facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size): 100000, 100000, enable_chunked_prefill=False) - model_runner = ModelRunner(model_config, None, scheduler_config, None, - None) + model_runner = ModelRunner(model_config=model_config, + parallel_config=None, + scheduler_config=scheduler_config, + device_config=None, + load_config=None, + lora_config=None) model_runner.set_block_size(16) prompt_lens = [] @@ -205,14 +212,17 @@ def test_empty_seq_group(): "facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config, None, None, None, None) + model_runner = ModelRunner(model_config=model_config, + parallel_config=None, + scheduler_config=None, + device_config=None, + load_config=None, + lora_config=None) model_runner.set_block_size(16) seq_group_metadata_list = [] input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( @@ -251,8 +261,6 @@ def mock_get_process_group_ranks(group=None): "facebook/opt-125m", tokenizer_mode="auto", trust_remote_code=False, - download_dir=None, - load_format="dummy", seed=0, dtype="float16", revision=None, @@ -262,11 +270,12 @@ def mock_get_process_group_ranks(group=None): 100000, 100000, enable_chunked_prefill=True) - model_runner = ModelRunner(model_config, - None, - scheduler_config, - None, - None, + model_runner = ModelRunner(model_config=model_config, + parallel_config=None, + scheduler_config=scheduler_config, + device_config=None, + load_config=None, + lora_config=None, is_driver_worker=True) model_runner.set_block_size(16) diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 8edb1cf05c08e..1804cf78d8003 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -23,6 +23,7 @@ def test_swap() -> None: scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, cache_config=engine_config.cache_config, + load_config=engine_config.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/config.py b/vllm/config.py index bf31b03b7c6c4..5a29620e85ac6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,9 +1,7 @@ import enum -import io import json import os -import typing -from dataclasses import dataclass, fields +from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch @@ -18,10 +16,14 @@ if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup - from vllm.model_executor.tensorizer_loader import TensorizerArgs + from vllm.model_executor.model_loader.loader import BaseModelLoader logger = init_logger(__name__) +# If true, will load models from ModelScope instead of Hugging Face Hub. +VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE", + "False").lower() == "true" + _GB = 1 << 30 @@ -35,18 +37,6 @@ class ModelConfig: available, and "slow" will always use the slow tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. @@ -83,8 +73,6 @@ def __init__( tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, - download_dir: Optional[str], - load_format: str, dtype: Union[str, torch.dtype], seed: int, revision: Optional[str] = None, @@ -101,8 +89,6 @@ def __init__( self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code - self.download_dir = download_dir - self.load_format = load_format self.seed = seed self.revision = revision self.code_revision = code_revision @@ -113,64 +99,16 @@ def __init__( self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs - if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - if not os.path.exists(model): - model_path = snapshot_download(model_id=model, - cache_dir=download_dir, - revision=revision) - else: - model_path = model - self.model = model_path - self.download_dir = model_path - self.tokenizer = model_path - self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) - self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() self._verify_cuda_graph() - def _verify_load_format(self) -> None: - load_format = self.load_format.lower() - supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer" - ] - rocm_not_supported_load_format: List[str] = [] - if load_format not in supported_load_format: - raise ValueError( - f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or " - "'dummy'.") - if is_hip() and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format '{load_format}' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - - # TODO: Remove this check once HF updates the pt weights of Mixtral. - architectures = getattr(self.hf_config, "architectures", []) - # architectures can be None instead of [] - if architectures and "MixtralForCausalLM" in architectures \ - and load_format == "pt": - raise ValueError( - "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") - self.load_format = load_format - def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow"]: @@ -471,6 +409,65 @@ def create_config( return tokenizer_pool_config +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + class ParallelConfig: """Configuration for the distributed execution. @@ -699,8 +696,6 @@ def maybe_create_spec_config( tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, - download_dir=target_model_config.download_dir, - load_format=target_model_config.load_format, dtype=target_model_config.dtype, seed=target_model_config.seed, revision=draft_revision, @@ -887,65 +882,6 @@ def get_image_input_enum_type( f"{[x.name for x in cls.ImageInputType]}.") from e -@dataclass -class TensorizerConfig: - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, - str, bytes, os.PathLike, int] - vllm_tensorized: bool - verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None - model_class: Optional[torch.nn.Module] = None - hf_config: Optional[PretrainedConfig] = None - dtype: Union[str, torch.dtype] = None - - def _construct_tensorizer_args(self) -> "TensorizerArgs": - from vllm.model_executor.tensorizer_loader import TensorizerArgs - tensorizer_args = { - "tensorizer_uri": self.tensorizer_uri, - "vllm_tensorized": self.vllm_tensorized, - "verify_hash": self.verify_hash, - "num_readers": self.num_readers, - "encryption_keyfile": self.encryption_keyfile, - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, - } - return TensorizerArgs(**tensorizer_args) - - def verify_with_parallel_config( - self, - parallel_config: "ParallelConfig", - ) -> None: - if (parallel_config.tensor_parallel_size > 1 - and self.tensorizer_uri is not None): - raise ValueError( - "Loading to multiple GPUs is not currently supported with " - "vLLM-serialized models. Please set tensor_parallel_size=1." - " or use a non-vLLM-serialized model, such as a " - "serialized Hugging Face `PretrainedModel`.") - - def verify_with_model_config(self, model_config) -> None: - if (model_config.quantization is not None - and self.tensorizer_uri is not None): - from vllm.model_executor.tensorizer_loader import ( - tensorizer_warning) - tensorizer_warning( - "Loading a model using Tensorizer with quantization on vLLM" - " is unstable and may lead to errors.") - - if (model_config.load_format != "tensorizer" - and self.tensorizer_uri is not None): - raise ValueError( - "A tensorizer uri was passed for tensorizer loading, but the " - f"load format was set to {model_config.load_format}. " - "Please set the load format to 'tensorizer' to use " - f"tensorizer args.") - - _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, @@ -1105,11 +1041,11 @@ class EngineConfig: parallel_config: ParallelConfig scheduler_config: SchedulerConfig device_config: DeviceConfig + load_config: LoadConfig lora_config: Optional[LoRAConfig] vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] - tensorizer_config: Optional[TensorizerConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1117,11 +1053,6 @@ def __post_init__(self): self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.tensorizer_config: - self.tensorizer_config.verify_with_parallel_config( - self.parallel_config) - self.tensorizer_config.verify_with_model_config(self.model_config) - if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3de74b0ac28b9..c61c0cc67d7a2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,15 +1,12 @@ import argparse import dataclasses -import io -import os from dataclasses import dataclass -from typing import BinaryIO, Optional, Union +from typing import Optional from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig, TensorizerConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) -from vllm.model_executor.tensorizer_loader import TensorizerArgs from vllm.utils import str_to_int_tuple @@ -60,17 +57,7 @@ class EngineArgs: ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 - - # Tensorizer configuration parameters - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, - bytes, os.PathLike, int] = None - vllm_tensorized: bool = False - verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 - encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None + model_loader_extra_config: Optional[dict] = None # Related to Vision-language models such as llava image_input_type: Optional[str] = None @@ -429,7 +416,16 @@ def add_cli_args( default=None, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding') - parser = TensorizerArgs.add_cli_args(parser) + + parser.add_argument('--model-loader-extra-config', + type=str, + default=EngineArgs.model_loader_extra_config, + help='Extra config for model loader. ' + 'This will be passed to the model loader ' + 'corresponding to the chosen load_format. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + return parser @classmethod @@ -444,11 +440,11 @@ def create_engine_config(self, ) -> EngineConfig: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, - self.trust_remote_code, self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, self.code_revision, - self.tokenizer_revision, self.max_model_len, self.quantization, - self.quantization_param_path, self.enforce_eager, - self.max_context_len_to_capture, self.max_logprobs) + self.trust_remote_code, self.dtype, self.seed, self.revision, + self.code_revision, self.tokenizer_revision, self.max_model_len, + self.quantization, self.quantization_param_path, + self.enforce_eager, self.max_context_len_to_capture, + self.max_logprobs) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -492,15 +488,10 @@ def create_engine_config(self, ) -> EngineConfig: max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None - tensorizer_config = TensorizerConfig( - tensorizer_uri=self.tensorizer_uri, - vllm_tensorized=self.vllm_tensorized, - verify_hash=self.verify_hash, - num_readers=self.num_readers, - encryption_keyfile=self.encryption_keyfile, - s3_access_key_id=self.s3_access_key_id, - s3_secret_access_key=self.s3_secret_access_key, - s3_endpoint=self.s3_endpoint, + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, ) if self.image_input_type: @@ -530,8 +521,8 @@ def create_engine_config(self, ) -> EngineConfig: lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, - decoding_config=decoding_config, - tensorizer_config=tensorizer_config) + load_config=load_config, + decoding_config=decoding_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f06c1d18ace4b..563694946d16e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -4,9 +4,9 @@ from transformers import PreTrainedTokenizer import vllm -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs @@ -72,11 +72,11 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], - tensorizer_config: Optional[TensorizerConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -92,8 +92,8 @@ def __init__( f"trust_remote_code={model_config.trust_remote_code}, " f"dtype={model_config.dtype}, " f"max_seq_len={model_config.max_model_len}, " - f"download_dir={model_config.download_dir!r}, " - f"load_format={model_config.load_format}, " + f"download_dir={load_config.download_dir!r}, " + f"load_format={load_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"disable_custom_all_reduce=" f"{parallel_config.disable_custom_all_reduce}, " @@ -114,8 +114,8 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.speculative_config = speculative_config + self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() - self.tensorizer_config = tensorizer_config self.log_stats = log_stats self._init_tokenizer() @@ -131,7 +131,7 @@ def __init__( lora_config=lora_config, vision_language_config=vision_language_config, speculative_config=speculative_config, - tensorizer_config=tensorizer_config, + load_config=load_config, ) self._initialize_kv_caches() @@ -271,9 +271,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) - if self.tensorizer_config: - self.tensorizer_config.verify_with_parallel_config( - self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index f562e4e0ae3de..426e2c41d8427 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -40,6 +40,7 @@ def _init_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbb6ec80f7b7e..8cc04c5299ca1 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -23,20 +23,20 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config + self.load_config = load_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config self.speculative_config = speculative_config - self.tensorizer_config = tensorizer_config self._init_executor() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index bae509f48025b..3a9537effe6d9 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -35,12 +35,12 @@ def _init_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, is_driver_worker=True, ) self.driver_worker.init_device() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 7aca5e36107aa..4065c0868d79a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -147,6 +147,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) + load_config = copy.deepcopy(self.load_config) device_config = copy.deepcopy(self.device_config) lora_config = copy.deepcopy(self.lora_config) cache_config = copy.deepcopy(self.cache_config) @@ -165,12 +166,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config=scheduler_config, device_config=device_config, cache_config=cache_config, + load_config=load_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, lora_config=lora_config, vision_language_config=vision_language_config, - tensorizer_config=self.tensorizer_config, )) # Initialize the driver worker with the Worker class. @@ -187,7 +188,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - tensorizer_config=self.tensorizer_config, + load_config=self.load_config, is_driver_worker=True, ) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py deleted file mode 100644 index c70ca48bca70a..0000000000000 --- a/vllm/model_executor/model_loader.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Utilities for selecting and loading models.""" -import contextlib -from typing import Tuple, Type - -import torch -from torch import nn - -from vllm.config import DeviceConfig, ModelConfig -from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.llava import LlavaForConditionalGeneration -from vllm.model_executor.tensorizer_loader import ( - ParameterizedLoadFormat, is_vllm_serialized_tensorizer, - load_with_tensorizer) -from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) - -_VISION_MODEL_CLASSES = [ - LlavaForConditionalGeneration, -] - - -@contextlib.contextmanager -def _set_default_torch_dtype(dtype: torch.dtype): - """Sets the default torch dtype to the given dtype.""" - old_dtype = torch.get_default_dtype() - torch.set_default_dtype(dtype) - yield - torch.set_default_dtype(old_dtype) - - -def _get_model_architecture( - model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: - architectures = getattr(model_config.hf_config, "architectures", []) - # Special handling for quantized Mixtral. - # FIXME(woosuk): This is a temporary hack. - if (model_config.quantization is not None - and "MixtralForCausalLM" in architectures): - architectures = ["QuantMixtralForCausalLM"] - - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def get_architecture_class_name(model_config: ModelConfig) -> str: - return _get_model_architecture(model_config)[1] - - -def get_model(model_config: ModelConfig, device_config: DeviceConfig, - **kwargs) -> nn.Module: - lora_config = kwargs.get("lora_config", None) - vision_language_config = kwargs.get("vision_language_config", None) - tensorizer_config = kwargs.get("tensorizer_config", None) - model_class = _get_model_architecture(model_config)[0] - - # Get the (maybe quantized) linear method. - linear_method = None - if model_config.quantization is not None: - quant_config = get_quant_config(model_config) - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") - supported_dtypes = quant_config.get_supported_act_dtypes() - if model_config.dtype not in supported_dtypes: - raise ValueError( - f"{model_config.dtype} is not supported for quantization " - f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") - - linear_method = quant_config.get_linear_method() - - with _set_default_torch_dtype(model_config.dtype): - # Create a model instance. - # The weights will be initialized as empty tensors. - extra_kwargs = {} - if hasattr(model_class, "supported_lora_modules"): - extra_kwargs["lora_config"] = lora_config - elif lora_config: - raise ValueError( - f"Model {model_class.__name__} does not support LoRA, " - "but LoRA is enabled. Support for this model may " - "be added in the future. If this is important to you, " - "please open an issue on github.") - elif model_class in _VISION_MODEL_CLASSES: - extra_kwargs["vision_language_config"] = vision_language_config - - with torch.device(device_config.device): - if (model_config.load_format == "tensorizer" - and is_vllm_serialized_tensorizer(tensorizer_config)): - extra_kwargs["linear_method"] = linear_method - tensorizer_config.model_class = model_class - tensorizer_config.hf_config = model_config.hf_config - tensorizer_config.dtype = model_config.dtype - model = load_with_tensorizer(tensorizer_config, **extra_kwargs) - return model.eval() - model = model_class(config=model_config.hf_config, - linear_method=linear_method, - **extra_kwargs) - if model_config.load_format == "dummy": - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - else: - # Load the weights from the cached or downloaded files. - if model_config.load_format == "tensorizer": - # Provide a dynamic load format for `model.load_weights` - # to retain tensorizer args from CLI. - model_config.load_format = ParameterizedLoadFormat( - model_config.load_format) - model_config.load_format.params = ( - tensorizer_config._construct_tensorizer_args()) - - model.load_weights( - model_config.model, - model_config.download_dir, - model_config.load_format, - model_config.revision, - ) - return model.eval() diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py new file mode 100644 index 0000000000000..6f90e49994fb2 --- /dev/null +++ b/vllm/model_executor/model_loader/__init__.py @@ -0,0 +1,30 @@ +from typing import Optional + +from torch import nn + +from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.model_executor.model_loader.loader import (BaseModelLoader, + get_model_loader) +from vllm.model_executor.model_loader.utils import ( + get_architecture_class_name, get_model_architecture) + + +def get_model( + *, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: + loader = get_model_loader(load_config) + return loader.load_model(model_config=model_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config) + + +__all__ = [ + "get_model", "get_model_loader", "BaseModelLoader", + "get_architecture_class_name", "get_model_architecture" +] diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py new file mode 100644 index 0000000000000..3b1d125ef8a67 --- /dev/null +++ b/vllm/model_executor/model_loader/loader.py @@ -0,0 +1,354 @@ +# ruff: noqa: SIM117 +import copy +import glob +import os +from abc import ABC, abstractmethod +from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, + Type) + +import torch +from torch import nn + +from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig, + LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, + tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, filter_files_not_needed_for_inference, + get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, + pt_weights_iterator, safetensors_weights_iterator) +from vllm.model_executor.models.llava import LlavaForConditionalGeneration + +if TYPE_CHECKING: + from vllm.model_executor.layers.linear import LinearMethodBase + +_VISION_MODEL_CLASSES = [ + LlavaForConditionalGeneration, +] + +logger = init_logger(__name__) + + +def _get_linear_method( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional["LinearMethodBase"]: + """Get the (maybe quantized) linear method.""" + linear_method = None + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + + linear_method = quant_config.get_linear_method() + return linear_method + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] +) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs = {} + if hasattr(model_class, "supported_lora_modules"): + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + elif model_class in _VISION_MODEL_CLASSES: + extra_kwargs["vision_language_config"] = vision_language_config + return extra_kwargs + + +def _initialize_model( + model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class = get_model_architecture(model_config)[0] + linear_method = _get_linear_method(model_config, load_config) + + return model_class(config=model_config.hf_config, + linear_method=linear_method, + **_get_model_initialization_kwargs( + model_class, lora_config, vision_language_config)) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + """Load a model with the given configurations.""" + ... + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + revision=revision) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if not use_safetensors: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + return np_cache_weights_iterator(model_name_or_path, + self.load_config.download_dir, + hf_folder, hf_weights_files) + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + return pt_weights_iterator(hf_weights_files) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + return model.eval() + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_unserialized( + self, model_config: ModelConfig, device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] + ) -> nn.Module: + """Load an unserialized model with tensorizer. + + Unserialized here means "not serialized with tensorizer". This + should still be faster than default HuggingFace loading, but will + be slower than loading a tensorizer-serialized model. + """ + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, model_config: ModelConfig, device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig] + ) -> nn.Module: + """Load a serialized model with tensorizer. + + See the examples/tensorize_vllm_model.py example " + script for serializing vLLM models.""" + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + linear_method = _get_linear_method(model_config, + self.load_config) + extra_kwargs = _get_model_initialization_kwargs( + model_class, lora_config, vision_language_config) + extra_kwargs["linear_method"] = linear_method + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + self._verify_config(model_config, parallel_config) + + if is_vllm_serialized_tensorizer(self.tensorizer_config): + return self._load_model_serialized(model_config, device_config, + lora_config, + vision_language_config) + return self._load_model_unserialized(model_config, device_config, + lora_config, + vision_language_config) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/neuron_model_loader.py b/vllm/model_executor/model_loader/neuron.py similarity index 100% rename from vllm/model_executor/neuron_model_loader.py rename to vllm/model_executor/model_loader/neuron.py diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer.py similarity index 78% rename from vllm/model_executor/tensorizer_loader.py rename to vllm/model_executor/model_loader/tensorizer.py index 8550cc97aefe8..ad554844384eb 100644 --- a/vllm/model_executor/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -4,20 +4,20 @@ import os import time import typing -import warnings from dataclasses import dataclass -from typing import Optional, Union +from typing import Generator, Optional, Tuple, Type, Union import torch from torch import nn +from transformers import PretrainedConfig -from vllm.config import TensorizerConfig +from vllm.config import ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -tensorizer_load_fail = False +tensorizer_load_fail = None try: from tensorizer import (DecryptionParams, EncryptionParams, @@ -25,51 +25,78 @@ from tensorizer.stream_io import open_stream from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) -except ImportError: - tensorizer_load_fail = True +except ImportError as e: + tensorizer_load_fail = e __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', - 'no_init_or_tensor' + 'no_init_or_tensor', 'TensorizerConfig' ] logger = init_logger(__name__) +@dataclass +class TensorizerConfig: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, + str, bytes, os.PathLike, int] + vllm_tensorized: bool + verify_hash: Optional[bool] = False + num_readers: Optional[int] = 1 + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[Type[torch.nn.Module]] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Optional[Union[str, torch.dtype]] = None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if (parallel_config.tensor_parallel_size > 1 + and self.tensorizer_uri is not None): + raise ValueError( + "Loading to multiple GPUs is not currently supported with " + "vLLM-serialized models. Please set tensor_parallel_size=1." + " or use a non-vLLM-serialized model, such as a " + "serialized Hugging Face `PretrainedModel`.") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + def load_with_tensorizer(tensorizer_config: TensorizerConfig, **extra_kwargs) -> nn.Module: tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) return tensorizer.deserialize() -def tensorizer_warning(message: str): - return warnings.warn(message, category=PerformanceWarning, stacklevel=2) - - def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: if tensorizer_config is None: return False return tensorizer_config.vllm_tensorized -class ParameterizedLoadFormat(str): - __slots__ = "params" - - -class PerformanceWarning(UserWarning): - - def __str__(self): - return (f"{super().__str__()}" - " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS" - " environment variable to hide this)") - - -if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower() - not in ("", "0", "n", "no", "off", "disable")): - warnings.simplefilter("ignore", category=PerformanceWarning) - - @dataclass class TensorizerArgs: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, @@ -219,11 +246,17 @@ class TensorizerAgent: behavior of the TensorDeserializer when loading tensors from a serialized model. For deserializations of HuggingFace models, TensorDeserializer is instead used as an iterator directly in the func hf_model_weights_iterator - in vllm/model_executor/weight_utils.py + in vllm/model_executor/model_loader/weight_utils.py """ def __init__(self, tensorizer_config: TensorizerConfig, linear_method: LinearMethodBase, **extra_kwargs): + if tensorizer_load_fail is not None: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`." + ) from tensorizer_load_fail + self.tensorizer_config = tensorizer_config self.tensorizer_args = ( self.tensorizer_config._construct_tensorizer_args()) @@ -234,11 +267,6 @@ def __init__(self, tensorizer_config: TensorizerConfig, self.linear_method = linear_method self.model = self._init_model() - if tensorizer_load_fail: - raise ImportError( - "Tensorizer is not installed. Please install tensorizer " - "to use this feature with `pip install vllm[tensorizer]`.") - def _init_model(self): model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype @@ -313,3 +341,23 @@ def deserialize(self): self._check_tensors_on_meta_device() self._resize_lora_embeddings() return self.model.eval() + + +def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" +) -> Generator[Tuple[str, torch.Tensor], None, None]: + logger.warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py new file mode 100644 index 0000000000000..a0a3b2784614d --- /dev/null +++ b/vllm/model_executor/model_loader/utils.py @@ -0,0 +1,40 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn + +from vllm.config import ModelConfig +from vllm.model_executor.models import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + if (model_config.quantization is not None + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py similarity index 53% rename from vllm/model_executor/weight_utils.py rename to vllm/model_executor/model_loader/weight_utils.py index 08425604f0511..1798db0136868 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -4,8 +4,9 @@ import hashlib import json import os +import tempfile from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Generator, Iterable, List, Optional, Tuple import filelock import huggingface_hub.constants @@ -15,7 +16,7 @@ from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.config import ModelConfig +from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) @@ -27,8 +28,7 @@ # can share the same lock without error. # lock files in the temp directory will be automatically deleted when the # system reboots, so users will not complain about annoying lock files -temp_dir = os.environ.get('TMPDIR') or os.environ.get( - 'TEMP') or os.environ.get('TMP') or "/tmp/" +temp_dir = tempfile.gettempdir() def enable_hf_transfer(): @@ -46,7 +46,7 @@ def enable_hf_transfer(): enable_hf_transfer() -class Disabledtqdm(tqdm): +class DisabledTqdm(tqdm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, disable=True) @@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file( # TODO(woosuk): Move this to other place. -def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: quant_cls = get_quantization_config(model_config.quantization) # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", @@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. - with get_lock(model_name_or_path, model_config.download_dir): + with get_lock(model_name_or_path, load_config.download_dir): hf_folder = snapshot_download(model_name_or_path, revision=model_config.revision, allow_patterns="*.json", - cache_dir=model_config.download_dir, - tqdm_class=Disabledtqdm) + cache_dir=load_config.download_dir, + tqdm_class=DisabledTqdm) else: hf_folder = model_name_or_path config_files = glob.glob(os.path.join(hf_folder, "*.json")) @@ -153,169 +154,127 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: return quant_cls.from_config(config) -def prepare_hf_model_weights( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - fall_back_to_pt: bool = True, - revision: Optional[str] = None, -) -> Tuple[str, List[str], bool]: - # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) \ - and load_format != "tensorizer" - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npcache": - allow_patterns = ["*.bin"] - elif load_format == "tensorizer": - allow_patterns = ["*.tensors"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if not is_local and load_format != "tensorizer": - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info(f"Using model weights format {allow_patterns}") - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=Disabledtqdm, - revision=revision) - else: - hf_folder = model_name_or_path - hf_weights_files: List[str] = [] +def download_weights_from_hf(model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + + Returns: + str: The path to the downloaded model weights. + """ + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] break - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) - ] - - if load_format == "tensorizer": - return hf_folder, hf_weights_files, use_safetensors - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors - - -def hf_model_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: Union[Tuple, str] = "auto", - revision: Optional[str] = None, - fall_back_to_pt: Optional[bool] = True, -) -> Iterator[Tuple[str, torch.Tensor]]: - hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( - model_name_or_path, - cache_dir=cache_dir, - load_format=load_format, - fall_back_to_pt=fall_back_to_pt, - revision=revision) - - if load_format == "npcache": - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - - # Convert the model weights from torch tensors to numpy arrays for - # faster loading. - np_folder = os.path.join(hf_folder, "np") - os.makedirs(np_folder, exist_ok=True) - weight_names_file = os.path.join(np_folder, "weight_names.json") - # Use file lock to prevent multiple processes from - # dumping the same model weights to numpy at the same time. - with get_lock(model_name_or_path, cache_dir): - if not os.path.exists(weight_names_file): - weight_names = [] - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - param_path = os.path.join(np_folder, name) - with open(param_path, "wb") as f: - np.save(f, param.cpu().detach().numpy()) - weight_names.append(name) - with open(weight_names_file, "w") as f: - json.dump(weight_names, f) - - with open(weight_names_file, "r") as f: - weight_names = json.load(f) - - for name in weight_names: - param_path = os.path.join(np_folder, name) - with open(param_path, "rb") as f: - param = np.load(f) - yield name, torch.from_numpy(param) - elif load_format == "tensorizer": - from vllm.model_executor.tensorizer_loader import (TensorDeserializer, - open_stream, - tensorizer_warning) - tensorizer_args = load_format.params - tensorizer_warning( - "Deserializing HuggingFace models is not optimized for " - "loading on vLLM, as tensorizer is forced to load to CPU. " - "Consider deserializing a vLLM model instead for faster " - "load times. See the examples/tensorize_vllm_model.py example " - "script for serializing vLLM models.") - - deserializer_args = tensorizer_args.deserializer_params - stream_params = tensorizer_args.stream_params - stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) - with TensorDeserializer(stream, **deserializer_args, - device="cpu") as state: - for name, param in state.items(): + + logger.info(f"Using model weights format {allow_patterns}") + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download(model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision) + return hf_folder + + +def filter_files_not_needed_for_inference( + hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +def np_cache_weights_iterator( + model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) yield name, param + + +def pt_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param del state - elif use_safetensors: - for st_file in hf_weights_files: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - else: - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() + torch.cuda.empty_cache() def kv_cache_scales_loader( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 30588aecdebe9..69162b0a92d65 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -19,7 +19,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -40,9 +40,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -340,19 +339,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if name == "lm_head.weight": diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 40966ab33631a..14f325e624f41 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -17,7 +17,7 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -35,9 +35,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -298,14 +297,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if name == "lm_head.weight": continue if not name.startswith("transformer."): diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 7b46ba306619a..3cdb7a7bca1c1 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -22,9 +22,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig @@ -370,14 +369,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_pos_emb.inv_freq" in name: continue if "word_embeddings" in name: diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index aa9b28b676e0b..d80969773e163 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -20,7 +20,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch import torch.utils.checkpoint @@ -41,10 +41,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -335,13 +334,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -352,8 +345,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 49eb7f1b2c185..179094b8fd7aa 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,5 +1,5 @@ # coding=utf-8 -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn @@ -18,10 +18,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -391,20 +390,13 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [( "ws" if weight_name in ["w1", "v1"] else "w2s", f"experts.mlp.{weight_name}", ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index abf4a462871b0..d476630ee6f11 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -23,16 +23,15 @@ # limitations under the License. """Inference-only DeciLM model compatible with HuggingFace weights.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import torch from transformers import PretrainedConfig from vllm.config import LoRAConfig from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) class DeciLMForCausalLM(LlamaForCausalLM): @@ -65,11 +64,7 @@ def __init__( linear_method=linear_method, lora_config=lora_config) - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -79,8 +74,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index c7dd11d07e6da..46101a152ec0d 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -44,9 +44,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -316,6 +315,8 @@ def forward( class DeepseekModel(nn.Module): + fall_back_to_pt_during_load = False + def __init__( self, config: PretrainedConfig, @@ -395,11 +396,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -410,12 +407,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 4f1ebcd5fb43c..25ce239d14662 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -19,7 +19,7 @@ """PyTorch Falcon model.""" import math -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -40,9 +40,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -399,11 +398,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): total_num_heads = self.config.num_attention_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads @@ -413,8 +408,7 @@ def load_weights(self, total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if name == "lm_head.weight": # Falcon uses tied embeddings. continue diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index fc1fc35570368..6d01537c5c344 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -15,7 +15,7 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -36,9 +36,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput logger = init_logger(__name__) @@ -346,11 +345,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -361,8 +356,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 43f0d47fcb122..850050c7232d0 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -34,9 +34,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -239,14 +238,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index cec2d771adfa8..8278ba02514d5 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -35,9 +35,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -260,14 +259,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: continue if ".attn.bias" in name: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5660097652748..7a830d7f9c965 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -34,9 +34,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -248,11 +247,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -262,8 +257,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "attn.bias" in name or "attn.masked_bias" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 2f9e2171cf114..b946aed92ed35 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -34,9 +34,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -262,14 +261,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 6e9cbd3f9f43f..db1da8bdc4fb9 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -18,9 +18,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -274,19 +273,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index a041b0c9a0452..e7ee749e824e4 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -20,7 +20,7 @@ """Inference-only Jais model compatible with HuggingFace weights.""" import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -36,9 +36,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import JAISConfig @@ -303,16 +302,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c86e292e7df1a..016e3b039d1e8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -42,10 +42,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator, - kv_cache_scales_loader) from vllm.sequence import SamplerOutput from vllm.utils import is_hip @@ -376,11 +375,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -390,8 +385,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index c2571d0893c8d..314a2792bf167 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -13,10 +13,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput _KEYS_TO_MODIFY_MAPPING = { @@ -198,11 +197,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # only doing this for language model part for now. stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -213,8 +208,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 49eda9c9a8112..f0d72fafcaf70 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -45,10 +45,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -472,11 +471,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -493,8 +488,7 @@ def load_weights(self, for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ff552a9d86536..4d1755f2bbe63 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -43,10 +43,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -319,6 +318,8 @@ def forward( class MixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -393,11 +394,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -414,12 +411,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 1f0c0e912beea..acd13cc27f159 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import numpy as np import torch @@ -43,9 +43,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -327,6 +326,7 @@ def forward( class MixtralForCausalLM(nn.Module): + fall_back_to_pt_during_load = False def __init__( self, @@ -366,11 +366,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -379,12 +375,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index af4cdce29d085..340f63286739b 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,7 +1,7 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch import torch.nn as nn @@ -18,9 +18,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig @@ -284,14 +283,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 3513c72879102..b92003bc0e067 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -36,7 +36,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch # this model must need this dependency @@ -56,9 +56,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -348,16 +347,9 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: # attention if ".att" in name: name = name.replace(".att", ".attn.att") diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 3a640850662c0..89263166bca81 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -35,9 +35,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -315,11 +314,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -327,8 +322,7 @@ def load_weights(self, ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "lm_head.weight" in name: continue if name.startswith("decoder."): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index c606ac027e9d9..bbb9fa5347cc8 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -4,7 +4,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -22,9 +22,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -280,11 +279,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -294,8 +289,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index e91624da90955..f974b78a0fbda 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -35,7 +35,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -53,9 +53,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -265,11 +264,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -278,8 +273,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 6213a2ded65ab..a77da7cb15984 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,7 +4,7 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -23,9 +23,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -253,19 +252,14 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 796e30e633e85..71b906e20ac19 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -42,9 +42,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -331,11 +330,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -345,8 +340,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if self.config.tie_word_embeddings and "lm_head.weight" in name: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index f920b4f5a40c7..59908bc9ef26a 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torch.nn.functional as F @@ -46,9 +46,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -366,6 +365,8 @@ def forward( class Qwen2MoeForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + def __init__( self, config: PretrainedConfig, @@ -404,11 +405,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -419,12 +416,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 651598b770f13..3e6c2db6f3c65 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -19,7 +19,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -37,9 +37,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -262,11 +261,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -276,8 +271,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 76e8e48673413..b90f3da141c2e 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -36,9 +36,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -274,11 +273,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -287,8 +282,7 @@ def load_weights(self, ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 7e9ce9e5c8e15..4e905390c2340 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -40,9 +40,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -331,11 +330,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -344,8 +339,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + for name, loaded_weight in weights: if ("rotary_emb.inv_freq" in name or "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 5d3d5801c960d..c98a673bfed4b 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,8 +1,10 @@ +import os from typing import Optional, Union from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) +from vllm.config import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import BaichuanTokenizer @@ -57,9 +59,26 @@ def get_tokenizer( tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, + download_dir: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface.""" + """Gets a tokenizer for the given model name via Huggingface/modelscope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=tokenizer_revision, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) + tokenizer_name = tokenizer_path + if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError( diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 49e1ad5709f5d..d378e3a90e1e7 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -3,8 +3,8 @@ import torch from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -26,6 +26,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, @@ -36,6 +37,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.load_config = load_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -55,8 +57,10 @@ def __init__( self.model_config.dtype if model_config is not None else None) def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=None, lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 41341b063bed7..6610b9c4be876 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -5,8 +5,8 @@ import torch.distributed from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -117,6 +117,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -129,6 +130,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + self.load_config = load_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -141,6 +143,7 @@ def __init__( parallel_config, scheduler_config, device_config, + load_config=self.load_config, lora_config=self.lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7dbe14ead0976..42c06a1b19361 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,9 +9,8 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, TensorizerConfig, - VisionLanguageConfig) +from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -108,17 +107,17 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, - tensorizer_config: Optional[TensorizerConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config - self.tensorizer_config = tensorizer_config + self.load_config = load_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -156,13 +155,13 @@ def __init__( def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( - self.model_config, - self.device_config, + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - tensorizer_config=self.tensorizer_config, ) self.model_memory_usage = m.consumed_memory diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index fff721a80c204..f70a7193effeb 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -6,7 +6,7 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata -from vllm.model_executor.neuron_model_loader import get_neuron_model +from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import (async_tensor_h2d, is_pin_memory_available, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82491c6df6616..6a79285f60579 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,8 +6,8 @@ import torch import torch.distributed -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, TensorizerConfig, +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, @@ -38,12 +38,12 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + load_config: LoadConfig, local_rank: int, rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, - tensorizer_config: Optional[TensorizerConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -55,7 +55,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config - self.tensorizer_config = tensorizer_config + self.load_config = load_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -70,11 +70,11 @@ def __init__( parallel_config, scheduler_config, device_config, + load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, - tensorizer_config=tensorizer_config, ) # Uninitialized cache engine. Will be initialized by # initialize_cache. From e95cd879598b834f85e70ebcd23db316ae430540 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 13:09:21 -0700 Subject: [PATCH 053/413] [Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894) --- tests/core/block/e2e/test_correctness.py | 70 ++++ tests/core/utils.py | 17 +- .../output_processor/test_multi_step.py | 270 +++++++++++++ tests/spec_decode/e2e/test_correctness.py | 127 +++++- tests/spec_decode/test_multi_step_worker.py | 4 +- tests/spec_decode/test_spec_decode_worker.py | 32 +- tests/spec_decode/utils.py | 2 +- vllm/core/block/block_table.py | 1 - vllm/core/scheduler.py | 8 +- vllm/engine/async_llm_engine.py | 4 +- vllm/engine/llm_engine.py | 371 +++--------------- vllm/engine/output_processor/__init__.py | 0 vllm/engine/output_processor/interfaces.py | 69 ++++ vllm/engine/output_processor/multi_step.py | 126 ++++++ vllm/engine/output_processor/single_step.py | 276 +++++++++++++ vllm/engine/output_processor/stop_checker.py | 101 +++++ vllm/engine/output_processor/util.py | 16 + vllm/executor/cpu_executor.py | 3 +- vllm/executor/executor_base.py | 5 +- vllm/executor/gpu_executor.py | 79 +++- vllm/executor/neuron_executor.py | 5 +- vllm/executor/ray_gpu_executor.py | 3 +- vllm/sequence.py | 13 + vllm/spec_decode/batch_expansion.py | 23 +- vllm/spec_decode/multi_step_worker.py | 16 +- vllm/spec_decode/spec_decode_worker.py | 44 ++- vllm/spec_decode/util.py | 26 ++ vllm/worker/cpu_worker.py | 8 +- vllm/worker/neuron_worker.py | 11 +- vllm/worker/worker.py | 11 +- vllm/worker/worker_base.py | 13 +- 31 files changed, 1347 insertions(+), 407 deletions(-) create mode 100644 tests/engine/output_processor/test_multi_step.py create mode 100644 vllm/engine/output_processor/__init__.py create mode 100644 vllm/engine/output_processor/interfaces.py create mode 100644 vllm/engine/output_processor/multi_step.py create mode 100644 vllm/engine/output_processor/single_step.py create mode 100644 vllm/engine/output_processor/stop_checker.py create mode 100644 vllm/engine/output_processor/util.py diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 94b65401e1dd4..0ee78a9b0a8ea 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [ + { + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "enable_chunked_prefill": True, + "max_num_batched_tokens": 2, + "max_num_seqs": 2, + }, + ]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [ + { + "use_v2_block_manager": False, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "use_v2_block_manager": True, + "num_lookahead_slots": 0, + }, + { + "use_v2_block_manager": True, + "num_lookahead_slots": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_chunked_prefill_block_manager_v2(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify that chunked prefill works with BlockManagerV2, with and without + lookahead scheduling. + """ + output_len = 32 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with BlockManagerV1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with BlockManagerV2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/utils.py b/tests/core/utils.py index fbbdb07cb8e6e..22c1d3826dff4 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from vllm import SamplingParams from vllm.lora.request import LoRARequest @@ -31,14 +31,17 @@ def create_dummy_prompt( def create_seq_group( - seq_prompt_len=1024, - seq_output_lens=(128, ), - request_id='0', - seq_id_start=0, -) -> SequenceGroup: + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: assert len(seq_output_lens) > 0 + if sampling_params is None: + sampling_params = SamplingParams() + prompt_token_ids = [0] * seq_prompt_len seqs = [] @@ -60,7 +63,7 @@ def create_seq_group( seq_group = SequenceGroup( request_id=request_id, seqs=seqs, - sampling_params=SamplingParams(), + sampling_params=sampling_params, arrival_time=time.time(), ) diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py new file mode 100644 index 0000000000000..6da3da091db78 --- /dev/null +++ b/tests/engine/output_processor/test_multi_step.py @@ -0,0 +1,270 @@ +import random +from unittest.mock import MagicMock + +import pytest +from transformers import PreTrainedTokenizer + +from tests.core.utils import create_seq_group +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, + SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + + +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [1, 12]) +@pytest.mark.skip_global_cleanup +def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): + """Verify multi-step decoding appends token ids correctly. + + We append token ids and verify all the token ids were appended correctly. + Note that ignore_eos=True. + """ + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=1024, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams(max_tokens=seq_output_len + + num_new_tokens, + ignore_eos=True), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids + output_processor.process_outputs(seq_group, outputs) + assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8]) +@pytest.mark.parametrize("max_tokens", [128 + 3]) +@pytest.mark.skip_global_cleanup +def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, max_tokens: int): + """Verify tokens after max_tokens are dropped and not appended to the + sequence. + """ + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams(max_tokens=max_tokens, ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go over max tokens in len. + assert seq.get_len() == seq_prompt_len + max_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): + """Verify the eos token id is included in the sequence, but subsequent + tokens are dropped (not appended to sequence). + """ + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to not go beyond provided eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1) + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:eos_index + 1] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +@pytest.mark.parametrize("seq_prompt_len", [1024]) +@pytest.mark.parametrize("seq_output_len", [128]) +@pytest.mark.parametrize("num_new_tokens", [12]) +@pytest.mark.parametrize("seed", list(range(6))) +@pytest.mark.skip_global_cleanup +def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, + seq_output_len: int, seed: int): + """When sampling parameters dictate that we should ignore the eos token id, + ensure all token ids are appended even if the eos token id is emitted. + """ + random.seed(seed) + detokenizer = MagicMock(spec=Detokenizer) + scheduler = MagicMock(spec=Scheduler) + stop_checker = MagicMock(spec=StopChecker) + seq_counter = Counter() + + eos_token_id = 100 + + output_processor = MultiStepOutputProcessor( + detokenizer=detokenizer, + scheduler=scheduler, + seq_counter=seq_counter, + get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id), + stop_checker=stop_checker, + ) + + seq_group = create_seq_group( + seq_prompt_len=seq_prompt_len, + seq_output_lens=[seq_output_len], + sampling_params=SamplingParams( + # Ensure enough space. + max_tokens=seq_output_len + num_new_tokens, + ignore_eos=True, + ), + ) + + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + new_token_ids = list(range(num_new_tokens)) + assert eos_token_id not in new_token_ids + eos_index = random.randint(0, len(new_token_ids) - 1) + new_token_ids[eos_index] = eos_token_id + + outputs = [ + SequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq.seq_id, + output_token=output_token, + logprobs={output_token: Logprob(0.0)}, + ) + ], + prompt_logprobs=None, + ) for output_token in new_token_ids + ] + + assert seq.get_len() == seq_prompt_len + seq_output_len + output_processor.process_outputs(seq_group, outputs) + + # Expect the processed sequence to go beyond eos. + assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens + + # Expect the correct tokens were appended. + expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens - + seq_output_len] + assert seq.get_token_ids( + )[-len(expected_appended_tokens):] == expected_appended_tokens + + +def mock_tokenizer(eos_token_id=1000): + tokenizer = MagicMock(spec=PreTrainedTokenizer) + tokenizer.eos_token_id = eos_token_id + return tokenizer diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index b5a6fcb7900a3..a8ebd66841eb2 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,4 +1,8 @@ +from itertools import cycle +from typing import List, Tuple + import pytest +from transformers import AutoTokenizer from vllm import SamplingParams @@ -7,18 +11,47 @@ "common_llm_kwargs", [{ # Use a small model for a fast test. - "model": "facebook/opt-125m", - "speculative_model": "facebook/opt-125m", - "num_speculative_tokens": 5, + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip real loading for fast test. + "load_format": "dummy", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, # Required for spec decode. "use_v2_block_manager": True }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 1, + }, + { + # No spec decode. + }, + ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [1]) +# NOTE: We should run more permutations of this test (more BS, more seeds). But +# because our spec decode generates gibberish token ids, the likelihood of +# emitting an invalid token combination is nontrivial. This causes divergence in +# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf- +# start" bytes are emitted. @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_config(test_llm_generator): - output_len = 1024 +def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): + """Run generation with speculative decoding on a batch. Verify the engine + generates the correct number of tokens (via ignore_eos=True), and that the + detokenization matches HF transformers. + """ + output_len = 32 temperature = 0.0 prompts = [ @@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator): "The future of AI is", ] + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + skip_special_tokens=True, + spaces_between_special_tokens=False, + ) + + batch_tokens, batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + # Expect a generation for each prompt in the batch. + assert len(batch_token_ids) == len(prompts) + + # Expect each generation to have expected number of tokens (note + # ignore_eos=True). + assert all(len(token_ids) == output_len for token_ids in batch_token_ids) + + # Expect detokenized string to match. + tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") + for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): + expected_tokens = tok.decode(actual_token_ids) + print(f"{actual_token_ids=}") + assert actual_tokens.strip() == expected_tokens.strip() + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Skip real loading for fast test. + "load_format": "dummy", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Expect failure as spec decode not supported by + # Ray backend. + "worker_use_ray": True, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail(test_llm_generator): + """Verify that speculative decoding with Ray fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + sampling_params = SamplingParams( max_tokens=output_len, ignore_eos=True, temperature=temperature, ) - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for GPU backend"): - get_token_ids_from_llm_generator(test_llm_generator, prompts, - sampling_params) + with pytest.raises(AssertionError, + match="Speculative decoding not yet supported for "): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): +def get_output_from_llm_generator( + llm_generator, prompts, + sampling_params) -> Tuple[List[str], List[List[int]]]: for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] + tokens = [output.outputs[0].text for output in outputs] del llm - return token_ids + return tokens, token_ids diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index f4d44108b47c2..d6edbab579afd 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -125,7 +125,7 @@ def test_same_output_for_single_step(): zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), ) + **single_step_execute_model_data.to_dict(), )[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -219,7 +219,7 @@ def test_same_output_for_multi_step(): continuations=continuations, final_seq_lens=final_seq_lens)) - single_step_output.append( + single_step_output.extend( worker.execute_model(**execute_model_data.to_dict(), )) # Append output tokens to new sequence data. diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 47aff8f575413..0a3110775e2d6 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -6,6 +6,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): execute_model_data, _, _ = create_batch(batch_size, k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 @@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) seen_contexts = [] @@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] exception_secret = 'artifical stop' rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k) + worker.execute_model(**execute_model_data.to_dict(), + num_lookahead_slots=k) assert len(rejection_sampler.call_args_list) == 1 args, _ = rejection_sampler.call_args_list[0] @@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) expected_output = create_sampler_output_list( rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) @@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): target_output = create_sampler_output_list(target_token_ids, target_token_probs) - target_worker.execute_model.return_value = target_output[0] + target_worker.execute_model.return_value = [target_output[0]] rejection_sampler_output = torch.randint(low=0, high=vocab_size, @@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): mock_rejsample_metrics) output = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int): batch_size, k, prev_output_token_len=0) out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" @@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int): 0].sampled_tokens is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) + **execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with( **execute_model_data.to_dict()) @@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int): rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) + target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] + draft_worker.device = 'cuda' target_worker.device = 'cuda' @@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int): batch_size, k, prev_output_token_len=0) out = worker.execute_model(**execute_model_data.to_dict(), - num_spec_tokens=k) + num_lookahead_slots=k) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" @@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int): 0].sampled_tokens is None, "expect gpu tensor references to be None" draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict(), return_python_output=False) + **execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with( **execute_model_data.to_dict()) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index edba4c226b289..d04b6029493f4 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -212,7 +212,7 @@ def create_sampler_output_list( SequenceOutput( output_token=token_id, parent_seq_id=seq_ids[seq_index], - logprobs={token_id: 0}, + logprobs={token_id: Logprob(0)}, ) ], prompt_logprobs=None, diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index ba061bbc4fbcb..560267e55ea3a 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -104,7 +104,6 @@ def append_token_ids(self, token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated - assert token_ids, "can't append empty token ids" self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 18ddcd1d6d466..4198550621030 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -762,9 +762,7 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, ) def _schedule_chunked_prefill(self): @@ -850,9 +848,7 @@ def _schedule_chunked_prefill(self): blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, - num_lookahead_slots=(prefills.num_lookahead_slots + - running_scheduled.num_lookahead_slots + - swapped_in.num_lookahead_slots), + num_lookahead_slots=running_scheduled.num_lookahead_slots, ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1dbf58904541c..27192449bf15a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -217,7 +217,9 @@ async def step_async(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + return self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups) async def encode_request_async( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 563694946d16e..c3de57e249ff8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,5 @@ import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, List, Optional, Type, Union from transformers import PreTrainedTokenizer @@ -11,6 +11,10 @@ from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger @@ -18,8 +22,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupOutput, SequenceOutput, - SequenceStatus) + SequenceGroup) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -187,6 +190,21 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + self.get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + self.get_tokenizer_for_seq, + ), + )) + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -412,240 +430,32 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _check_beam_search_early_stopping( - self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=current_worst_seq.eos_token_id) - if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id)) - return current_worst_score >= highest_attainable_score - - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: - - # Process prompt logprobs - prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and seq_group.sampling_params.detokenize: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, prompt_logprobs) - seq_group.prompt_logprobs = prompt_logprobs - - # Process samples - samples = outputs.samples - parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } - for sample in samples: - parent_child_dict[sample.parent_seq_id].append(sample) - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] - if len(child_samples) == 0: - # This parent sequence has no children samples. Remove - # the parent sequence from the sequence group since it will - # not be used in the future iterations. - parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - self.scheduler.free_seq(parent) - continue - # Fork the parent sequence if there are multiple child samples. - for child_sample in child_samples[:-1]: - new_child_seq_id = next(self.seq_counter) - child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) - child_seqs.append((child, parent)) - # Continue the parent sequence for the last child sample. - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) - child_seqs.append((parent, parent)) - - for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) - else: - new_char_count = 0 - self._check_stop(seq, new_char_count, seq_group.sampling_params) - - # Non-beam search case - if not seq_group.sampling_params.use_beam_search: - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - return - - # Beam search case - # Select the child sequences to keep in the sequence group. - selected_child_seqs = [] - unselected_child_seqs = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty - - # Select the newly finished sequences with the highest scores - # to replace existing finished sequences. - # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] - all_finished_seqs = existing_finished_seqs + new_finished_seqs - # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - for seq, parent, is_new in all_finished_seqs[:beam_width]: - if is_new: - # A newly generated child sequence finishes and has a high - # score, so we will add it into the sequence group. - selected_child_seqs.append((seq, parent)) - for seq, parent, is_new in all_finished_seqs[beam_width:]: - if is_new: - # A newly generated child sequence finishes but has a low - # score, so we will not add it into the sequence group. - # Additionally, if this sequence is a continuation of a - # parent sequence, we will need remove the parent sequence - # from the sequence group. - unselected_child_seqs.append((seq, parent)) - else: - # An existing finished sequence has a low score, so we will - # remove it from the sequence group. - seq_group.remove(seq.seq_id) - - # select the top beam_width sequences from the running - # sequences for the next iteration to continue the beam - # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] - # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - - # Check if we can stop the beam search. - if len(running_child_seqs) == 0: - # No running sequences, stop the beam search. - stop_beam_search = True - elif len(all_finished_seqs) < beam_width: - # Not enough finished sequences, continue the beam search. - stop_beam_search = False - else: - # Check the early stopping criteria - best_running_seq = running_child_seqs[0][0] - current_worst_seq = all_finished_seqs[beam_width - 1][0] - stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) - - if stop_beam_search: - # Stop the beam search and remove all the running sequences from - # the sequence group. - unselected_child_seqs.extend(running_child_seqs) - else: - # Continue the beam search and select the top beam_width sequences - # to continue the beam search. - selected_child_seqs.extend(running_child_seqs[:beam_width]) - # The remaining running sequences will not be used in the next - # iteration. Again, if these sequences are continuations of - # parent sequences, we will need to remove the parent sequences - # from the sequence group. - unselected_child_seqs.extend(running_child_seqs[beam_width:]) - - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in selected_child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - self.scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - for seq, parent in selected_child_seqs: - if seq is parent and seq.is_finished(): - self.scheduler.free_seq(seq) - - # Remove the unselected parent sequences from the sequence group and - # free their memory in block manager. - for seq, parent in unselected_child_seqs: - if seq is parent: - # Remove the parent sequence if it is not selected for next - # iteration - seq_group.remove(seq.seq_id) - self.scheduler.free_seq(seq) - def _process_model_outputs( - self, output: SamplerOutput, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + self, output: List[SamplerOutput], + scheduled_seq_groups: List[SequenceGroup], + ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: + """Apply the model output to the sequences in the scheduled seq groups. + + Returns RequestOutputs that can be returned to the client. + """ + now = time.time() + + # Organize outputs by [sequence group][step] instead of + # [step][sequence group]. + output_by_sequence_group = create_output_by_sequence_group( + sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) + # Update the scheduled sequence groups with the model outputs. - scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output): + for scheduled_seq_group, outputs in zip(scheduled_seq_groups, + output_by_sequence_group): seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) # If uncomputed tokens > 0, it means prefill is chunked. # We don't need to process outputs in that case. if seq_group.get_num_uncomputed_tokens() == 0: - self._process_sequence_group_outputs(seq_group, outputs) + self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() @@ -657,13 +467,9 @@ def _process_model_outputs( seq_group.maybe_set_first_token_time(now) request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - for seq_group in scheduler_outputs.ignored_seq_groups: + for seq_group in ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - - # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) return request_outputs def step(self) -> List[RequestOutput]: @@ -721,13 +527,23 @@ def step(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): output = self.model_executor.execute_model( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots) else: output = [] - return self._process_model_outputs(output, scheduler_outputs) + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups) + + # Log stats. + if self.log_stats: + self.stat_logger.log(self._get_stats(scheduler_outputs)) + + return request_outputs def do_log_stats(self) -> None: """Forced log when no requests active.""" @@ -807,87 +623,6 @@ def _get_stats(self, time_e2e_requests=time_e2e_requests, ) - def _check_stop(self, seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> None: - """Stop the finished sequences. - - new_char_count is the number of chars added to the - sequence's output text for the newly generated token - """ - - # Check if the minimum number of tokens has been generated yet; - # skip the stop string/token checks if not - if seq.get_output_len() < sampling_params.min_tokens: - return - - # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == seq.eos_token_id): - seq.status = SequenceStatus.FINISHED_STOPPED - return - - # Check if a stop token was encountered. - # This assumes a single token produced per step. - last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: - if new_char_count and ( - not sampling_params.include_stop_str_in_output): - # Remove last token - seq.output_text = seq.output_text[:-new_char_count] - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = last_token_id - return - - # Check if any stop strings are matched. - stop_str = self._check_stop_strings(seq, new_char_count, - sampling_params) - if stop_str is not None: - seq.status = SequenceStatus.FINISHED_STOPPED - seq.stop_reason = stop_str - return - - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - @staticmethod - def _check_stop_strings(seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> Optional[str]: - """Check if any stop strings are matched and truncate sequence - output text accordingly. - - Returns the stop string if matched or else None. - """ - if not new_char_count: - return None - - for stop_str in sampling_params.stop: - stop_string_len = len(stop_str) - # Avoid searching already-searched text. - stop_index = seq.output_text.find( - stop_str, -new_char_count - stop_string_len) - if stop_index == -1: - continue - - if sampling_params.include_stop_str_in_output: - # Truncate to end of stop string. - stop_index += stop_string_len - if stop_index >= len(seq.output_text): - # No truncation required. - return stop_str - - # Truncate the output text to either the beginning - # or end of the stop string. - seq.output_text = seq.output_text[:stop_index] - return stop_str - return None - def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 0000000000000..9ddac7a04cb36 --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from typing import Callable, Iterable, List + +from transformers import PreTrainedTokenizer + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor( + scheduler_config, + detokenizer, + scheduler, + seq_counter, + stop_checker, + ) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py new file mode 100644 index 0000000000000..50da0d35fcec1 --- /dev/null +++ b/vllm/engine/output_processor/multi_step.py @@ -0,0 +1,126 @@ +from typing import Callable, Iterable, List + +from transformers import PreTrainedTokenizer + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Logprob, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization, + including freeing finished sequences. It also handles cases where there + are tokens emitted after the EOS token. + """ + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + + assert seqs, "expected running sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + + # Since there's only one sequence per sequence group, we can take the + # first sample. + samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + valid_samples = valid_samples[:remaining_tokens] + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + valid_samples = valid_samples[:i + 1] + break + + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id in output_token_ids: + seq.append_token_id( + token_id=output_token_id, + # TODO emit logprobs in multi-step decoding. + logprobs={output_token_id: Logprob(0.0)}, + ) + + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params) + if seq.is_finished(): + break + + if seq.is_finished(): + self.scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py new file mode 100644 index 0000000000000..1b7eb014f802b --- /dev/null +++ b/vllm/engine/output_processor/single_step.py @@ -0,0 +1,276 @@ +from typing import Iterable, List, Tuple, Union + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer + +logger = init_logger(__name__) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__( + self, + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: Scheduler, + seq_counter: Iterable[int], + stop_checker: StopChecker, + ): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0]) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: + + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None and seq_group.sampling_params.detokenize: + self.detokenizer.decode_prompt_logprobs_inplace( + seq_group, prompt_logprobs) + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + if seq_group.sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, seq_group.sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence(seq, new_char_count, + seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=current_worst_seq.eos_token_id) + if early_stopping is False: + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=best_running_seq.eos_token_id)) + return current_worst_score >= highest_attainable_score diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 0000000000000..66deb9b591746 --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,101 @@ +from typing import Callable, Optional + +from transformers import PreTrainedTokenizer + +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], + PreTrainedTokenizer]): + self.max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 0000000000000..5fbb09a857a46 --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,16 @@ +from typing import List + +from vllm.sequence import SamplerOutput + + +def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], + num_seq_groups: int): + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group = [[] for _ in range(num_seq_groups)] + for step in sampler_outputs: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + return output_by_sequence_group diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 426e2c41d8427..f925a6fc93dcd 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -74,7 +74,8 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 8cc04c5299ca1..1839b5603ff3e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -72,8 +72,9 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences.""" raise NotImplementedError @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3a9537effe6d9..6e4a765e2ffd5 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -13,13 +13,17 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: - assert (not self.speculative_config - ), "Speculative decoding not yet supported for GPU backend" + """Initialize the worker and load the model. - # Instantiate the worker and load the model to GPU. - self._init_worker() + If speculative decoding is enabled, we instead create the speculative + worker. + """ + if self.speculative_config is None: + self._init_non_spec_worker() + else: + self._init_spec_worker() - def _init_worker(self): + def _init_non_spec_worker(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker @@ -46,6 +50,57 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() + def _init_spec_worker(self): + """Initialize a SpecDecodeWorker, using a draft model for proposals. + """ + assert self.speculative_config is not None + + from vllm.spec_decode.multi_step_worker import MultiStepWorker + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + from vllm.worker.worker import Worker + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + target_worker = Worker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + draft_worker = MultiStepWorker( + model_config=self.speculative_config.draft_model_config, + parallel_config=self.speculative_config.draft_parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=True, + ) + + spec_decode_worker = SpecDecodeWorker.from_workers( + proposer_worker=draft_worker, scorer_worker=target_worker) + + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = spec_decode_worker + + # Load model handled in spec decode worker. + self.driver_worker.init_device() + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. @@ -63,16 +118,20 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, + ) -> List[SamplerOutput]: output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + num_lookahead_slots=num_lookahead_slots, ) return output diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 273b17a927efd..7cc187e297c9f 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -48,10 +48,13 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int) -> List[SamplerOutput]: assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} and blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") + assert num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4065c0868d79a..5f859fdc9c078 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -242,7 +242,8 @@ def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int = 0) -> SamplerOutput: all_outputs = self._run_workers( "execute_model", driver_kwargs={ diff --git a/vllm/sequence.py b/vllm/sequence.py index dcde81df19923..92362a9a5d2a3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -693,3 +693,16 @@ def __len__(self): def __eq__(self, other: object): return isinstance(other, self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index e0b75837e8a39..88af1dd360155 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,10 +6,10 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, - sampler_output_to_torch, +from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors, + nvtx_range, sampler_output_to_torch, split_batch_by_proposal_len) -from vllm.worker.worker import Worker +from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: Worker, device: str, vocab_size: int): + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): self._scorer_worker = scorer_worker self._device = device self._vocab_size = vocab_size @@ -83,7 +84,9 @@ def score_proposals( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - return_python_output=False) + ) + assert len(target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] all_tokens, all_probs = self._contract_batch( original_bs=len(seq_group_metadata_list), @@ -142,6 +145,16 @@ def _contract_batch(self, original_bs: int, This maps the scores of speculative tokens back to their original sequences. """ + + # We mock the device tensors until PR 7/9 is merged (e2e correctness). + # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + maybe_mock_device_tensors( + sampler_output=target_sampler_output, + batch_size=len(non_spec_indices) + num_scoring_tokens, + vocab_size=self._vocab_size, + device=self._device, + ) + (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 73b6e201c67a9..ce63c329a40aa 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,7 +6,8 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.spec_decode.util import (maybe_mock_device_tensors, + sampler_output_to_torch) from vllm.worker.worker import Worker @@ -69,6 +70,9 @@ def execute_model_multi_step( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] self._append_new_tokens(model_output, copied_seq_group_metadata_list) @@ -341,6 +345,16 @@ def _merge_outputs( sampler_output = maybe_sampler_output + # We mock the device tensors until PR 7/9 is merged (e2e correctness). + # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + for step_output in sampler_output: + maybe_mock_device_tensors( + sampler_output=step_output, + batch_size=len(proposal_lens), + vocab_size=self._vocab_size, + device=self._device, + ) + proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 885bf537568e3..be3af7be93864 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,8 +3,9 @@ import torch +from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -13,8 +14,9 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) -from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase + +logger = init_logger(__name__) class SpecDecodeWorker(LoraNotSupportedWorkerBase): @@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. """ + @classmethod + def from_workers(cls, proposer_worker: MultiStepWorker, + scorer_worker: WorkerBase) -> "SpecDecodeWorker": + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + # TODO(cade) disable strict mode for speedup. + rejection_sampler=RejectionSampler(strict_mode=True), + ) + def __init__( self, proposer_worker: MultiStepWorker, - scorer_worker: Worker, + scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, ): @@ -87,6 +99,10 @@ def init_device(self) -> None: self.scorer_worker.init_device() self.proposer_worker.init_device() + # NOTE(cade): load_model is not part of the WorkerBase interface. + self.scorer_worker.load_model() + self.proposer_worker.load_model() + self._metrics.init_gpu_tensors(self.rank) self.rejection_sampler.init_gpu_tensors(self.rank) self.scorer = BatchExpansionTop1Scorer( @@ -131,7 +147,7 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]], blocks_to_swap_out: Optional[Dict[int, int]], blocks_to_copy: Optional[Dict[int, List[int]]], - num_spec_tokens: int, + num_lookahead_slots: int, ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ @@ -140,9 +156,11 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}") + # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. - if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0: + if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: return self._run_no_spec( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -155,7 +173,7 @@ def execute_model( blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - k=num_spec_tokens, + k=num_lookahead_slots, ) @nvtx_range("spec_decode_worker._run_no_spec") @@ -170,20 +188,24 @@ def _run_no_spec( proposer and scorer model so that the KV cache is consistent between the two. """ + logger.info("run proposer worker no spec") self.proposer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - return_python_output=False) + ) + logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ) + assert len(sampler_output) == 1 + sampler_output = sampler_output[0] # Clear device tensors from sampler output. This reduces communication # overhead when the engine runs in a different process than the workers. @@ -209,11 +231,13 @@ def _run_speculative_decoding_step( sequence. """ + logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) + logger.info("score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, blocks_to_swap_in, @@ -223,9 +247,11 @@ def _run_speculative_decoding_step( proposals, ) + logger.info("verify proposals") accepted_token_ids = self._verify_tokens(seq_group_metadata_list, proposal_scores, proposals, k) + logger.info("create output list") return self._create_output_sampler_list(seq_group_metadata_list, accepted_token_ids, k) @@ -311,7 +337,7 @@ def _create_output_sampler_list( parent_seq_id=seq_id, output_token=token_id, # TODO Add verifier logprobs. - logprobs={token_id: 0.0}, + logprobs={token_id: Logprob(0.0)}, ) ], prompt_logprobs=None, diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 406568a4bc08c..eb6d4ca1da8e6 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -82,6 +82,32 @@ def sampler_output_to_torch( return sampled_token_ids, sampled_token_probs +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). + return + + # Softmax to ensure valid probs. + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, + size=(batch_size, ), + dtype=torch.long, + device=device) + + @contextmanager def nvtx_range(msg, *args, **kwargs): """ diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 6610b9c4be876..afc4a1e1f4630 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -251,7 +251,7 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) @@ -274,11 +274,13 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.cpu_cache) - return output + + # CPU worker only supports single-step execution. + return [output] def init_distributed_environment(self) -> None: """Initialize the distributed environment.""" diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 2f22f82c045db..142c6c97f5194 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,5 +1,5 @@ """A Neuron worker class.""" -from typing import List, Optional, Tuple +from typing import List, Tuple import torch import torch.distributed @@ -73,15 +73,18 @@ def initialize_cache(self, num_gpu_blocks: int, def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Optional[SamplerOutput]: + ) -> List[SamplerOutput]: num_seq_groups = len(seq_group_metadata_list) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list) - return output + + # Neuron worker only supports single-step output. Wrap the output in a + # list to conform to interface. + return [output] def get_cache_block_size_bytes(self) -> int: """Determine the size in bytes of a cache block. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6a79285f60579..e2b47530d41e4 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -210,7 +210,9 @@ def execute_model( blocks_to_swap_in: Optional[Dict[int, int]] = None, blocks_to_swap_out: Optional[Dict[int, int]] = None, blocks_to_copy: Optional[Dict[int, List[int]]] = None, - ) -> Optional[SamplerOutput]: + num_lookahead_slots: int = 0, + ) -> List[SamplerOutput]: + if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups = len(seq_group_metadata_list) @@ -235,11 +237,14 @@ def execute_model( # If there is no input, we don't need to execute the model. if num_seq_groups == 0: - return {} + return [] output = self.model_runner.execute_model(seq_group_metadata_list, self.gpu_cache) - return output + + # Worker only supports single-step execution. Wrap the output in a list + # to conform to interface. + return [output] def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d8c9febb11584..a92f5aea76059 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -40,12 +40,13 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: - """Executes one model step on the given sequences.""" + def execute_model( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, + int], + blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" raise NotImplementedError @abstractmethod From d150e4f89f9da5b093511227449ad940a9a82a52 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 17:56:01 -0700 Subject: [PATCH 054/413] [Misc] [CI] Fix CI failure caught after merge (#4126) --- vllm/executor/gpu_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 6e4a765e2ffd5..77c997f97956e 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -68,6 +68,7 @@ def _init_spec_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -82,6 +83,7 @@ def _init_spec_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + load_config=self.load_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, From 11d652bd4f6f81d09638399885099b78a4e3b9c8 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 16 Apr 2024 22:53:26 -0700 Subject: [PATCH 055/413] [CI] Move CPU/AMD tests to after wait (#4123) --- .buildkite/test-template.j2 | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 3ed23c62c005d..0e1acc9777d4b 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -3,13 +3,6 @@ {% set default_working_dir = "/vllm-workspace/tests" %} steps: - - label: "AMD Test" - agents: - queue: amd - command: bash .buildkite/run-amd-test.sh - - - label: "CPU Test" - command: bash .buildkite/run-cpu-test.sh - label: ":docker: build image" commands: @@ -23,6 +16,14 @@ steps: limit: 5 - wait + - label: "AMD Test" + agents: + queue: amd + command: bash .buildkite/run-amd-test.sh + + - label: "CPU Test" + command: bash .buildkite/run-cpu-test.sh + {% for step in steps %} - label: "{{ step.label }}" agents: From 8438e0569eaf8496aa3d41deb808f2c831b64ecf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 17 Apr 2024 01:34:33 -0700 Subject: [PATCH 056/413] [Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (#4024) [Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024) --- tests/distributed/test_pynccl.py | 7 +- vllm/engine/ray_utils.py | 44 ++------- vllm/executor/ray_gpu_executor.py | 156 ++++++++++++++++-------------- vllm/utils.py | 16 ++- vllm/worker/cpu_worker.py | 5 +- vllm/worker/neuron_worker.py | 4 + vllm/worker/worker.py | 4 + vllm/worker/worker_base.py | 56 +++++++++++ 8 files changed, 176 insertions(+), 116 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b50eed1c8c722..d58f621d36b86 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,18 +1,18 @@ import multiprocessing -import os import pytest import torch from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) +from vllm.utils import update_environment_variables def distributed_run(fn, world_size): number_of_processes = world_size processes = [] for i in range(number_of_processes): - env = os.environ.copy() + env = {} env['RANK'] = str(i) env['LOCAL_RANK'] = str(i) env['WORLD_SIZE'] = str(number_of_processes) @@ -32,8 +32,7 @@ def update_env(fn): # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapper(env): - import os - os.environ.update(env) + update_environment_variables(env) fn() return wrapper diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 04d4ed83976d0..febae42b84549 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,55 +1,28 @@ import pickle -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_ip, is_hip, set_cuda_visible_devices -from vllm.worker.worker import Worker +from vllm.utils import get_ip, is_hip +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) try: import ray - class RayWorkerVllm: + class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" - def __init__(self, init_cached_hf_modules=False) -> None: - if init_cached_hf_modules: - from transformers.dynamic_module_utils import init_hf_modules - init_hf_modules() - self._worker: Optional[Worker] = None + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn: Callable[[], Worker]): - self._worker = worker_init_fn() - - @property - def worker(self) -> Worker: - assert self._worker is not None - return self._worker - - def __getattr__(self, name): - return getattr(self.worker, name) - - def execute_method(self, method, *args, **kwargs): - try: - executor = getattr(self, method) - return executor(*args, **kwargs) - except Exception as e: - # exceptions in ray worker may cause deadlock - # see https://github.com/vllm-project/vllm/issues/3455 - # print the error and inform the user to solve the error - msg = (f"Error executing method {method}. " - "This might cause deadlock in distributed execution.") - logger.exception(msg) - raise e - def get_node_ip(self) -> str: return get_ip() @@ -58,9 +31,6 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: gpu_ids = ray.get_gpu_ids() return node_id, gpu_ids - def set_cuda_visible_devices(self, device_ids) -> None: - set_cuda_visible_devices(device_ids) - def execute_model_compiled_dag_remote(self, ignored): """Used only when compiled DAG is enabled.""" import torch @@ -77,7 +47,7 @@ def execute_model_compiled_dag_remote(self, ignored): "For distributed inference, please install Ray with " "`pip install ray`.") ray = None # type: ignore - RayWorkerVllm = None # type: ignore + RayWorkerWrapper = None # type: ignore def initialize_ray_cluster( diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5f859fdc9c078..5a43f1fc28a84 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -1,17 +1,16 @@ import asyncio -import copy import os import pickle from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.engine.ray_utils import RayWorkerVllm, ray +from vllm.engine.ray_utils import RayWorkerWrapper, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async, set_cuda_visible_devices) + make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -74,9 +73,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. - self.driver_dummy_worker: RayWorkerVllm = None + self.driver_dummy_worker: RayWorkerWrapper = None # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerVllm] = [] + self.workers: List[RayWorkerWrapper] = [] if self.parallel_config.ray_workers_use_nsight: ray_remote_kwargs = self._configure_ray_workers_use_nsight( @@ -97,13 +96,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + )(RayWorkerWrapper).remote( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", + ) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: # If the worker is on the same node as the driver, we use it # as the resource holder for the driver process. self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", + ) else: # Else, added to the list of workers. self.workers.append(worker) @@ -115,82 +121,56 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "GPU node.") # Get the set of GPU IDs used on each node. - driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) - worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) node_workers = defaultdict(list) node_gpus = defaultdict(list) - node_workers[driver_node_id].append(0) - node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) # Set CUDA_VISIBLE_DEVICES for the driver and workers. - set_cuda_visible_devices(node_gpus[driver_node_id]) - for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): - worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + all_args_to_update_environment_variables = [] + for (node_id, _) in worker_node_and_gpu_ids: + all_args_to_update_environment_variables.append([{ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])) + }]) + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - load_config = copy.deepcopy(self.load_config) - device_config = copy.deepcopy(self.device_config) - lora_config = copy.deepcopy(self.lora_config) - cache_config = copy.deepcopy(self.cache_config) - vision_language_config = copy.deepcopy(self.vision_language_config) - - # Initialize the actual workers with the Worker class. - for rank, (worker, (node_id, _)) in enumerate( - zip(self.workers, worker_node_and_gpu_ids), - start=1, - ): + def collect_arg_helper_func(**kwargs): + # avoid writing `{"name": value}` manually + return kwargs + + init_worker_all_kwargs = [] + + # Initialize the actual workers inside worker wrapper. + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): local_rank = node_workers[node_id].index(rank) - worker.init_worker.remote( - lambda rank=rank, local_rank=local_rank: Worker( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - cache_config=cache_config, - load_config=load_config, + init_worker_all_kwargs.append( + collect_arg_helper_func( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, - lora_config=lora_config, - vision_language_config=vision_language_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=rank == 0, )) - - # Initialize the driver worker with the Worker class. - driver_rank = 0 - driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - local_rank=driver_local_rank, - rank=driver_rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - load_config=self.load_config, - is_driver_worker=True, - ) + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") self._run_workers( @@ -279,13 +259,35 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, + driver_args: Optional[Tuple[Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, + all_args: Optional[List[List[Any]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + all_args and all_kwargs are used to pass heterogeneous arguments, + i.e. different arguments for each worker. + """ + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # for mypy type checking + assert driver_args is not None + assert driver_kwargs is not None + if all_args is None: + all_args = [driver_args] + [args] * len(self.workers) + if all_kwargs is None: + all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers) + + # for mypy type checking + assert all_args is not None + assert all_kwargs is not None if max_concurrent_workers: raise NotImplementedError( @@ -299,8 +301,10 @@ def _run_workers( else: # Start the ray workers first. ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers + worker.execute_method.remote(method, *worker_args, + **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_args[1:], all_kwargs[1:]) ] if driver_args is None: @@ -309,9 +313,13 @@ def _run_workers( driver_kwargs = kwargs # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) - + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *all_args[0], **all_kwargs[0]) + else: + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *all_args[0], **all_kwargs[0])) # Get the results of the ray workers. if self.workers: if use_ray_compiled_dag: @@ -386,8 +394,12 @@ async def _run_workers_async( driver_kwargs = kwargs # Run the driver worker asynchronously. - driver_executor = make_async(getattr(self.driver_worker, method)) - coros.append(driver_executor(*driver_args, **driver_kwargs)) + def helper(): + return self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + + driver_executor = make_async(helper) + coros.append(driver_executor()) # Run the ray workers asynchronously. for worker in self.workers: diff --git a/vllm/utils.py b/vllm/utils.py index aad62516ad1b9..e132575e7bf81 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -271,8 +271,12 @@ def get_open_port() -> int: return s.getsockname()[1] -def set_cuda_visible_devices(device_ids: List[int]) -> None: - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ: + logger.warning(f"Overwriting environment variable {k} " + f"from '{os.environ[k]}' to '{v}'") + os.environ[k] = v def chunk_list(lst, chunk_size): @@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]], merged_dict[key].extend(value) return dict(merged_dict) + + +def init_cached_hf_modules(): + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index afc4a1e1f4630..8468ace5a2fdc 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -138,7 +138,10 @@ def __init__( self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.model_runner = CPUModelRunner(model_config, parallel_config, scheduler_config, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 142c6c97f5194..d0e6aaed180e6 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -29,6 +29,10 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.model_runner = NeuronModelRunner(model_config, parallel_config, scheduler_config, device_config) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e2b47530d41e4..b021866965401 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -60,6 +60,10 @@ def __init__( if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() self.vision_language_config = vision_language_config if self.vision_language_config: assert not self.lora_config, ( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a92f5aea76059..309aa6256acea 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,8 +1,14 @@ +import importlib +import os from abc import ABC, abstractmethod from typing import Dict, List, Tuple +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import update_environment_variables + +logger = init_logger(__name__) class WorkerBase(ABC): @@ -82,3 +88,53 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> List[int]: raise ValueError(f"{type(self)} does not support LoRA") + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__(self, + worker_module_name=None, + worker_class_name=None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker = None + + def update_environment_variables(self, envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Actual initialization of the worker class. + Arguments are passed to the worker class constructor. + """ + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + self.worker = worker_class(*args, **kwargs) + + def execute_method(self, method, *args, **kwargs): + try: + if hasattr(self, method): + executor = getattr(self, method) + else: + executor = getattr(self.worker, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e From fe3b5bbc23a99533bc7d4a94ae073828ed025974 Mon Sep 17 00:00:00 2001 From: Elinx Date: Wed, 17 Apr 2024 19:07:23 +0800 Subject: [PATCH 057/413] [Bugfix] fix output parsing error for trtllm backend (#4137) Co-authored-by: Roger Wang --- benchmarks/backend_request_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index bab570252c929..f9d167590fe47 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -135,6 +135,7 @@ async def async_request_trt_llm( "data:") data = json.loads(chunk) + output.generated_text += data["text_output"] timestamp = time.perf_counter() # First token if ttft == 0.0: @@ -149,7 +150,6 @@ async def async_request_trt_llm( most_recent_timestamp = timestamp output.latency = most_recent_timestamp - st - output.generated_text = json.loads(data)["text_output"] output.success = True else: From a53222544c6385ee314a26fdf42eb14f5b4e5ad9 Mon Sep 17 00:00:00 2001 From: Shoichi Uchinami Date: Thu, 18 Apr 2024 02:02:45 +0900 Subject: [PATCH 058/413] [Kernel] Add punica dimension for Swallow-MS-7B LoRA (#4134) --- csrc/punica/bgmv/bgmv_config.h | 1 + tests/lora/test_punica.py | 1 + 2 files changed, 2 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index d2906914f927e..fec484d693055 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -60,6 +60,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 32768) \ f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ f(in_T, out_T, W_T, narrow, 64000) \ f(in_T, out_T, W_T, narrow, 64256) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 8b174f01d87d4..f3b9bd5912967 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -82,6 +82,7 @@ def _lora_ref_impl( 32768, 33024, 36864, + 43264, 49152, 64000, 64256, From 533d2a1f3962c45ddebbdeb0ff1cb7cd54d5e771 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 18 Apr 2024 09:28:43 +0900 Subject: [PATCH 059/413] [Typing] Mypy typing part 2 (#4043) Co-authored-by: SangBin Cho --- .github/workflows/mypy.yaml | 8 +- format.sh | 8 +- vllm/engine/async_llm_engine.py | 44 +++++---- vllm/lora/worker_manager.py | 4 +- .../guided_decoding/outlines_decoding.py | 4 +- .../outlines_logits_processors.py | 6 +- vllm/model_executor/model_loader/neuron.py | 16 ++-- .../model_executor/model_loader/tensorizer.py | 1 + vllm/model_executor/sampling_metadata.py | 4 + vllm/spec_decode/batch_expansion.py | 6 +- vllm/spec_decode/interfaces.py | 4 +- vllm/spec_decode/metrics.py | 1 + vllm/spec_decode/multi_step_worker.py | 21 ++-- vllm/spec_decode/spec_decode_worker.py | 6 +- vllm/worker/cpu_model_runner.py | 25 +++-- vllm/worker/cpu_worker.py | 11 ++- vllm/worker/model_runner.py | 96 +++++++++++-------- vllm/worker/neuron_model_runner.py | 22 +++-- vllm/worker/worker.py | 11 ++- vllm/worker/worker_base.py | 8 +- 20 files changed, 180 insertions(+), 126 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 6db0bb7645ecd..477ce9bc9ce85 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -41,10 +41,10 @@ jobs: mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml # TODO(sang): Follow up - # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml - # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml - # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml - # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/format.sh b/format.sh index 1c195b899c742..84ee88b5b4c8a 100755 --- a/format.sh +++ b/format.sh @@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml # TODO(sang): Follow up -# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 27192449bf15a..c3020d2b38db0 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import os import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, Union) +from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, + Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -52,7 +52,7 @@ class AsyncStream: def __init__(self, request_id: str) -> None: self.request_id = request_id - self._queue = asyncio.Queue() + self._queue: asyncio.Queue = asyncio.Queue() self._finished = False def put(self, item: Union[RequestOutput, Exception]) -> None: @@ -312,15 +312,17 @@ def __init__(self, self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) - self.background_loop = None + self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected - self._background_loop_unshielded = None + self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None self.start_engine_loop = start_engine_loop - self._request_tracker: Optional[RequestTracker] = None self._errored_with: Optional[BaseException] = None + # Lazy initialized fields + self._request_tracker: RequestTracker + @classmethod def from_engine_args( cls, @@ -361,11 +363,13 @@ def from_engine_args( @property def is_running(self) -> bool: return (self.background_loop is not None + and self._background_loop_unshielded is not None and not self._background_loop_unshielded.done()) @property def is_stopped(self) -> bool: - return self.errored or (self.background_loop is not None + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None and self._background_loop_unshielded.done()) @property @@ -381,7 +385,7 @@ def _error_callback(self, exc: Exception) -> None: async def get_tokenizer(self) -> "PreTrainedTokenizer": if self.engine_use_ray: - return await self.engine.get_tokenizer.remote() + return await self.engine.get_tokenizer.remote() # type: ignore else: return self.engine.get_tokenizer() @@ -434,7 +438,8 @@ async def engine_step(self) -> bool: # TODO: Maybe add add_request_batch to reduce Ray overhead try: if self.engine_use_ray: - await self.engine.add_request.remote(**new_request) + await self.engine.add_request.remote( # type: ignore + **new_request) else: await self.engine.add_request_async(**new_request) except ValueError as e: @@ -449,7 +454,7 @@ async def engine_step(self) -> bool: await self._engine_abort(finished_requests) if self.engine_use_ray: - request_outputs = await self.engine.step.remote() + request_outputs = await self.engine.step.remote() # type: ignore else: request_outputs = await self.engine.step_async() @@ -462,7 +467,7 @@ async def engine_step(self) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: - await self.engine.abort_request.remote(request_ids) + await self.engine.abort_request.remote(request_ids) # type: ignore else: self.engine.abort_request(request_ids) @@ -525,11 +530,12 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await self.engine.encode_request_async.remote( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) + prompt_token_ids = await ( + self.engine.encode_request_async.remote( # type: ignore + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request)) else: prompt_token_ids = await self.engine.encode_request_async( request_id=request_id, @@ -676,13 +682,13 @@ def _abort(self, request_id: str) -> None: async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: - return await self.engine.get_model_config.remote() + return await self.engine.get_model_config.remote() # type: ignore else: return self.engine.get_model_config() async def do_log_stats(self) -> None: if self.engine_use_ray: - await self.engine.do_log_stats.remote() + await self.engine.do_log_stats.remote() # type: ignore else: self.engine.do_log_stats() @@ -695,7 +701,7 @@ async def check_health(self) -> None: if self.engine_use_ray: try: - await self.engine.check_health.remote() + await self.engine.check_health.remote() # type: ignore except ray.exceptions.RayActorError as e: raise RuntimeError("Engine is dead.") from e else: diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index a0868defbd3ca..5356b79537b05 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -107,12 +107,12 @@ def create_lora_manager( self._lora_manager: LoRAModelManager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: self._apply_loras(lora_requests) self._lora_manager.set_lora_mapping(lora_mapping) - def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: loras_that_exist = self.list_loras() loras_map = { lora_request.lora_int_id: lora_request diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bd4564a36e1ed..53efebb604048 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -55,7 +55,7 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: + tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor( def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest] -) -> Tuple[str, GuidedDecodingMode]: +) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: json = request.guided_json diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 28041695546dc..95a67b612f08b 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,7 +21,7 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Union import torch -from outlines.fsm.fsm import CFGFSM, RegexFSM +from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -29,6 +29,10 @@ class BaseLogitsProcessor: + def __init__(self): + # Child class should use initialize in their init. + self.fsm: FSM + def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 43d17ad373b87..07e23aca6cc5f 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading neuron models.""" import importlib import os -from typing import Optional, Type +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -27,7 +27,7 @@ } # Models supported by Neuron. -_NEURON_SUPPORTED_MODELS = { +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { "LlamaForCausalLM": ("transformers_neuronx.llama.model", "LlamaForSampling", "LlamaForCausalLM"), "MistralForCausalLM": ("transformers_neuronx.mistral.model", @@ -43,11 +43,13 @@ def __init__( ) -> None: super().__init__() self.config = config - self.model = None self.logits_processor = LogitsProcessor(config.vocab_size, logits_as_input=True) self.sampler = Sampler() + # Lazy initialized + self.model: nn.Module + def forward( self, input_ids: torch.Tensor, @@ -74,17 +76,17 @@ def sample( def load_weights(self, model_name_or_path: str, **kwargs): arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls, hf_model_cls = ( + neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( _NEURON_SUPPORTED_MODELS[arch]) neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) split_model_dir = f"{model_name_or_path}-split" if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): split_model_dir = model_name_or_path elif not os.path.exists(f"{model_name_or_path}-split"): - hf_model_cls = getattr(transformers, hf_model_cls) + hf_model_cls = getattr(transformers, hf_model_cls_name) from transformers_neuronx.module import save_pretrained_split hf_model = hf_model_cls.from_pretrained(model_name_or_path, @@ -96,7 +98,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() -def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: +def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: if arch in _NEURON_SUPPORTED_MODELS: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index ad554844384eb..16be0ecf9ce07 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -167,6 +167,7 @@ def __post_init__(self): decryption_params = DecryptionParams.from_key(key) self.deserializer_params['encryption'] = decryption_params + @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Tensorizer CLI arguments""" diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 534cb75c2fd2f..31032c4cead20 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -113,6 +113,8 @@ def from_sampling_metadata( get_num_triton_sampler_splits(vocab_size)) sample_indices_start_idx = 0 + assert sampling_metadata.seq_groups is not None + assert sampling_metadata.seq_data is not None for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -147,6 +149,7 @@ def from_sampling_metadata( and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs + assert sampling_metadata.prompt_lens is not None prompt_len = sampling_metadata.prompt_lens[i] temperatures += [temperature] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1) @@ -172,6 +175,7 @@ def from_sampling_metadata( is_prompt = i < sampling_metadata.num_prompts if is_prompt: prompt_best_of.append(sampling_params.best_of) + assert sampling_metadata.prompt_lens is not None prompt_len = sampling_metadata.prompt_lens[i] if sampling_params.prompt_logprobs is not None: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 88af1dd360155..bbc5b1778854f 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -106,7 +106,7 @@ def score_proposals( def _expand_batch( self, seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids_list: List[TokenId], + proposal_token_ids_list: List[List[TokenId]], proposal_lens_list: List[int], ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: """Given the input sequences and potentially multiple corresponding @@ -218,7 +218,7 @@ def _create_scoring_model_input( def _create_target_seq_group_metadata( self, input_seq_group_metadata: SequenceGroupMetadata, - proposal_token_ids: List[TokenId], # shape: [batch_size, k] + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] batch_index: int, target_seq_ids_iter: Iterator[TargetSeqId], ) -> List[SequenceGroupMetadata]: @@ -360,7 +360,7 @@ def _get_token_ids_to_score( [0, 1, 2] [0, 1, 2, 3] """ - empty_token_ids = [] + empty_token_ids: List[TokenId] = [] token_ids_to_score = [empty_token_ids] token_ids_to_score.extend([ diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 2a72974d01bdc..f0715120192e5 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import torch @@ -73,5 +73,5 @@ def score_proposals( blocks_to_copy: Optional[Dict[int, List[int]]], k: int, proposals: SpeculativeProposals, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 5df8fc4316d48..d1e72b6640548 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -112,6 +112,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: Returns a CUDA event recording when the copy is complete. """ + assert self._copy_stream is not None self._copy_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._copy_stream): diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index ce63c329a40aa..8b722476853fa 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -26,7 +26,8 @@ class MultiStepWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._proposer: Optional[DraftModelTop1Proposer] = None + # Lazy initialization list. + self._proposer: DraftModelTop1Proposer def init_device(self): super().init_device() @@ -338,10 +339,10 @@ def _merge_outputs( self._vocab_size, dtype=torch.float32, device=self._device) - proposal_lens = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) - return proposal_tokens, proposal_probs, proposal_lens + proposal_lens_tensor = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output @@ -376,9 +377,9 @@ def _merge_outputs( proposal_tokens, proposal_probs = (entire_proposal_tokens, entire_proposal_probs) - proposal_lens = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens[nonzero_proposal_len_indices] = max_proposal_len + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len - return proposal_tokens, proposal_probs, proposal_lens + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index be3af7be93864..68a2a774ef4b7 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -89,7 +89,8 @@ def __init__( self.probs_dtype = self.rejection_sampler.probs_dtype self.token_id_dtype = self.rejection_sampler.token_id_dtype - self.scorer: SpeculativeScorer = None + # Lazy initiazliation. + self.scorer: SpeculativeScorer def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -233,6 +234,9 @@ def _run_speculative_decoding_step( logger.info("get spec proposals") # Generate proposals using draft worker. + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None proposals = self.proposer_worker.get_spec_proposals( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d378e3a90e1e7..7377c8931cefa 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Tuple import torch +from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -48,14 +49,15 @@ def __init__( if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None - self.block_size = None # Set after initial profiling. - self.kv_cache_dtype = kv_cache_dtype self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) + # Lazy initialization. + self.model: nn.Module # Set after init_Model + self.block_size: int # Set after initial profiling. + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, @@ -245,7 +247,11 @@ def _prepare_sample( selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices: Dict[SamplingType, + List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 @@ -262,10 +268,9 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) + sampling_params.sampling_type].append( + (categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 @@ -328,7 +333,7 @@ def _prepare_sample( def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata]: if self.is_driver_worker: @@ -381,7 +386,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 8468ace5a2fdc..3652830b7d519 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed @@ -152,8 +152,8 @@ def __init__( is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine = None - self.cpu_cache = None + self.cache_engine: CPUCacheEngine + self.cpu_cache: List[torch.Tensor] def init_device(self) -> None: self.init_distributed_environment() @@ -257,13 +257,13 @@ def execute_model( ) -> List[SamplerOutput]: if self.is_driver_worker: assert seq_group_metadata_list is not None - num_seq_groups = len(seq_group_metadata_list) + num_seq_groups: int = len(seq_group_metadata_list) assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None assert blocks_to_copy is not None assert len(blocks_to_swap_in) == 0 assert len(blocks_to_swap_out) == 0 - data = { + data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_copy": blocks_to_copy, } @@ -273,6 +273,7 @@ def execute_model( num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] + assert blocks_to_copy is not None self.cache_copy(blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 42c06a1b19361..31e08789dfd1f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -128,23 +128,17 @@ def __init__( if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None - self.block_size = None # Set after initial profiling. - self.lora_manager = None + # Set after load_model. + self.lora_manager: LRUCacheWorkerLoRAManager = None self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool = None # Set during graph capture. + self.graph_memory_pool: Optional[Tuple[ + int, int]] = None # Set during graph capture. self.max_context_len_to_capture = ( self.model_config.max_context_len_to_capture if self.model_config is not None else 0) - # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables = None # Set after initial profiling. + self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config @@ -152,6 +146,17 @@ def __init__( self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) + # Lazy initialization + self.model: torch.nn.Module # Set after load_model + self.block_size: int # Set after initial profiling. + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables: torch.Tensor # Set after initial profiling. + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -489,16 +494,16 @@ def _prepare_decode( lora_index_mapping.append(0) batch_size = graph_batch_size - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == len(input_tokens) - assert context_lens.shape[0] == len(input_positions) - assert context_lens.shape[0] == len(slot_mapping) + assert context_lens_tensor.shape[0] == len(input_tokens) + assert context_lens_tensor.shape[0] == len(input_positions) + assert context_lens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -527,7 +532,7 @@ def _prepare_decode( max_prompt_len=None, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens, + context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) @@ -551,7 +556,11 @@ def _prepare_sample( selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices: Dict[SamplingType, + List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 @@ -569,10 +578,9 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) + sampling_params.sampling_type].append( + (categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 @@ -596,15 +604,16 @@ def _prepare_sample( categorized_sample_indices[ sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) + list( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs)))) categorized_sample_indices_start_idx += num_seqs categorized_sampled_token_indices_start_idx += num_seqs @@ -641,9 +650,9 @@ def _prepare_sample( def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[int], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] @@ -741,6 +750,7 @@ def prepare_input_tensors( if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) else: + assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) @@ -809,7 +819,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, @@ -923,7 +933,7 @@ def remove_all_loras(self) -> bool: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_all_loras() - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -1065,10 +1075,16 @@ class CUDAGraphRunner: def __init__(self, model: nn.Module): self.model = model - self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + def capture( self, input_ids: torch.Tensor, @@ -1078,7 +1094,7 @@ def capture( memory_pool, **kwargs, ) -> None: - assert self.graph is None + assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). @@ -1095,8 +1111,8 @@ def capture( # Capture the graph. # NOTE(woosuk): Python 3.8 does not support multi-line with statements. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 with _maybe_pynccl(): hidden_states = self.model( input_ids, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f70a7193effeb..487df334d73e3 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Tuple import torch +from torch import nn from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -34,9 +35,11 @@ def __init__( self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None self.pin_memory = is_pin_memory_available() + # Lazy initialization. + self.model: nn.Module # initialize after load_model. + def load_model(self) -> None: self.model = get_neuron_model(self.model_config, parallel_config=self.parallel_config, @@ -147,7 +150,11 @@ def _prepare_sample( selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices: Dict[SamplingType, + List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 @@ -165,10 +172,9 @@ def _prepare_sample( categorized_sample_indices_start_idx += prompt_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) + sampling_params.sampling_type].append( + (categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 @@ -237,7 +243,7 @@ def _prepare_sample( def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -259,7 +265,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_block_ids, sampling_metadata ) = self.prepare_input_tensors(seq_group_metadata_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b021866965401..2203570b37ad6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import torch import torch.distributed @@ -82,8 +82,8 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine = None - self.gpu_cache = None + self.cache_engine: CacheEngine + self.gpu_cache: List[torch.Tensor] def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -223,7 +223,7 @@ def execute_model( assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None assert blocks_to_copy is not None - data = { + data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, @@ -237,6 +237,9 @@ def execute_model( blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 309aa6256acea..13e062fe64b29 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,7 +1,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -56,7 +56,7 @@ def execute_model( raise NotImplementedError @abstractmethod - def get_cache_block_size_bytes() -> int: + def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding. """ @@ -71,7 +71,7 @@ def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @@ -86,7 +86,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise ValueError(f"{type(self)} does not support LoRA") - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA") From 6dc1fc9cfed74f63f3e90c60e4d6ad7065bd4529 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 17 Apr 2024 22:28:52 -0700 Subject: [PATCH 060/413] [Core] nccl integrity check and test (#4155) [Core] Add integrity check during initialization; add test for it (#4155) --- .buildkite/test-pipeline.yaml | 1 + tests/distributed/test_pynccl_library.py | 43 ++++++++++++++++ .../device_communicators/pynccl.py | 38 +++++--------- vllm/utils.py | 51 +++++++++++++++++++ 4 files changed, 107 insertions(+), 26 deletions(-) create mode 100644 tests/distributed/test_pynccl_library.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f39c3302ac2e9..2263dee20fbed 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -33,6 +33,7 @@ steps: num_gpus: 2 # only support 1 or 2 for now. commands: - pytest -v -s test_pynccl.py + - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py diff --git a/tests/distributed/test_pynccl_library.py b/tests/distributed/test_pynccl_library.py new file mode 100644 index 0000000000000..ec60a5ed3114d --- /dev/null +++ b/tests/distributed/test_pynccl_library.py @@ -0,0 +1,43 @@ +import multiprocessing +import tempfile + + +def target_fn(env, filepath): + from vllm.utils import update_environment_variables + update_environment_variables(env) + from vllm.utils import nccl_integrity_check + nccl_integrity_check(filepath) + + +def test_library_file(): + # note: don't import vllm.distributed.device_communicators.pynccl + # before running this test, otherwise the library file will be loaded + # and it might interfere with the test + from vllm.utils import find_nccl_library + so_file = find_nccl_library() + with open(so_file, 'rb') as f: + content = f.read() + try: + # corrupt the library file, should raise an exception + with open(so_file, 'wb') as f: + f.write(content[:len(content) // 2]) + p = multiprocessing.Process(target=target_fn, args=({}, so_file)) + p.start() + p.join() + assert p.exitcode != 0 + + # move the library file to a tmp path + # test VLLM_NCCL_SO_PATH + fd, path = tempfile.mkstemp() + with open(path, 'wb') as f: + f.write(content) + p = multiprocessing.Process(target=target_fn, + args=({ + "VLLM_NCCL_SO_PATH": path + }, path)) + p.start() + p.join() + assert p.exitcode == 0 + finally: + with open(so_file, 'wb') as f: + f.write(content) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 0a8bb860efa1c..c57a4f59d442c 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -21,8 +21,7 @@ import ctypes import datetime -import glob -import os +import platform # ===================== import region ===================== import torch @@ -30,40 +29,27 @@ from torch.distributed import ReduceOp from vllm.logger import init_logger +from vllm.utils import find_nccl_library, nccl_integrity_check logger = init_logger(__name__) -so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") - -# check if we have vllm-managed nccl -vllm_nccl_path = None -if torch.version.cuda is not None: - cuda_major = torch.version.cuda.split(".")[0] - path = os.path.expanduser( - f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") - files = glob.glob(path) - vllm_nccl_path = files[0] if files else None - -# manually load the nccl library -if so_file: - logger.info( - f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}") -else: - if torch.version.cuda is not None: - so_file = vllm_nccl_path or "libnccl.so.2" - elif torch.version.hip is not None: - so_file = "librccl.so.1" - else: - raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info(f"Loading nccl from library {so_file}") +so_file = find_nccl_library() try: + # load the library in another process. + # if it core dumps, it will not crash the current process + nccl_integrity_check(so_file) nccl = ctypes.CDLL(so_file) except Exception as e: logger.error( f"Failed to load NCCL library from {so_file} ." "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise please set the environment variable VLLM_NCCL_SO_PATH" + "Otherwise, the nccl library might not exist, be corrupted " + f"or it does not support the current platform {platform.platform()}." + f"One solution is to download libnccl2 version 2.18 from " + f"https://developer.download.nvidia.com/compute/cuda/repos/ " + f"and extract the libnccl.so.2 file. If you already have the " + f"library, please set the environment variable VLLM_NCCL_SO_PATH" " to point to the correct nccl library path.") raise e diff --git a/vllm/utils.py b/vllm/utils.py index e132575e7bf81..49e7033c23e6a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,6 +1,7 @@ import asyncio import enum import gc +import glob import os import socket import subprocess @@ -517,3 +518,53 @@ def init_cached_hf_modules(): """ from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() + + +def nccl_integrity_check(filepath): + """ + when the library is corrupted, we cannot catch + the exception in python. it will crash the process. + instead, we use the exit code of `ldd` to check + if the library is corrupted. if not, we will return + the version of the library. + """ + exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null") + if exit_code != 0: + raise RuntimeError(f"Failed to load NCCL library from {filepath} .") + import ctypes + + nccl = ctypes.CDLL(filepath) + version = ctypes.c_int() + nccl.ncclGetVersion.restype = ctypes.c_int + nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] + result = nccl.ncclGetVersion(ctypes.byref(version)) + assert result == 0 + return version.value + + +def find_nccl_library(): + so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") + + # check if we have vllm-managed nccl + vllm_nccl_path = None + if torch.version.cuda is not None: + cuda_major = torch.version.cuda.split(".")[0] + path = os.path.expanduser( + f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") + files = glob.glob(path) + vllm_nccl_path = files[0] if files else None + + # manually load the nccl library + if so_file: + logger.info( + f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}" + ) + else: + if torch.version.cuda is not None: + so_file = vllm_nccl_path or "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info(f"Found nccl from library {so_file}") + return so_file From 66ded030677c7a0ca696f8d64e41637f4a358c00 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Apr 2024 08:16:26 +0100 Subject: [PATCH 061/413] Allow model to be served under multiple names (#2894) Co-authored-by: Alexandre Payot --- vllm/entrypoints/openai/api_server.py | 8 ++++---- vllm/entrypoints/openai/cli_args.py | 10 +++++++--- vllm/entrypoints/openai/serving_chat.py | 8 ++++---- vllm/entrypoints/openai/serving_completion.py | 6 +++--- vllm/entrypoints/openai/serving_engine.py | 15 ++++++++------- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 32282bfd8d12b..d6673976bb775 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -150,18 +150,18 @@ async def authentication(request: Request, call_next): logger.info(f"args: {args}") if args.served_model_name is not None: - served_model = args.served_model_name + served_model_names = args.served_model_name else: - served_model = args.model + served_model_names = [args.model] engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - openai_serving_chat = OpenAIServingChat(engine, served_model, + openai_serving_chat = OpenAIServingChat(engine, served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, served_model, args.lora_modules) + engine, served_model_names, args.lora_modules) app.root_path = args.root_path uvicorn.run(app, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index cc71931b97955..5c361b4d184ee 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -54,11 +54,15 @@ def make_arg_parser(): help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument("--served-model-name", + nargs="+", type=str, default=None, - help="The model name used in the API. If not " - "specified, the model name will be the same as " - "the huggingface name.") + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument.") parser.add_argument( "--lora-modules", type=str, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c9ed4a9de20f4..f35eab15bc487 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, - served_model: str, + served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRA]] = None, chat_template=None): super().__init__(engine=engine, - served_model=served_model, + served_model_names=served_model_names, lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) @@ -109,7 +109,7 @@ async def chat_completion_stream_generator( result_generator: AsyncIterator[RequestOutput], request_id: str ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - model_name = request.model + model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" first_iteration = True @@ -251,7 +251,7 @@ async def chat_completion_full_generator( result_generator: AsyncIterator[RequestOutput], request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: - model_name = request.model + model_name = self.served_model_names[0] created_time = int(time.time()) final_res: RequestOutput = None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a71f2d6a4426a..b7e2530a69b51 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -53,10 +53,10 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, - served_model: str, + served_model_names: List[str], lora_modules: Optional[List[LoRA]] = None): super().__init__(engine=engine, - served_model=served_model, + served_model_names=served_model_names, lora_modules=lora_modules) async def create_completion(self, request: CompletionRequest, @@ -79,7 +79,7 @@ async def create_completion(self, request: CompletionRequest, return self.create_error_response( "suffix is not currently supported") - model_name = request.model + model_name = self.served_model_names[0] request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 77a568b564039..b5a7a977ebbab 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -29,10 +29,10 @@ class OpenAIServing: def __init__(self, engine: AsyncLLMEngine, - served_model: str, + served_model_names: List[str], lora_modules=Optional[List[LoRA]]): self.engine = engine - self.served_model = served_model + self.served_model_names = served_model_names if lora_modules is None: self.lora_requests = [] else: @@ -74,13 +74,14 @@ async def _post_init(self): async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ - ModelCard(id=self.served_model, - root=self.served_model, + ModelCard(id=served_model_name, + root=self.served_model_names[0], permission=[ModelPermission()]) + for served_model_name in self.served_model_names ] lora_cards = [ ModelCard(id=lora.lora_name, - root=self.served_model, + root=self.served_model_names[0], permission=[ModelPermission()]) for lora in self.lora_requests ] @@ -150,7 +151,7 @@ def create_streaming_error_response( return json_str async def _check_model(self, request) -> Optional[ErrorResponse]: - if request.model == self.served_model: + if request.model in self.served_model_names: return if request.model in [lora.lora_name for lora in self.lora_requests]: return @@ -160,7 +161,7 @@ async def _check_model(self, request) -> Optional[ErrorResponse]: status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora(self, request) -> Optional[LoRARequest]: - if request.model == self.served_model: + if request.model in self.served_model_names: return for lora in self.lora_requests: if request.model == lora.lora_name: From 53b018edcbc601f0eea9f65f13a9a9620c4be8dc Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 18 Apr 2024 03:21:55 -0400 Subject: [PATCH 062/413] [Bugfix] Get available quantization methods from quantization registry (#4098) --- benchmarks/benchmark_latency.py | 3 ++- benchmarks/benchmark_throughput.py | 4 +++- tests/models/test_marlin.py | 7 +++---- vllm/config.py | 7 ++++--- vllm/engine/arg_utils.py | 3 ++- vllm/model_executor/layers/quantization/__init__.py | 7 ++++--- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index aadbc441713fc..44da3bad8d840 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -9,6 +9,7 @@ from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS def main(args: argparse.Namespace): @@ -101,7 +102,7 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'gptq', 'squeezellm', None], + choices=[*QUANTIZATION_METHODS, None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6df1e1d628e6c..6bb889d1eceba 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -10,6 +10,8 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + def sample_requests( dataset_path: str, @@ -267,7 +269,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'gptq', 'squeezellm', None], + choices=[*QUANTIZATION_METHODS, None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 2305db3510060..4fe6daec02520 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -16,13 +16,12 @@ import pytest import torch -from vllm.model_executor.layers.quantization import ( - _QUANTIZATION_CONFIG_REGISTRY) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] -marlin_not_supported = ( - capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability()) +marlin_not_supported = (capability < + QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass diff --git a/vllm/config.py b/vllm/config.py index 5a29620e85ac6..2912d6ccc2c5b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,6 +9,7 @@ from transformers import PretrainedConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, is_neuron) @@ -118,8 +119,8 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] - rocm_not_supported_quantization = ["awq", "marlin"] + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = ["gptq", "squeezellm"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -155,7 +156,7 @@ def _verify_quantization(self) -> None: f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") if is_hip( - ) and self.quantization in rocm_not_supported_quantization: + ) and self.quantization not in rocm_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c61c0cc67d7a2..2999ab0a7e72a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,6 +7,7 @@ EngineConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import str_to_int_tuple @@ -286,7 +287,7 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', 'gptq', 'squeezellm', None], + choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' 'None, we first check the `quantization_config` ' diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ad988d48755b0..a3b89a66469eb 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -_QUANTIZATION_CONFIG_REGISTRY = { +QUANTIZATION_METHODS = { "awq": AWQConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, @@ -16,12 +16,13 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: - if quantization not in _QUANTIZATION_CONFIG_REGISTRY: + if quantization not in QUANTIZATION_METHODS: raise ValueError(f"Invalid quantization method: {quantization}") - return _QUANTIZATION_CONFIG_REGISTRY[quantization] + return QUANTIZATION_METHODS[quantization] __all__ = [ "QuantizationConfig", "get_quantization_config", + "QUANTIZATION_METHODS", ] From e8cc7967ff8a6f8432747a9e87ab451d36e1ff57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Thu, 18 Apr 2024 00:51:28 -0700 Subject: [PATCH 063/413] [Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill (#4128) --- tests/kernels/test_prefix_prefill.py | 2 +- vllm/attention/ops/prefix_prefill.py | 44 +++++++++++++++++----------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 6494fb34af98f..ad31b0a7c2a19 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -10,7 +10,7 @@ NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128] +HEAD_SIZES = [128, 96] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 70f09224f1cf6..4896cf3909c6e 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -47,7 +47,8 @@ def _fwd_kernel( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -59,26 +60,30 @@ def _fwd_kernel( cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len block_start_loc = BLOCK_M * start_m # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) @@ -99,7 +104,8 @@ def _fwd_kernel( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -126,7 +132,8 @@ def _fwd_kernel( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -142,16 +149,15 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where( - block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -179,8 +185,8 @@ def _fwd_kernel( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), other=0.0) p = p.to(v.dtype) @@ -195,7 +201,8 @@ def _fwd_kernel( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) return @triton.jit @@ -636,7 +643,8 @@ def context_attention_fwd(q, # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = 2**((Lk - 1).bit_length()) sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -646,6 +654,7 @@ def context_attention_fwd(q, num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: + assert Lk == Lk_padded _fwd_kernel_alibi[grid]( q, k, @@ -738,6 +747,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, From 705578ae14b648782a8a321dd0903c163bd77375 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 18 Apr 2024 10:55:48 -0700 Subject: [PATCH 064/413] [Docs] document that Meta Llama 3 is supported (#4175) --- README.md | 2 +- docs/source/models/supported_models.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8434c11883341..947d50d4ad764 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) - Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) -- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) +- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 5e5ce871f61dd..951fc3aac0c75 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -80,8 +80,8 @@ Alongside each architecture, we include some popular models that use it. - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - * - :code:`LlamaForCausalLM` - - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi - - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. + - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi + - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM From e1bb2fd52dea0bbc772bdf35fd27664c5daec7b2 Mon Sep 17 00:00:00 2001 From: James Whedbee Date: Thu, 18 Apr 2024 16:12:55 -0500 Subject: [PATCH 065/413] [Bugfix] Support logprobs when using guided_json and other constrained decoding fields (#4149) --- tests/entrypoints/test_openai_server.py | 30 +++++++++++++++++++++++ vllm/entrypoints/openai/serving_engine.py | 4 ++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 14e6ee0ffe9d9..0dd30eec30086 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -723,6 +723,36 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.parametrize("guided_decoding_backend", + ["outlines", "lm-format-enforcer"]) +async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, + guided_decoding_backend: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + "The best language for type-safe systems programming is " + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5, + extra_body=dict(guided_choice=TEST_CHOICE, + guided_decoding_backend=guided_decoding_backend)) + top_logprobs = chat_completion.choices[0].logprobs.top_logprobs + + # -9999.0 is the minimum logprob returned by OpenAI + assert all( + isinstance(logprob, float) and logprob >= -9999.0 + for token_dict in top_logprobs + for token, logprob in token_dict.items()) + + async def test_response_format_json_object(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b5a7a977ebbab..8e5ee88d7f3a9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -116,7 +116,9 @@ def _create_logprobs( if num_output_top_logprobs: logprobs.top_logprobs.append({ - p.decoded_token: p.logprob + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + p.decoded_token: max(p.logprob, -9999.0) for i, p in step_top_logprobs.items() } if step_top_logprobs else None) From 87fa80c91f5b24a2ee1805b80c3eca8fd6600cd5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 18 Apr 2024 14:36:39 -0700 Subject: [PATCH 066/413] [Misc] Bump transformers to latest version (#4176) --- requirements-common.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index c1614d2537b25..3de0f98e7c0c3 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -5,7 +5,8 @@ sentencepiece # Required for LLaMA tokenizer. numpy requests py-cpuinfo -transformers >= 4.39.1 # Required for StarCoder2 & Llava. +transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. +tokenizers >= 0.19.1 # Required for Llama 3. fastapi uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. From cd2f63fb362b2c53b993e7edf6565ab6a5f9f260 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Thu, 18 Apr 2024 15:26:01 -0700 Subject: [PATCH 067/413] [CI/CD] add neuron docker and ci test scripts (#3571) --- .buildkite/run-neuron-test.sh | 37 ++++++++++++++++++++++++++++++++ .buildkite/test-template.j2 | 5 +++++ Dockerfile.neuron | 36 +++++++++++++++++++++++++++++++ setup.py | 3 ++- vllm/engine/async_llm_engine.py | 4 ++-- vllm/executor/neuron_executor.py | 22 ++++++++++++++++++- 6 files changed, 103 insertions(+), 4 deletions(-) create mode 100644 .buildkite/run-neuron-test.sh create mode 100644 Dockerfile.neuron diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh new file mode 100644 index 0000000000000..8ba03b78e8dbf --- /dev/null +++ b/.buildkite/run-neuron-test.sh @@ -0,0 +1,37 @@ +# This script build the Neuron docker image and run the API server inside the container. +# It serves a sanity check for compilation and basic model usage. +set -e + +# Try building the docker image +aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com +docker build -t neuron -f Dockerfile.neuron . + +# Setup cleanup +remove_docker_container() { docker rm -f neuron || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image +docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \ + --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 & + +# Wait for the server to start +wait_for_server_to_start() { + timeout=300 + counter=0 + + while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do + sleep 1 + counter=$((counter + 1)) + if [ $counter -ge $timeout ]; then + echo "Timeout after $timeout seconds" + break + fi + done +} +wait_for_server_to_start + +# Test a simple prompt +curl -X POST -H "Content-Type: application/json" \ + localhost:8000/generate \ + -d '{"prompt": "San Francisco is a"}' diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 0e1acc9777d4b..fb1086db77823 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -21,6 +21,11 @@ steps: queue: amd command: bash .buildkite/run-amd-test.sh + - label: "Neuron Test" + agents: + queue: neuron + command: bash .buildkite/run-neuron-test.sh + - label: "CPU Test" command: bash .buildkite/run-cpu-test.sh diff --git a/Dockerfile.neuron b/Dockerfile.neuron new file mode 100644 index 0000000000000..fe42b4ef393f1 --- /dev/null +++ b/Dockerfile.neuron @@ -0,0 +1,36 @@ +# default base image +ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04" + +FROM $BASE_IMAGE + +RUN echo "Base image is $BASE_IMAGE" + +# Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y + +### Mount Point ### +# When launching the container, mount the code directory to /app +ARG APP_MOUNT=/app +VOLUME [ ${APP_MOUNT} ] +WORKDIR ${APP_MOUNT} + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas +RUN python3 -m pip install sentencepiece transformers==4.36.2 -U +RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U + +COPY ./vllm /app/vllm/vllm +COPY ./setup.py /app/vllm/setup.py +COPY ./requirements-common.txt /app/vllm/requirements-common.txt +COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt + +RUN cd /app/vllm \ + && python3 -m pip install -U -r requirements-neuron.txt + +ENV VLLM_BUILD_WITH_NEURON 1 +RUN cd /app/vllm \ + && pip install -e . \ + && cd .. + +CMD ["/bin/bash"] diff --git a/setup.py b/setup.py index 19a9150ad2e64..4b672e1af8494 100644 --- a/setup.py +++ b/setup.py @@ -204,7 +204,8 @@ def _is_neuron() -> bool: subprocess.run(["neuron-ls"], capture_output=True, check=True) except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False - return torch_neuronx_installed + return torch_neuronx_installed or os.environ.get("VLLM_BUILD_WITH_NEURON", + False) def _is_cpu() -> bool: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c3020d2b38db0..c436ece83f65a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -335,8 +335,8 @@ def from_engine_args( engine_config = engine_args.create_engine_config() if engine_config.device_config.device_type == "neuron": - raise NotImplementedError("Neuron is not supported for " - "async engine yet.") + from vllm.executor.neuron_executor import NeuronExecutorAsync + executor_class = NeuronExecutorAsync elif engine_config.parallel_config.worker_use_ray: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 7cc187e297c9f..5a137d1bdcb3b 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,9 +1,10 @@ from typing import Dict, List, Set, Tuple -from vllm.executor.executor_base import ExecutorBase +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import make_async logger = init_logger(__name__) @@ -73,3 +74,22 @@ def check_health(self) -> None: # NeuronExecutor will always be healthy as long as # it's running. return + + +class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model)( + seq_group_metadata_list=seq_group_metadata_list, ) + return output + + async def check_health_async(self) -> None: + # NeuronExecutor will always be healthy as long as + # it's running. + return From 8f9c28fd40a9daa744d08cacb0bd5ac2247a97d1 Mon Sep 17 00:00:00 2001 From: Adam Tilghman Date: Thu, 18 Apr 2024 15:32:47 -0700 Subject: [PATCH 068/413] [Bugfix] Fix CustomAllreduce nvlink topology detection (#3974) [Bugfix] Fix CustomAllreduce pcie nvlink topology detection (#3974) (#4159) --- vllm/distributed/device_communicators/custom_all_reduce.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index f83caef879da3..7602897d3dd8f 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -145,8 +145,10 @@ def _is_full_nvlink(rank, world_size): for i in range(world_size): if i != rank: try: - link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) - if not link_state: + peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i) + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: return False except pynvml.NVMLError as error: logger.info( From 8a7a3e4436d7284df4c0913f074d77d640a9c6c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 18 Apr 2024 16:15:12 -0700 Subject: [PATCH 069/413] [Core] add an option to log every function call to for debugging hang/crash in distributed inference (#4079) Co-authored-by: Simon Mo --- .buildkite/test-pipeline.yaml | 2 +- .github/ISSUE_TEMPLATE/400-bug report.yml | 2 + tests/test_logger.py | 27 ++++++++++++ vllm/executor/ray_gpu_executor.py | 12 ++++-- vllm/logger.py | 52 +++++++++++++++++++++++ vllm/utils.py | 13 +++++- vllm/worker/worker_base.py | 20 +++++++-- 7 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 tests/test_logger.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2263dee20fbed..0f920c7ec1442 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -40,7 +40,7 @@ steps: - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test - command: pytest -v -s engine tokenization test_sequence.py test_config.py + command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test commands: diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index f1124dfa78bbc..c87f8fddcb776 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -57,6 +57,8 @@ body: If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000000000..601f72b50811c --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,27 @@ +import os +import sys +import tempfile + +from vllm.logger import enable_trace_function_call + + +def f1(x): + return f2(x) + + +def f2(x): + return x + + +def test_trace_function_call(): + fd, path = tempfile.mkstemp() + cur_dir = os.path.dirname(__file__) + enable_trace_function_call(path, cur_dir) + f1(1) + with open(path, 'r') as f: + content = f.read() + + assert "f1" in content + assert "f2" in content + sys.settrace(None) + os.remove(path) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5a43f1fc28a84..f779b0f8a5113 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,7 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) + get_vllm_instance_id, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -133,12 +133,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) - # Set CUDA_VISIBLE_DEVICES for the driver and workers. + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [] for (node_id, _) in worker_node_and_gpu_ids: all_args_to_update_environment_variables.append([{ "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])) + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + os.getenv("VLLM_TRACE_FUNCTION", "0"), }]) self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) diff --git a/vllm/logger.py b/vllm/logger.py index af9575085ef37..046f0e9099a4b 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,9 +1,11 @@ # Adapted from # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" +import datetime import logging import os import sys +from functools import partial from typing import Optional VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) @@ -65,3 +67,53 @@ def init_logger(name: str): logger.addHandler(_default_handler) logger.propagate = False return logger + + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ['call', 'return']: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the vllm root_dir + return + # Log every function call or return + try: + with open(log_path, 'a') as f: + if event == 'call': + f.write(f"{datetime.datetime.now()} Call to" + f" {func_name} in {filename}:{lineno}\n") + else: + f.write(f"{datetime.datetime.now()} Return from" + f" {func_name} in {filename}:{lineno}\n") + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, + root_dir: Optional[str] = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + vllm root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "VLLM_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only.") + logger.info(f"Trace frame log is saved to {log_file_path}") + if root_dir is None: + # by default, this is the vllm root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) diff --git a/vllm/utils.py b/vllm/utils.py index 49e7033c23e6a..fbe86dacaeb99 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -163,6 +163,17 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) +@lru_cache(maxsize=None) +def get_vllm_instance_id(): + """ + If the environment variable VLLM_INSTANCE_ID is set, return it. + Otherwise, return a random UUID. + Instance id represents an instance of the VLLM. All processes in the same + instance should have the same instance id. + """ + return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}") + + @lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 @@ -274,7 +285,7 @@ def get_open_port() -> int: def update_environment_variables(envs: Dict[str, str]): for k, v in envs.items(): - if k in os.environ: + if k in os.environ and os.environ[k] != v: logger.warning(f"Overwriting environment variable {k} " f"from '{os.environ[k]}' to '{v}'") os.environ[k] = v diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 13e062fe64b29..783dff3a43404 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,12 +1,15 @@ +import datetime import importlib import os +import tempfile +import threading from abc import ABC, abstractmethod from typing import Dict, List, Set, Tuple -from vllm.logger import init_logger +from vllm.logger import enable_trace_function_call, init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import update_environment_variables +from vllm.utils import get_vllm_instance_id, update_environment_variables logger = init_logger(__name__) @@ -115,9 +118,20 @@ def update_environment_variables(self, envs: Dict[str, str]) -> None: def init_worker(self, *args, **kwargs): """ - Actual initialization of the worker class. + Actual initialization of the worker class, and set up + function tracing if required. Arguments are passed to the worker class constructor. """ + if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) self.worker = worker_class(*args, **kwargs) From a134ef6f5e6c24d3cd459c63557e5db276db25b2 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 18 Apr 2024 21:13:36 -0700 Subject: [PATCH 070/413] Support eos_token_id from generation_config.json (#4182) --- vllm/engine/llm_engine.py | 19 +++++++++++++++++-- vllm/sampling_params.py | 14 +++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c3de57e249ff8..88b344f50767f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,7 @@ import time from typing import Iterable, List, Optional, Type, Union -from transformers import PreTrainedTokenizer +from transformers import GenerationConfig, PreTrainedTokenizer import vllm from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, @@ -34,6 +34,17 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 +def _load_generation_config_dict(model_config: ModelConfig): + try: + return GenerationConfig.from_pretrained( + model_config.model, + revision=model_config.revision, + ).to_diff_dict() + except OSError: + # Not found. + return {} + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -124,6 +135,8 @@ def __init__( self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict( + model_config) self.model_executor = executor_class( model_config=model_config, @@ -391,6 +404,8 @@ def add_request( # inject the eos token id into the sampling_params to support min_tokens # processing sampling_params.eos_token_id = seq.eos_token_id + sampling_params.update_from_generation_config( + self.generation_config_fields) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, @@ -435,7 +450,7 @@ def _process_model_outputs( scheduled_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: """Apply the model output to the sequences in the scheduled seq groups. - + Returns RequestOutputs that can be returned to the client. """ diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 53a38b25bfdac..dc0e60344d858 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,7 +2,7 @@ import copy from enum import IntEnum from functools import cached_property -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from pydantic import Field @@ -271,6 +271,18 @@ def _verify_greedy_sampling(self) -> None: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") + def update_from_generation_config( + self, generation_config: Dict[str, Any]) -> None: + """Update if there are non-default values from generation_config""" + # Update eos_token_id for generation + if eos_ids := generation_config.get("eos_token_id"): + # it can be either int or list of int + if isinstance(eos_ids, int): + eos_ids = [eos_ids] + original_stop_token_ids = set(self.stop_token_ids) + original_stop_token_ids.update(eos_ids) + self.stop_token_ids = list(original_stop_token_ids) + @cached_property def sampling_type(self) -> SamplingType: if self.use_beam_search: From d17c8477f1fb337ea9fcf439bcab4d323058c1b4 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Fri, 19 Apr 2024 15:59:54 +0800 Subject: [PATCH 071/413] [Bugfix] Fix LoRA loading check (#4138) Co-authored-by: simon-mo --- tests/lora/conftest.py | 6 ++++++ tests/lora/test_lora_checkpoints.py | 22 ++++++++++++++++++++-- vllm/lora/models.py | 4 +++- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 2dabfb6b4337c..a3ffc53d8cd1d 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -143,6 +143,12 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def baichuan_zero_lora_files(): + # all the lora_B weights are initialized to zero. + return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 35ad7342944cd..d4d1665b624ea 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -3,9 +3,16 @@ from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"] -@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) -def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + +@pytest.mark.parametrize("lora_name", lora_lst) +def test_load_checkpoints( + lora_name, + baichuan_lora_files, + baichuan_zero_lora_files, + chatglm3_lora_files, +): supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules @@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): device="cpu", embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) + elif lora_name == "baichuan7B-zero": + #Test that the target_modules contain prefix + # such as "model.layers.0.self_atten.W_pack", and + # the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_zero_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 62f1502458008..6bb9fee27d535 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -212,7 +212,9 @@ def from_local_checkpoint( target_modules = config["target_modules"] unexpected_modules = [] for module in target_modules: - if module not in expected_lora_modules: + # Compatible with more modules, such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: unexpected_modules.append(module) # loaded lora's target modules must be a subset of expected_lora_modules if unexpected_modules: From 221d93ecbf51102df69deaf153d35df6d93370f6 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Fri, 19 Apr 2024 01:00:22 -0700 Subject: [PATCH 072/413] Bump version of 0.4.1 (#4177) --- vllm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index 2c1fd40573240..5ca4680227598 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.4.0.post1" +__version__ = "0.4.1" __all__ = [ "LLM", From 8f20fc04bf7384089395caa021766cd352d0cf0b Mon Sep 17 00:00:00 2001 From: Uranus <109661872+UranusSeven@users.noreply.github.com> Date: Fri, 19 Apr 2024 16:18:33 +0800 Subject: [PATCH 073/413] [Misc] fix docstrings (#4191) Co-authored-by: Zhong Wang --- vllm/sequence.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 92362a9a5d2a3..7dcacab6f2ab6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -160,7 +160,7 @@ def reset_state_for_recompute(self) -> None: self._stage = SequenceStage.PREFILL def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefil tokens that are not computed.""" + """Return the number of prefill tokens that are not computed.""" # we use `get_len()` which includes prompt_len + output_len instead # of prompt_len here. This is because during recompute we need to # prefill for both prompt and output. @@ -345,12 +345,9 @@ def fork(self, new_seq_id: int) -> "Sequence": def get_num_new_tokens(self) -> int: """Get the number of new tokens to be computed. - Args: - remainig_token_budget: The remaining token budgets. Returns: - The new number of tokens to be computed. I.e., 1 for decode, prompt - size for prefill. If there's not enough remainig_token_budget, it - can return the chunked number of new tokens. + The new number of tokens to be computed. I.e., 1 for decode, or + the remaining prompt size for prefill. """ if self.data.stage == SequenceStage.DECODE: return 1 From 7be4f5628fc9999bf8a6025edd8f098353e0724b Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Fri, 19 Apr 2024 18:08:26 +0300 Subject: [PATCH 074/413] [Bugfix][Core] Restore logging of stats in the async engine (#4150) --- vllm/engine/async_llm_engine.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c436ece83f65a..ca4ba66f09cb8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -217,10 +217,16 @@ async def step_async(self) -> List[RequestOutput]: else: output = [] - return self._process_model_outputs( + request_outputs = self._process_model_outputs( output, scheduler_outputs.scheduled_seq_groups, scheduler_outputs.ignored_seq_groups) + # Log stats. + if self.log_stats: + self.stat_logger.log(self._get_stats(scheduler_outputs)) + + return request_outputs + async def encode_request_async( self, request_id: str, # pylint: disable=unused-argument From 15b86408a89d5b998409e7fbe7850e937cc837da Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 19 Apr 2024 12:44:51 -0700 Subject: [PATCH 075/413] [Misc] add nccl in collect env (#4211) --- .github/ISSUE_TEMPLATE/200-installation.yml | 1 + .github/ISSUE_TEMPLATE/300-usage.yml | 1 + .github/ISSUE_TEMPLATE/400-bug report.yml | 1 + .github/ISSUE_TEMPLATE/700-performance discussion.yml | 1 + collect_env.py | 2 ++ 5 files changed, 6 insertions(+) diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml index 4c6c96187cc6c..df41ade8c3c01 100644 --- a/.github/ISSUE_TEMPLATE/200-installation.yml +++ b/.github/ISSUE_TEMPLATE/200-installation.yml @@ -18,6 +18,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml index 88227b4b2e7b9..54763af1058f6 100644 --- a/.github/ISSUE_TEMPLATE/300-usage.yml +++ b/.github/ISSUE_TEMPLATE/300-usage.yml @@ -18,6 +18,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index c87f8fddcb776..08120ad8e5a60 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -18,6 +18,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance discussion.yml index 9e8e7b4aa3530..4f8843420a94e 100644 --- a/.github/ISSUE_TEMPLATE/700-performance discussion.yml +++ b/.github/ISSUE_TEMPLATE/700-performance discussion.yml @@ -39,6 +39,7 @@ body: # For security purposes, please feel free to check the contents of collect_env.py before running it. python collect_env.py ``` + It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | ```text The output of `python collect_env.py` diff --git a/collect_env.py b/collect_env.py index 8982fba024274..1ecfeb8e22e2f 100644 --- a/collect_env.py +++ b/collect_env.py @@ -63,6 +63,7 @@ "magma", "triton", "optree", + "nccl", } DEFAULT_PIP_PATTERNS = { @@ -73,6 +74,7 @@ "triton", "optree", "onnx", + "nccl", } From bc9df1571b8002738eb8db70a07f552e32feb75f Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Sat, 20 Apr 2024 05:43:56 +0530 Subject: [PATCH 076/413] Pass `tokenizer_revision` when getting tokenizer in openai serving (#4214) --- vllm/entrypoints/openai/serving_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8e5ee88d7f3a9..376b581052d85 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -68,6 +68,7 @@ async def _post_init(self): self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, + tokenizer_revision=engine_model_config.tokenizer_revision, trust_remote_code=engine_model_config.trust_remote_code, truncation_side="left") From 138485a82de50f90536ea0a650dd2f6bba1927e9 Mon Sep 17 00:00:00 2001 From: Ayush Rautwar <42046470+ayusher@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:49:22 -0400 Subject: [PATCH 077/413] [Bugfix] Add fix for JSON whitespace (#4189) Co-authored-by: Ubuntu --- tests/entrypoints/test_openai_server.py | 27 ++++++++++--------- .../outlines_logits_processors.py | 5 ++++ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 0dd30eec30086..85a7ef464c032 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -754,19 +754,20 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, async def test_response_format_json_object(server, client: openai.AsyncOpenAI): - resp = await client.chat.completions.create( - model=MODEL_NAME, - messages=[{ - "role": - "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') - }], - response_format={"type": "json_object"}) - - content = resp.choices[0].message.content - loaded = json.loads(content) - assert loaded == {"result": 2}, loaded + for _ in range(2): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": ('what is 1+1? please respond with a JSON object, ' + 'the format is {"result": 2}') + }], + response_format={"type": "json_object"}) + + content = resp.choices[0].message.content + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded async def test_guided_grammar(server, client: openai.AsyncOpenAI): diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 95a67b612f08b..25ab5bf8b6a9c 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -131,6 +131,11 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + def init_state(self): + """Initialize state with a CFGFSM copy.""" + super().init_state() + self.fsm = self.fsm.copy() + @lru_cache def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): From 682789d4026429a04cc32acb88064265441080dd Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 20 Apr 2024 04:51:33 +0100 Subject: [PATCH 078/413] Fix missing docs and out of sync `EngineArgs` (#4219) Co-authored-by: Harry Mellor --- docs/source/conf.py | 2 + docs/source/models/engine_args.rst | 134 ++---------------------- vllm/engine/arg_utils.py | 159 ++++++++++++++++------------- 3 files changed, 98 insertions(+), 197 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 19cc8557a7541..cfa956b143ba3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,12 +11,14 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import logging +import os import sys from typing import List from sphinx.ext import autodoc logger = logging.getLogger(__name__) +sys.path.append(os.path.abspath("../..")) # -- Project information ----------------------------------------------------- diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index 235cb4e128c99..92bc7e0e843ed 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -5,133 +5,17 @@ Engine Arguments Below, you can find an explanation of every engine argument for vLLM: -.. option:: --model - - Name or path of the huggingface model to use. - -.. option:: --tokenizer - - Name or path of the huggingface tokenizer to use. - -.. option:: --revision - - The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - -.. option:: --tokenizer-revision - - The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. - -.. option:: --tokenizer-mode {auto,slow} - - The tokenizer mode. - - * "auto" will use the fast tokenizer if available. - * "slow" will always use the slow tokenizer. - -.. option:: --trust-remote-code - - Trust remote code from huggingface. - -.. option:: --download-dir - - Directory to download and load the weights, default to the default cache dir of huggingface. - -.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer} - - The format of the model weights to load. - - * "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. - * "pt" will load the weights in the pytorch bin format. - * "safetensors" will load the weights in the safetensors format. - * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. - * "dummy" will initialize the weights with random values, mainly for profiling. - * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_ See `examples/tensorize_vllm_model.py `_ to serialize a vLLM model, and for more information. - -.. option:: --dtype {auto,half,float16,bfloat16,float,float32} - - Data type for model weights and activations. - - * "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. - * "half" for FP16. Recommended for AWQ quantization. - * "float16" is the same as "half". - * "bfloat16" for a balance between precision and range. - * "float" is shorthand for FP32 precision. - * "float32" for FP32 precision. - -.. option:: --max-model-len - - Model context length. If unspecified, will be automatically derived from the model config. - -.. option:: --worker-use-ray - - Use Ray for distributed serving, will be automatically set when using more than 1 GPU. - -.. option:: --pipeline-parallel-size (-pp) - - Number of pipeline stages. - -.. option:: --tensor-parallel-size (-tp) - - Number of tensor parallel replicas. - -.. option:: --max-parallel-loading-workers - - Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models. - -.. option:: --block-size {8,16,32} - - Token block size for contiguous chunks of tokens. - -.. option:: --enable-prefix-caching - - Enables automatic prefix caching - -.. option:: --seed - - Random seed for operations. - -.. option:: --swap-space - - CPU swap space size (GiB) per GPU. - -.. option:: --gpu-memory-utilization - - The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. - For example, a value of 0.5 would imply 50% GPU memory utilization. - If unspecified, will use the default value of 0.9. - -.. option:: --max-num-batched-tokens - - Maximum number of batched tokens per iteration. - -.. option:: --max-num-seqs - - Maximum number of sequences per iteration. - -.. option:: --max-paddings - - Maximum number of paddings in a batch. - -.. option:: --disable-log-stats - - Disable logging statistics. - -.. option:: --quantization (-q) {awq,squeezellm,None} - - Method used to quantize the weights. +.. argparse:: + :module: vllm.engine.arg_utils + :func: _engine_args_parser + :prog: -m vllm.entrypoints.openai.api_server Async Engine Arguments ---------------------- -Below are the additional arguments related to the asynchronous engine: - -.. option:: --engine-use-ray - Use Ray to start the LLM engine in a separate process as the server process. - -.. option:: --disable-log-requests - - Disable logging requests. - -.. option:: --max-log-len +Below are the additional arguments related to the asynchronous engine: - Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited. \ No newline at end of file +.. argparse:: + :module: vllm.engine.arg_utils + :func: _async_engine_args_parser + :prog: -m vllm.entrypoints.openai.api_server \ No newline at end of file diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2999ab0a7e72a..53f129598270c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -82,57 +82,55 @@ def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Shared CLI arguments for vLLM engine.""" - # NOTE: If you update any of the arguments below, please also - # make sure to update docs/source/models/engine_args.rst - # Model arguments parser.add_argument( '--model', type=str, default='facebook/opt-125m', - help='name or path of the huggingface model to use') + help='Name or path of the huggingface model to use.') parser.add_argument( '--tokenizer', type=str, default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') + help='Name or path of the huggingface tokenizer to use.') parser.add_argument( '--revision', type=str, default=None, - help='the specific model version to use. It can be a branch ' + help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', type=str, default=None, - help='the specific revision to use for the model code on ' + help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', type=str, default=None, - help='the specific tokenizer version to use. It can be a branch ' + help='The specific tokenizer version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') + parser.add_argument( + '--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer.') parser.add_argument('--trust-remote-code', action='store_true', - help='trust remote code from huggingface') + help='Trust remote code from huggingface.') parser.add_argument('--download-dir', type=str, default=EngineArgs.download_dir, - help='directory to download and load the weights, ' + help='Directory to download and load the weights, ' 'default to the default cache dir of ' - 'huggingface') + 'huggingface.') parser.add_argument( '--load-format', type=str, @@ -140,19 +138,19 @@ def add_cli_args( choices=[ 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' ], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.' - '"tensorizer" will load the weights using tensorizer from CoreWeave' - 'which assumes tensorizer_uri is set to the location of the ' - 'serialized weights.') + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave which assumes tensorizer_uri is set to the location of ' + 'the serialized weights.') parser.add_argument( '--dtype', type=str, @@ -160,10 +158,14 @@ def add_cli_args( choices=[ 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' ], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') + help='Data type for model weights and activations.\n\n' + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + 'BF16 precision for BF16 models.\n' + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.') parser.add_argument( '--kv-cache-dtype', type=str, @@ -172,7 +174,7 @@ def add_cli_args( help='Data type for kv cache storage. If "auto", will use model ' 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria. ') + 'supported for common inference criteria.') parser.add_argument( '--quantization-param-path', type=str, @@ -183,58 +185,59 @@ def add_cli_args( 'default to 1.0, which may cause accuracy issues. ' 'FP8_E5M2 (without scaling) is only supported on cuda version' 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria. ') + 'supported for common inference criteria.') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + help='Model context length. If unspecified, will ' + 'be automatically derived from the model config.') parser.add_argument( '--guided-decoding-backend', type=str, default='outlines', choices=['outlines', 'lm-format-enforcer'], help='Which engine will be used for guided decoding' - ' (JSON schema / regex etc)') + ' (JSON schema / regex etc).') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') + help='Use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU.') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') + help='Number of pipeline stages.') parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') + help='Number of tensor parallel replicas.') parser.add_argument( '--max-parallel-loading-workers', type=int, default=EngineArgs.max_parallel_loading_workers, - help='load model sequentially in multiple batches, ' + help='Load model sequentially in multiple batches, ' 'to avoid RAM OOM when using tensor ' - 'parallel and large models') + 'parallel and large models.') parser.add_argument( '--ray-workers-use-nsight', action='store_true', - help='If specified, use nsight to profile ray workers') + help='If specified, use nsight to profile Ray workers.') # KV cache arguments parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, choices=[8, 16, 32, 128], - help='token block size') + help='Token block size for contiguous chunks of ' + 'tokens.') parser.add_argument('--enable-prefix-caching', action='store_true', - help='Enables automatic prefix caching') + help='Enables automatic prefix caching.') parser.add_argument('--use-v2-block-manager', action='store_true', - help='Use BlockSpaceMangerV2') + help='Use BlockSpaceMangerV2.') parser.add_argument( '--num-lookahead-slots', type=int, @@ -247,18 +250,19 @@ def add_cli_args( parser.add_argument('--seed', type=int, default=EngineArgs.seed, - help='random seed') + help='Random seed for operations.') parser.add_argument('--swap-space', type=int, default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') + help='CPU swap space size (GiB) per GPU.') parser.add_argument( '--gpu-memory-utilization', type=float, default=EngineArgs.gpu_memory_utilization, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') + help='The fraction of GPU memory to be used for the model ' + 'executor, which can range from 0 to 1. For example, a value of ' + '0.5 would imply 50%% GPU memory utilization. If unspecified, ' + 'will use the default value of 0.9.') parser.add_argument( '--num-gpu-blocks-override', type=int, @@ -268,21 +272,21 @@ def add_cli_args( parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') + help='Maximum number of batched tokens per ' + 'iteration.') parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') + help='Maximum number of sequences per iteration.') parser.add_argument( '--max-logprobs', type=int, default=EngineArgs.max_logprobs, - help=('max number of log probs to return logprobs is specified in' - ' SamplingParams')) + help=('Max number of log probs to return logprobs is specified in' + ' SamplingParams.')) parser.add_argument('--disable-log-stats', action='store_true', - help='disable logging statistics') + help='Disable logging statistics.') # Quantization settings. parser.add_argument('--quantization', '-q', @@ -303,13 +307,13 @@ def add_cli_args( parser.add_argument('--max-context-len-to-capture', type=int, default=EngineArgs.max_context_len_to_capture, - help='maximum context length covered by CUDA ' + help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', default=EngineArgs.disable_custom_all_reduce, - help='See ParallelConfig') + help='See ParallelConfig.') parser.add_argument('--tokenizer-pool-size', type=int, default=EngineArgs.tokenizer_pool_size, @@ -402,7 +406,7 @@ def add_cli_args( '--enable-chunked-prefill', action='store_true', help='If set, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens') + 'max_num_batched_tokens.') parser.add_argument( '--speculative-model', @@ -416,7 +420,7 @@ def add_cli_args( type=int, default=None, help='The number of speculative tokens to sample from ' - 'the draft model in speculative decoding') + 'the draft model in speculative decoding.') parser.add_argument('--model-loader-extra-config', type=str, @@ -534,20 +538,31 @@ class AsyncEngineArgs(EngineArgs): max_log_len: Optional[int] = None @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - parser = EngineArgs.add_cli_args(parser) + def add_cli_args(parser: argparse.ArgumentParser, + async_args_only: bool = False) -> argparse.ArgumentParser: + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) parser.add_argument('--engine-use-ray', action='store_true', - help='use Ray to start the LLM engine in a ' + help='Use Ray to start the LLM engine in a ' 'separate process as the server process.') parser.add_argument('--disable-log-requests', action='store_true', - help='disable logging requests') + help='Disable logging requests.') parser.add_argument('--max-log-len', type=int, default=None, - help='max number of prompt characters or prompt ' - 'ID numbers being printed in log. ' - 'Default: unlimited.') + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') return parser + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(argparse.ArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(), + async_args_only=True) From a22cdea371bb26b4bdba112d4602736b48ca4a3a Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 19 Apr 2024 21:28:57 -0700 Subject: [PATCH 079/413] [Kernel][FP8] Initial support with dynamic per-tensor scaling (#4118) Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726 This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine. Algorithm: We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass. Initial Results: Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128: BF16: 1.47s FP8: 1.66s I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code. --- tests/quantization/test_fp8.py | 24 +++ vllm/entrypoints/llm.py | 9 +- vllm/model_executor/layers/linear.py | 8 + .../layers/quantization/__init__.py | 2 + .../model_executor/layers/quantization/fp8.py | 138 ++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 4 + .../model_loader/weight_utils.py | 9 +- 7 files changed, 189 insertions(+), 5 deletions(-) create mode 100644 tests/quantization/test_fp8.py create mode 100644 vllm/model_executor/layers/quantization/fp8.py diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py new file mode 100644 index 0000000000000..fa10e60de10a7 --- /dev/null +++ b/tests/quantization/test_fp8.py @@ -0,0 +1,24 @@ +"""Tests whether FP8 computation is enabled correctly. + +Run `pytest tests/quantization/test_fp8.py --forked`. +""" +import pytest +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +@pytest.mark.skipif( + capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), + reason="FP8 is not supported on this GPU type.") +def test_load_fp16_model(vllm_runner) -> None: + llm = vllm_runner("facebook/opt-125m", quantization="fp8") + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.linear_method, Fp8LinearMethod) + assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e08c253dc539..961de5d5063fa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -42,10 +42,11 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq" and "squeezellm". If None, we first check - the `quantization_config` attribute in the model config file. If - that is None, we assume the model weights are not quantized and use - `dtype` to determine the data type of the weights. + we support "awq", "gptq", "squeezellm", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3ca870742efc5..d466d8807fc64 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F +from torch import nn from torch.nn.parameter import Parameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -48,6 +49,13 @@ def apply_weights(self, Expects create_weights to have been called before on the layer.""" raise NotImplementedError + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a3b89a66469eb..0344d6e4e3e45 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -3,12 +3,14 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import FP8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS = { "awq": AWQConfig, + "fp8": FP8Config, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000000000..9dc0e86e1243d --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -0,0 +1,138 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class FP8Config(QuantizationConfig): + """Config class for FP8.""" + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # TODO: PyTorch 2.3.0+ is required to run FP8 on + # SM 89 (e.g. Ada) GPUs. Specifically, this PR has to + # be included: https://github.com/pytorch/pytorch/pull/118881 + return 90 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FP8Config": + return cls() + + def get_linear_method(self) -> "Fp8LinearMethod": + return Fp8LinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + We now support common FP16/BF16 model checkpoints ONLY. The weight + scaling factor will be initialized after the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: FP8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + set_weight_attrs(weight, extra_weight_attrs) + + w_scale = Parameter( + torch.empty(1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("weight_scaling_factor", w_scale) + + def process_weights_after_loading(self, layer: Module) -> None: + # Although the linear_method is propagated to all layers, + # only linear layers invoke "create_weights". So we check + # whether "weight_scaling_facor" is registered to determine + # whether the layer is a linear layer that requires quantization. + if not hasattr(layer, "weight_scaling_factor"): + return + + qweight, weight_scale = per_tensor_quantize(layer.weight) + # torch._scaled_mm requires column-major in the second + # input (weight), so we transpose the quantized weight. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scaling_factor.data.copy_(weight_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qinput, x_scale = per_tensor_quantize(x) + output, _ = torch._scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scaling_factor, + bias=bias, + ) + return output + + +def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: + """Quantize a tensor using per-tensor static scaling factor. + + Args: + tensor: The input tensor. + """ + finfo = torch.finfo(torch.float8_e4m3fn) + # Calculate the scale as dtype max divided by absmax. + # Since .abs() creates a new tensor, we use aminmax to get + # the min and max first and then calculate the absmax. + min_val, max_val = tensor.aminmax() + amax = min_val.abs().max(max_val.abs()) + scale = finfo.max / amax.clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + qweight = qweight.to(torch.float8_e4m3fn) + scale = scale.float().reciprocal() + return qweight, scale diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 3b1d125ef8a67..6c8cb2935f37e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -228,6 +228,10 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) + for _, module in model.named_modules(): + linear_method = getattr(module, "linear_method", None) + if linear_method is not None: + linear_method.process_weights_after_loading(module) return model.eval() diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 1798db0136868..9995f2afe3cf7 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -134,11 +134,18 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm) else: hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + config_files = glob.glob(os.path.join(hf_folder, "*.json")) quant_config_files = [ f for f in config_files if any( - f.endswith(x) for x in quant_cls.get_config_filenames()) + f.endswith(x) for x in possible_config_filenames) ] if len(quant_config_files) == 0: raise ValueError( From 91528575ec0e6fd251bac08973a3abf23d4c318c Mon Sep 17 00:00:00 2001 From: nunjunj <106306814+nunjunj@users.noreply.github.com> Date: Sat, 20 Apr 2024 00:11:57 -0700 Subject: [PATCH 080/413] [Frontend] multiple sampling params support (#3570) --- tests/entrypoints/test_llm_generate.py | 41 ++++++++++++++++++++++++++ vllm/entrypoints/llm.py | 30 ++++++++++++------- 2 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 tests/entrypoints/test_llm_generate.py diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py new file mode 100644 index 0000000000000..5e8b7ca4d9977 --- /dev/null +++ b/tests/entrypoints/test_llm_generate.py @@ -0,0 +1,41 @@ +import pytest + +from vllm import LLM, SamplingParams + + +def test_multiple_sampling_params(): + + llm = LLM(model="facebook/opt-125m", + max_num_batched_tokens=4096, + tensor_parallel_size=1) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = [ + SamplingParams(temperature=0.01, top_p=0.95), + SamplingParams(temperature=0.3, top_p=0.95), + SamplingParams(temperature=0.7, top_p=0.95), + SamplingParams(temperature=0.99, top_p=0.95), + ] + + # Multiple SamplingParams should be matched with each prompt + outputs = llm.generate(prompts, sampling_params=sampling_params) + assert len(prompts) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + + # Single SamplingParams should be applied to every prompt + single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) + outputs = llm.generate(prompts, sampling_params=single_sampling_params) + assert len(prompts) == len(outputs) + + # sampling_params is None, default params should be applied + outputs = llm.generate(prompts, sampling_params=None) + assert len(prompts) == len(outputs) \ No newline at end of file diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 961de5d5063fa..f745dbd736d17 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -127,7 +127,8 @@ def set_tokenizer( def generate( self, prompts: Optional[Union[str, List[str]]] = None, - sampling_params: Optional[SamplingParams] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, @@ -142,7 +143,10 @@ def generate( Args: prompts: A list of prompts to generate completions for. sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. @@ -163,27 +167,33 @@ def generate( and len(prompts) != len(prompt_token_ids)): raise ValueError("The lengths of prompts and prompt_token_ids " "must be the same.") + + if prompts is not None: + num_requests = len(prompts) + else: + assert prompt_token_ids is not None + num_requests = len(prompt_token_ids) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() + elif isinstance(sampling_params, + list) and len(sampling_params) != num_requests: + raise ValueError("The lengths of prompts and sampling_params " + "must be the same.") if multi_modal_data: multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - if prompts is not None: - num_requests = len(prompts) - else: - assert prompt_token_ids is not None - num_requests = len(prompt_token_ids) - for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] self._add_request( prompt, - sampling_params, + sampling_params[i] + if isinstance(sampling_params, list) else sampling_params, token_ids, lora_request=lora_request, # Get ith image while maintaining the batch dim. @@ -232,4 +242,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + return outputs \ No newline at end of file From cc74b2b232070f74d8765a5eefa49ae93ee45490 Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Sat, 20 Apr 2024 11:33:16 +0300 Subject: [PATCH 081/413] Updating lm-format-enforcer version and adding links to decoding libraries in docs (#4222) --- requirements-common.txt | 2 +- vllm/engine/arg_utils.py | 6 +++++- vllm/model_executor/layers/quantization/fp8.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 3de0f98e7c0c3..3cc7bba8f84db 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,7 +12,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer == 0.9.3 +lm-format-enforcer == 0.9.8 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 53f129598270c..8939a0d537281 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -197,7 +197,11 @@ def add_cli_args( default='outlines', choices=['outlines', 'lm-format-enforcer'], help='Which engine will be used for guided decoding' - ' (JSON schema / regex etc).') + ' (JSON schema / regex etc) by default. Currently support ' + 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/noamgat/lm-format-enforcer.' + ' Can be overridden per request via guided_decoding_backend' + ' parameter.') # Parallel arguments parser.add_argument('--worker-use-ray', action='store_true', diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9dc0e86e1243d..8df82e0e18edd 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch from torch.nn import Module @@ -114,7 +114,7 @@ def apply_weights(self, return output -def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]: +def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: """Quantize a tensor using per-tensor static scaling factor. Args: From fe7d648fe56c138811e9b2b02937b55a84830454 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sun, 21 Apr 2024 17:15:28 +0100 Subject: [PATCH 082/413] Don't show default value for flags in `EngineArgs` (#4223) Co-authored-by: Harry Mellor --- docs/source/models/engine_args.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index 92bc7e0e843ed..bdf566d3ebbd1 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -9,6 +9,7 @@ Below, you can find an explanation of every engine argument for vLLM: :module: vllm.engine.arg_utils :func: _engine_args_parser :prog: -m vllm.entrypoints.openai.api_server + :nodefaultconst: Async Engine Arguments ---------------------- @@ -18,4 +19,5 @@ Below are the additional arguments related to the asynchronous engine: .. argparse:: :module: vllm.engine.arg_utils :func: _async_engine_args_parser - :prog: -m vllm.entrypoints.openai.api_server \ No newline at end of file + :prog: -m vllm.entrypoints.openai.api_server + :nodefaultconst: \ No newline at end of file From 7f2593b164c2ff115ba4fb9ce95fe63bdd824b85 Mon Sep 17 00:00:00 2001 From: xiaoji <44150358+YeFD@users.noreply.github.com> Date: Mon, 22 Apr 2024 00:57:08 +0800 Subject: [PATCH 083/413] [Doc]: Update the doc of adding new models (#4236) --- docs/source/models/adding_model.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index a82c2cef10e83..cbc8099e6f70f 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -95,7 +95,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a 5. Register your model ---------------------- -Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_. +Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py `_. 6. Out-of-Tree Model Integration -------------------------------------------- From a37d815b83849b5a96a182929dd6f3bd35f68fb8 Mon Sep 17 00:00:00 2001 From: GeauxEric Date: Sun, 21 Apr 2024 15:06:46 -0700 Subject: [PATCH 084/413] Make initialization of tokenizer and detokenizer optional (#3748) Co-authored-by: Yun Ding Co-authored-by: Roger Wang --- tests/engine/test_skip_tokenizer_init.py | 23 ++++++++++++++++ vllm/config.py | 7 ++++- vllm/engine/arg_utils.py | 7 ++++- vllm/engine/llm_engine.py | 29 +++++++++++++++------ vllm/engine/output_processor/single_step.py | 5 ++-- vllm/entrypoints/llm.py | 9 +++++++ 6 files changed, 68 insertions(+), 12 deletions(-) create mode 100644 tests/engine/test_skip_tokenizer_init.py diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py new file mode 100644 index 0000000000000..baa463a316902 --- /dev/null +++ b/tests/engine/test_skip_tokenizer_init.py @@ -0,0 +1,23 @@ +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +def test_skip_tokenizer_initialization(model: str): + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM(model=model, skip_tokenizer_init=True) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) + with pytest.raises(ValueError) as err: + llm.generate("abc", sampling_params) + assert "prompts must be None if" in str(err.value) + outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], + sampling_params=sampling_params) + assert len(outputs) > 0 + completions = outputs[0].outputs + assert len(completions) > 0 + assert completions[0].text == "" + assert completions[0].token_ids diff --git a/vllm/config.py b/vllm/config.py index 2912d6ccc2c5b..97ede0faa21ab 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -66,6 +66,8 @@ class ModelConfig: max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. """ def __init__( @@ -85,6 +87,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + skip_tokenizer_init: bool = False, ) -> None: self.model = model self.tokenizer = tokenizer @@ -99,6 +102,7 @@ def __init__( self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture self.max_logprobs = max_logprobs + self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision) @@ -106,7 +110,8 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) - self._verify_tokenizer_mode() + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() self._verify_quantization() self._verify_cuda_graph() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8939a0d537281..5de20633ffdd6 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,7 @@ class EngineArgs: """Arguments for vLLM engine.""" model: str tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' trust_remote_code: bool = False download_dir: Optional[str] = None @@ -93,6 +94,10 @@ def add_cli_args( type=str, default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use.') + parser.add_argument( + '--skip-tokenizer-init', + action='store_true', + help='Skip initialization of tokenizer and detokenizer') parser.add_argument( '--revision', type=str, @@ -453,7 +458,7 @@ def create_engine_config(self, ) -> EngineConfig: self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs) + self.max_logprobs, self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 88b344f50767f..d96025ea1fb6a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -100,6 +100,7 @@ def __init__( f"model={model_config.model!r}, " f"speculative_config={speculative_config!r}, " f"tokenizer={model_config.tokenizer!r}, " + f"skip_tokenizer_init={model_config.skip_tokenizer_init}, " f"tokenizer_mode={model_config.tokenizer_mode}, " f"revision={model_config.revision}, " f"tokenizer_revision={model_config.tokenizer_revision}, " @@ -132,8 +133,14 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats - self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) + if not self.model_config.skip_tokenizer_init: + self.tokenizer: BaseTokenizerGroup + self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + else: + self.detokenizer = None + self.tokenizer = None + self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) @@ -187,9 +194,10 @@ def __init__( parallel_config.disable_custom_all_reduce, }) - # Ping the tokenizer to ensure liveness if it runs in a - # different process. - self.tokenizer.ping() + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of @@ -296,7 +304,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: BaseTokenizerGroup = get_tokenizer_group( + self.tokenizer = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) def _verify_args(self) -> None: @@ -393,8 +401,13 @@ def add_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id + eos_token_id = None + if self.tokenizer: + eos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).eos_token_id + else: + logger.warning("Use None for EOS token id because tokenizer is " + "not initialized") seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, eos_token_id, lora_request) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 1b7eb014f802b..b32937327ba7f 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -59,7 +59,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and seq_group.sampling_params.detokenize: + if prompt_logprobs is not None and \ + seq_group.sampling_params.detokenize and self.detokenizer: self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) seq_group.prompt_logprobs = prompt_logprobs @@ -105,7 +106,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child_seqs.append((parent, parent)) for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize: + if seq_group.sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, seq_group.sampling_params) else: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f745dbd736d17..b022707794a78 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -32,6 +32,9 @@ class LLM: tokenizer: The name or path of a HuggingFace Transformers tokenizer. tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. tensor_parallel_size: The number of GPUs to use for distributed @@ -76,6 +79,7 @@ def __init__( model: str, tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, trust_remote_code: bool = False, tensor_parallel_size: int = 1, dtype: str = "auto", @@ -96,6 +100,7 @@ def __init__( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, dtype=dtype, @@ -160,6 +165,10 @@ def generate( if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " "provided.") + if self.llm_engine.model_config.skip_tokenizer_init \ + and prompts is not None: + raise ValueError("prompts must be None if skip_tokenizer_init " + "is True") if isinstance(prompts, str): # Convert a single prompt to a list. prompts = [prompts] From 95e5b087cfed33bf20891852b2d7a0cac2341519 Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Mon, 22 Apr 2024 00:57:24 -0400 Subject: [PATCH 085/413] [AMD][Hardware][Misc][Bugfix] xformer cleanup and light navi logic and CI fixes and refactoring (#4129) --- .buildkite/test-pipeline.yaml | 2 - Dockerfile.rocm | 5 +- patch_xformers.rocm.sh | 33 ---- .../commonpy_xformers-0.0.23.rocm.patch | 13 -- rocm_patch/flashpy_xformers-0.0.23.rocm.patch | 152 ------------------ vllm/attention/backends/rocm_flash_attn.py | 31 ++-- 6 files changed, 19 insertions(+), 217 deletions(-) delete mode 100644 patch_xformers.rocm.sh delete mode 100644 rocm_patch/commonpy_xformers-0.0.23.rocm.patch delete mode 100644 rocm_patch/flashpy_xformers-0.0.23.rocm.patch diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0f920c7ec1442..f7c1569696249 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -15,10 +15,8 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test command: pytest -v -s core diff --git a/Dockerfile.rocm b/Dockerfile.rocm index b1c5fac9d78ef..3f84b949481d1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE" ARG FA_GFX_ARCHS="gfx90a;gfx942" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" -ARG FA_BRANCH="3d2b6f5" +ARG FA_BRANCH="ae7928c" RUN echo "FA_BRANCH is $FA_BRANCH" # whether to build flash-attention @@ -92,13 +92,10 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip numba -RUN python3 -m pip install xformers==0.0.23 --no-deps RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ - && if [ "$BUILD_FA" = "1" ]; then \ - bash patch_xformers.rocm.sh; fi \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ && cd .. diff --git a/patch_xformers.rocm.sh b/patch_xformers.rocm.sh deleted file mode 100644 index de427b24d306f..0000000000000 --- a/patch_xformers.rocm.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -set -e - -XFORMERS_VERSION="0.0.23" - -export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)') - -if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then - echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed" - exit 1 -fi - -export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') -export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') - -echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}" -echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}" - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" - patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" -else - echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" -fi - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" - patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" -else - echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" -fi diff --git a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch b/rocm_patch/commonpy_xformers-0.0.23.rocm.patch deleted file mode 100644 index 4d7495cf13e1d..0000000000000 --- a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch +++ /dev/null @@ -1,13 +0,0 @@ ---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000 -+++ common.py 2023-11-28 16:14:19.846233146 +0000 -@@ -298,8 +298,8 @@ - dtype = d.query.dtype - if device_type not in cls.SUPPORTED_DEVICES: - reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") -- if device_type == "cuda" and not _built_with_cuda: -- reasons.append("xFormers wasn't build with CUDA support") -+ #if device_type == "cuda" and not _built_with_cuda: -+ # reasons.append("xFormers wasn't build with CUDA support") - if device_type == "cuda": - device_capability = torch.cuda.get_device_capability(d.device) - if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: diff --git a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch b/rocm_patch/flashpy_xformers-0.0.23.rocm.patch deleted file mode 100644 index ac846728a7a91..0000000000000 --- a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch +++ /dev/null @@ -1,152 +0,0 @@ ---- flash_ori.py 2023-12-13 05:43:31.530752623 +0000 -+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000 -@@ -36,44 +36,44 @@ - - FLASH_VERSION = "0.0.0" - try: -- try: -- from ... import _C_flashattention # type: ignore[attr-defined] -- from ..._cpp_lib import _build_metadata -- -- if _build_metadata is not None: -- FLASH_VERSION = _build_metadata.flash_version -- except ImportError: -- import flash_attn -- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention -- -- FLASH_VERSION = flash_attn.__version__ -- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) -- if ( -- flash_ver_parsed != (2, 3, 6) -- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" -- ): -- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") -+ #try: -+ # from ... import _C_flashattention # type: ignore[attr-defined] -+ # from ..._cpp_lib import _build_metadata -+ -+ # if _build_metadata is not None: -+ # FLASH_VERSION = _build_metadata.flash_version -+ #except ImportError: -+ import flash_attn -+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention -+ -+ FLASH_VERSION = flash_attn.__version__ -+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) -+ # if ( -+ # flash_ver_parsed != (2, 3, 6) -+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" -+ # ): -+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") - - # create library so that flash-attn goes through the PyTorch Dispatcher -- _flash_lib = torch.library.Library("xformers_flash", "DEF") -- -- _flash_lib.define( -- "flash_fwd(Tensor query, Tensor key, Tensor value, " -- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " -- "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, " -- "bool is_causal, int window_left, " -- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" -- ) -+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") - -- _flash_lib.define( -- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " -- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " -- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " -- "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, bool is_causal, " -- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" -- ) -+ #_flash_lib.define( -+ # "flash_fwd(Tensor query, Tensor key, Tensor value, " -+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " -+ # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, " -+ # "bool is_causal, int window_left, " -+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" -+ #) -+ -+ #_flash_lib.define( -+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " -+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " -+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " -+ # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, bool is_causal, " -+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" -+ #) - - def _flash_fwd( - query, -@@ -111,8 +111,8 @@ - p, - softmax_scale, - is_causal, -- window_left, # window_size_left -- window_right, # window_size_right -+ # window_left, # window_size_left -+ # window_right, # window_size_right - return_softmax, - None, # rng - ) -@@ -134,15 +134,15 @@ - out, - cu_seq_lens_q, - cu_seq_lens_k, -- seqused_k, -+ # seqused_k, - max_seq_len_q, - max_seq_len_k, - p, - softmax_scale, - False, - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - return_softmax, - None, - ) -@@ -184,8 +184,8 @@ - p, - softmax_scale, - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - None, - rng_state, - ) -@@ -208,15 +208,15 @@ - softmax_scale, - False, # zero_tensors - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - None, - rng_state, - ) - return dq, dk, dv - -- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") -- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") -+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") -+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") - except ImportError: - pass - -@@ -400,7 +400,7 @@ - implementation. - """ - -- OPERATOR = get_operator("xformers_flash", "flash_fwd") -+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") - SUPPORTED_DEVICES: Set[str] = {"cuda"} - CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) - SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c42660fb8f74f..dbaa71fd16add 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -154,25 +154,30 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 + self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = (os.environ.get( "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) - if self.use_naive_attn: - # AMD Radeon 7900 series (gfx1100) currently does not support - # xFormers nor FlashAttention. As a temporary workaround, we use - # naive PyTorch implementation of attention. - self.attn_fuc = _naive_attention - logger.debug("Using naive attention in ROCmBackend") - elif self.use_triton_flash_attn: + if self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") + # if not using triton, navi3x not use flash-attn either + if torch.cuda.get_device_capability()[0] == 11: + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + self.attn_func = _naive_attention + logger.debug("Using naive attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -247,13 +252,13 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_naive_attn or self.use_triton_flash_attn: + if self.use_triton_flash_attn or self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - out = self.attn_fuc( + out = self.attn_func( query, key, value, From 747b1a7147515c08491ef4aa1b23ea23329966ed Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 21 Apr 2024 23:04:16 -0700 Subject: [PATCH 086/413] [Core][Distributed] fix _is_full_nvlink detection (#4233) --- .../device_communicators/custom_all_reduce.py | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 7602897d3dd8f..58cbe77baf7e0 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,5 +1,6 @@ +import os from contextlib import contextmanager -from typing import Optional +from typing import List, Optional import torch import torch.distributed as dist @@ -53,14 +54,20 @@ def init_custom_ar() -> None: return False # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported - full_nvlink = _is_full_nvlink(rank, world_size) + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = list( + map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))) + else: + device_ids = list(range(num_dev)) + # this checks hardware and driver support for NVLink + full_nvlink = _is_full_nvlink(device_ids) if world_size > 2 and not full_nvlink: logger.warn( "Custom allreduce is disabled because it's not supported on more" " than two PCIe-only GPUs. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.") return - # test P2P capability + # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time # then we cache the result if not _can_p2p(rank, world_size): @@ -138,23 +145,28 @@ def _nvml(): pynvml.nvmlShutdown() -# query if the set of gpus are fully connected by nvlink (1 hop) @_nvml() -def _is_full_nvlink(rank, world_size): - handle = pynvml.nvmlDeviceGetHandleByIndex(rank) - for i in range(world_size): - if i != rank: - try: - peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i) - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: +def _is_full_nvlink(device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`, + so it works on real physical device ids. + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError as error: + logger.error( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped.", + exc_info=error) return False - except pynvml.NVMLError as error: - logger.info( - f"NVLink detection failed with message \"{str(error)}\". " - "This is normal if your machine has no NVLink equipped") - return False return True From 296cdf8ac7853b671d10f3df094288ea4f35bb38 Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:44:16 +0800 Subject: [PATCH 087/413] [Misc] Add vision language model support to CPU backend (#3968) --- vllm/executor/cpu_executor.py | 1 + vllm/worker/cpu_model_runner.py | 60 ++++++++++++++++++++------------- vllm/worker/cpu_worker.py | 24 ++++++++----- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index f925a6fc93dcd..35249cd7302cb 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -45,6 +45,7 @@ def _init_worker(self): rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + vision_language_config=self.vision_language_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7377c8931cefa..a82373d3d1626 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -5,7 +5,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) + ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -29,6 +29,7 @@ def __init__( device_config: DeviceConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, *args, @@ -38,6 +39,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker @@ -59,13 +61,14 @@ def __init__( self.block_size: int # Set after initial profiling. def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - vision_language_config=None, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + vision_language_config=self.vision_language_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def _prepare_prompt( self, @@ -76,6 +79,7 @@ def _prepare_prompt( input_positions: List[int] = [] slot_mapping: List[int] = [] prompt_lens: List[int] = [] + multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt @@ -96,6 +100,10 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, prompt_len))) + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, @@ -118,6 +126,15 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, @@ -144,12 +161,8 @@ def _prepare_prompt( slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return ( - input_tokens, - input_positions, - attn_metadata, - prompt_lens, - ) + return (input_tokens, input_positions, attn_metadata, prompt_lens, + multi_modal_input) def _prepare_decode( self, @@ -336,14 +349,16 @@ def prepare_input_tensors( seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata]: + multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, prompt_lens, + multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) @@ -376,12 +391,8 @@ def prepare_input_tensors( perform_sampling=False, ) - return ( - input_tokens, - input_positions, - attn_metadata, - sampling_metadata, - ) + return (input_tokens, input_positions, attn_metadata, + sampling_metadata, multi_modal_input) @torch.inference_mode() def execute_model( @@ -389,7 +400,8 @@ def execute_model( seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: - (input_tokens, input_positions, attn_metadata, sampling_metadata + (input_tokens, input_positions, attn_metadata, sampling_metadata, + multi_modal_input ) = self.prepare_input_tensors(seq_group_metadata_list) model_executable = self.model @@ -399,6 +411,8 @@ def execute_model( "kv_caches": kv_caches, "attn_metadata": attn_metadata, } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3652830b7d519..83ededd742533 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,7 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) @@ -122,6 +123,7 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: @@ -135,21 +137,25 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.vision_language_config = vision_language_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = CPUModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - load_config=self.load_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + self.model_runner = CPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CPUCacheEngine From e73ed0f1c624f85d348c0709c256a0ae6627986b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 22 Apr 2024 00:54:16 -0700 Subject: [PATCH 088/413] [Bugfix] Fix type annotations in CPU model runner (#4256) --- vllm/worker/cpu_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index a82373d3d1626..bf0a6c84e6f07 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -73,7 +73,8 @@ def load_model(self) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + Optional[torch.Tensor]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] @@ -347,8 +348,8 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, - SamplingMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + Optional[torch.Tensor]]: multi_modal_input = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or From 077f0a2e8a873340b1a2cf54d6c9043754eb7514 Mon Sep 17 00:00:00 2001 From: Tao He Date: Mon, 22 Apr 2024 17:19:51 +0800 Subject: [PATCH 089/413] [Frontend] Enable support for CPU backend in AsyncLLMEngine. (#3993) Signed-off-by: Tao He --- vllm/engine/async_llm_engine.py | 5 +++++ vllm/executor/cpu_executor.py | 27 +++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ca4ba66f09cb8..3a2f7db679358 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -343,6 +343,11 @@ def from_engine_args( if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "cpu": + assert not engine_config.parallel_config.worker_use_ray, ( + "Ray is not supported with the CPU backend.") + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync elif engine_config.parallel_config.worker_use_ray: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 35249cd7302cb..8d6a1fff91fd8 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -4,11 +4,12 @@ import torch from vllm.config import CacheConfig, ModelConfig, SchedulerConfig -from vllm.executor.executor_base import ExecutorBase +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) logger = init_logger(__name__) @@ -100,6 +101,28 @@ def check_health(self) -> None: return +class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model)( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + return output + + async def check_health_async(self) -> None: + # CPUExecutor will always be healthy as long as + # it's running. + return + + def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: if config.dtype == torch.float16: logger.warning("float16 is not supported on CPU, casting to bfloat16.") From 15436806912d7ad9371c8bcf6a46857590c107d2 Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Mon, 22 Apr 2024 12:10:48 -0400 Subject: [PATCH 090/413] [Bugfix] Ensure download_weights_from_hf(..) inside loader is using the revision parameter (#4217) --- vllm/model_executor/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 6c8cb2935f37e..64cd186506bdb 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -172,7 +172,7 @@ def _prepare_weights(self, model_name_or_path: str, if not is_local: hf_folder = download_weights_from_hf(model_name_or_path, self.load_config.download_dir, - allow_patterns) + allow_patterns, revision) else: hf_folder = model_name_or_path From 3d925165f2b18379640a63fbb42de95440d63b64 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 22 Apr 2024 17:36:54 +0100 Subject: [PATCH 091/413] Add example scripts to documentation (#4225) Co-authored-by: Harry Mellor --- .gitignore | 2 + docs/source/conf.py | 9 ++- docs/source/generate_examples.py | 61 +++++++++++++++++++ .../examples/examples_index.template.rst | 8 +++ docs/source/index.rst | 1 + ...nt.py => openai_chat_completion_client.py} | 0 6 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 docs/source/generate_examples.py create mode 100644 docs/source/getting_started/examples/examples_index.template.rst rename examples/{openai_chatcompletion_client.py => openai_chat_completion_client.py} (100%) diff --git a/.gitignore b/.gitignore index b1513ef0ddb0c..e077366d1e4a1 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,8 @@ instance/ # Sphinx documentation docs/_build/ +docs/source/getting_started/examples/*.rst +!**/*.template.rst # PyBuilder .pybuilder/ diff --git a/docs/source/conf.py b/docs/source/conf.py index cfa956b143ba3..aac8cbb63ebeb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -48,7 +48,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns: List[str] = [] +exclude_patterns: List[str] = ["**/*.template.rst"] # Exclude the prompt "$" when copying code copybutton_prompt_text = r"\$ " @@ -73,6 +73,13 @@ # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] + +# Generate additional rst documentation here. +def setup(app): + from docs.source.generate_examples import generate_examples + generate_examples() + + # Mock out external dependencies here. autodoc_mock_imports = [ "cpuinfo", diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py new file mode 100644 index 0000000000000..79b49a186236a --- /dev/null +++ b/docs/source/generate_examples.py @@ -0,0 +1,61 @@ +import re +from pathlib import Path + + +def fix_case(text: str) -> str: + subs = [ + ("api", "API"), + ("llm", "LLM"), + ("vllm", "vLLM"), + ("openai", "OpenAI"), + ("multilora", "MultiLoRA"), + ] + for sub in subs: + text = re.sub(*sub, text, flags=re.IGNORECASE) + return text + + +def underline(title: str, character: str = "=") -> str: + return f"{title}\n{character * len(title)}" + + +def generate_title(filename: str) -> str: + # Turn filename into a title + title = filename.replace("_", " ").title() + # Handle acronyms and names + title = fix_case(title) + # Underline title + title = underline(title) + return title + + +def generate_examples(): + root_dir = Path(__file__).parent.parent.parent.resolve() + + # Source paths + script_dir = root_dir / "examples" + script_paths = sorted(script_dir.glob("*.py")) + + # Destination paths + doc_dir = root_dir / "docs/source/getting_started/examples" + doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths] + + # Generate the example docs for each example script + for script_path, doc_path in zip(script_paths, doc_paths): + script_url = f"https://github.com/vllm-project/vllm/blob/main/examples/{script_path.name}" + # Make script_path relative to doc_path and call it include_path + include_path = '../../../..' / script_path.relative_to(root_dir) + content = (f"{generate_title(doc_path.stem)}\n\n" + f"Source {script_url}.\n\n" + f".. literalinclude:: {include_path}\n" + " :language: python\n" + " :linenos:\n") + with open(doc_path, "w+") as f: + f.write(content) + + # Generate the toctree for the example scripts + with open(doc_dir / "examples_index.template.rst") as f: + examples_index = f.read() + with open(doc_dir / "examples_index.rst", "w+") as f: + example_docs = "\n ".join(path.stem for path in script_paths) + f.write(examples_index.replace(r"%EXAMPLE_DOCS%", example_docs)) diff --git a/docs/source/getting_started/examples/examples_index.template.rst b/docs/source/getting_started/examples/examples_index.template.rst new file mode 100644 index 0000000000000..1b34cccbae15a --- /dev/null +++ b/docs/source/getting_started/examples/examples_index.template.rst @@ -0,0 +1,8 @@ +Examples +================================= + +.. toctree:: + :maxdepth: 1 + :caption: Scripts + + %EXAMPLE_DOCS% diff --git a/docs/source/index.rst b/docs/source/index.rst index 5d5d52696ba34..e8daa5f052754 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -65,6 +65,7 @@ Documentation getting_started/neuron-installation getting_started/cpu-installation getting_started/quickstart + getting_started/examples/examples_index .. toctree:: :maxdepth: 1 diff --git a/examples/openai_chatcompletion_client.py b/examples/openai_chat_completion_client.py similarity index 100% rename from examples/openai_chatcompletion_client.py rename to examples/openai_chat_completion_client.py From ad8d696a99ca1eee19f1404e16e8e82df592ff85 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 23 Apr 2024 06:11:06 +0900 Subject: [PATCH 092/413] [Core] Scheduler perf fix (#4270) --- tests/core/test_scheduler.py | 18 +++++++++--------- vllm/core/scheduler.py | 7 ++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 9588a1bead5f6..a2511238506b0 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -540,7 +540,7 @@ def test_decode_schedule_preempted(): curr_loras = None for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) running.append(seq_group) scheduler.block_manager.can_append_slots = MagicMock() @@ -581,7 +581,7 @@ def test_decode_swap_beam_search(): budget = create_token_budget() for i in range(3): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) running.append(seq_group) append_new_token_seq_group(60, seq_group, 1) budget.add_num_seqs(seq_group.request_id, @@ -629,7 +629,7 @@ def test_schedule_decode_blocks_to_copy_update(): running = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) running.append(seq_group) @@ -659,7 +659,7 @@ def test_schedule_swapped_simple(): curr_loras = None blocks_to_swap_out = {} _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -687,7 +687,7 @@ def test_schedule_swapped_max_token_budget(): blocks_to_swap_out = {} for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -721,7 +721,7 @@ def test_schedule_swapped_max_seqs(): blocks_to_swap_out = {} for i in range(4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -759,7 +759,7 @@ def test_schedule_swapped_max_loras(): lora_name=str(i), lora_int_id=i + 1, lora_local_path="abc")) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -783,7 +783,7 @@ def test_schedule_swapped_cannot_swap_in(): blocks_to_swap_out = {} for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) @@ -808,7 +808,7 @@ def test_schedule_swapped_blocks_to_copy(): policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group, 60) + scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out = {} scheduler._swap_out(seq_group, blocks_to_swap_out) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4198550621030..8d7db09bbea08 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -297,7 +297,6 @@ def num_decoding_tokens_per_seq(self) -> int: def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. - logger.debug(f"add_seq_group {seq_group.request_id}") self.waiting.append(seq_group) def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: @@ -427,7 +426,6 @@ def _schedule_running( swapped_out.append(seq_group) break else: - logger.debug(f"append slot for {seq_group}") self._append_slots(seq_group, blocks_to_copy) is_prefill = seq_group.is_prefill() if is_prefill: @@ -659,7 +657,7 @@ def _schedule_prefills( if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - self._allocate_and_set_running(seq_group, num_new_tokens) + self._allocate_and_set_running(seq_group) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -952,8 +950,7 @@ def free_finished_seq_groups(self) -> None: self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) - def _allocate_and_set_running(self, seq_group: SequenceGroup, - num_new_tokens: int) -> None: + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING From ceaf4ed0030c7816537359b8efc750474149ce0f Mon Sep 17 00:00:00 2001 From: Zhanghao Wu Date: Mon, 22 Apr 2024 15:34:31 -0700 Subject: [PATCH 093/413] [Doc] Update the SkyPilot doc with serving and Llama-3 (#4276) --- docs/source/serving/run_on_sky.rst | 287 ++++++++++++++++++++++++++--- 1 file changed, 264 insertions(+), 23 deletions(-) diff --git a/docs/source/serving/run_on_sky.rst b/docs/source/serving/run_on_sky.rst index 2c88d24dc5d0b..bd33c76cec3de 100644 --- a/docs/source/serving/run_on_sky.rst +++ b/docs/source/serving/run_on_sky.rst @@ -1,7 +1,7 @@ .. _on_cloud: -Running on clouds with SkyPilot -=============================== +Deploying and scaling up with SkyPilot +================================================ .. raw:: html @@ -9,51 +9,75 @@ Running on clouds with SkyPilot vLLM

-vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot `__, an open-source framework for running LLMs on any cloud. +vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with `SkyPilot `__, an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc, can be found in `SkyPilot AI gallery `__. -To install SkyPilot and setup your cloud credentials, run: + +Prerequisites +------------- + +- Go to the `HuggingFace model page `__ and request access to the model :code:`meta-llama/Meta-Llama-3-8B-Instruct`. +- Check that you have installed SkyPilot (`docs `__). +- Check that :code:`sky check` shows clouds or Kubernetes are enabled. .. code-block:: console - $ pip install skypilot - $ sky check + pip install skypilot-nightly + sky check + + +Run on a single instance +------------------------ See the vLLM SkyPilot YAML for serving, `serving.yaml `__. .. code-block:: yaml resources: - accelerators: A100 + accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model. + use_spot: True + disk_size: 512 # Ensure model checkpoints can fit. + disk_tier: best + ports: 8081 # Expose to internet traffic. envs: - MODEL_NAME: decapoda-research/llama-13b-hf - TOKENIZER: hf-internal-testing/llama-tokenizer + MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct + HF_TOKEN: # Change to your own huggingface token, or use --env to pass. setup: | - conda create -n vllm python=3.9 -y + conda create -n vllm python=3.10 -y conda activate vllm - git clone https://github.com/vllm-project/vllm.git - cd vllm - pip install . - pip install gradio + + pip install vllm==0.4.0.post1 + # Install Gradio for web UI. + pip install gradio openai + pip install flash-attn==2.5.7 run: | conda activate vllm echo 'Starting vllm api server...' - python -u -m vllm.entrypoints.api_server \ - --model $MODEL_NAME \ - --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ - --tokenizer $TOKENIZER 2>&1 | tee api_server.log & + python -u -m vllm.entrypoints.openai.api_server \ + --port 8081 \ + --model $MODEL_NAME \ + --trust-remote-code \ + --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ + 2>&1 | tee api_server.log & + echo 'Waiting for vllm api server to start...' while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done + echo 'Starting gradio server...' - python vllm/examples/gradio_webserver.py + git clone https://github.com/vllm-project/vllm.git || true + python vllm/examples/gradio_openai_chatbot_webserver.py \ + -m $MODEL_NAME \ + --port 8811 \ + --model-url http://localhost:8081/v1 \ + --stop-token-ids 128009,128001 -Start the serving the LLaMA-13B model on an A100 GPU: +Start the serving the Llama-3 8B model on any of the candidate GPUs listed (L4, A10g, ...): .. code-block:: console - $ sky launch serving.yaml + HF_TOKEN="your-huggingface-token" sky launch serving.yaml --env HF_TOKEN Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion. @@ -61,9 +85,226 @@ Check the output of the command. There will be a shareable gradio link (like the (task, pid=7431) Running on public URL: https://.gradio.live -**Optional**: Serve the 65B model instead of the default 13B and use more GPU: +**Optional**: Serve the 70B model instead of the default 8B and use more GPU: + +.. code-block:: console + + HF_TOKEN="your-huggingface-token" sky launch serving.yaml --gpus A100:8 --env HF_TOKEN --env MODEL_NAME=meta-llama/Meta-Llama-3-70B-Instruct + + +Scale up to multiple replicas +----------------------------- + +SkyPilot can scale up the service to multiple service replicas with built-in autoscaling, load-balancing and fault-tolerance. You can do it by adding a services section to the YAML file. + +.. code-block:: yaml + + service: + replicas: 2 + # An actual request for readiness probe. + readiness_probe: + path: /v1/chat/completions + post_data: + model: $MODEL_NAME + messages: + - role: user + content: Hello! What is your name? + max_tokens: 1 + +.. raw:: html + +
+ Click to see the full recipe YAML + + +.. code-block:: yaml + + service: + replicas: 2 + # An actual request for readiness probe. + readiness_probe: + path: /v1/chat/completions + post_data: + model: $MODEL_NAME + messages: + - role: user + content: Hello! What is your name? + max_tokens: 1 + + resources: + accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model. + use_spot: True + disk_size: 512 # Ensure model checkpoints can fit. + disk_tier: best + ports: 8081 # Expose to internet traffic. + + envs: + MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct + HF_TOKEN: # Change to your own huggingface token, or use --env to pass. + + setup: | + conda create -n vllm python=3.10 -y + conda activate vllm + + pip install vllm==0.4.0.post1 + # Install Gradio for web UI. + pip install gradio openai + pip install flash-attn==2.5.7 + + run: | + conda activate vllm + echo 'Starting vllm api server...' + python -u -m vllm.entrypoints.openai.api_server \ + --port 8081 \ + --model $MODEL_NAME \ + --trust-remote-code \ + --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \ + 2>&1 | tee api_server.log & + + echo 'Waiting for vllm api server to start...' + while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done + + echo 'Starting gradio server...' + git clone https://github.com/vllm-project/vllm.git || true + python vllm/examples/gradio_openai_chatbot_webserver.py \ + -m $MODEL_NAME \ + --port 8811 \ + --model-url http://localhost:8081/v1 \ + --stop-token-ids 128009,128001 + +.. raw:: html + +
+ +Start the serving the Llama-3 8B model on multiple replicas: + +.. code-block:: console + + HF_TOKEN="your-huggingface-token" sky serve up -n vllm serving.yaml --env HF_TOKEN + + +Wait until the service is ready: .. code-block:: console - sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf + watch -n10 sky serve status vllm + + +.. raw:: html + +
+ Example outputs: + +.. code-block:: console + + Services + NAME VERSION UPTIME STATUS REPLICAS ENDPOINT + vllm 1 35s READY 2/2 xx.yy.zz.100:30001 + + Service Replicas + SERVICE_NAME ID VERSION IP LAUNCHED RESOURCES STATUS REGION + vllm 1 1 xx.yy.zz.121 18 mins ago 1x GCP({'L4': 1}) READY us-east4 + vllm 2 1 xx.yy.zz.245 18 mins ago 1x GCP({'L4': 1}) READY us-east4 + +.. raw:: html + +
+ +After the service is READY, you can find a single endpoint for the service and access the service with the endpoint: + +.. code-block:: console + + ENDPOINT=$(sky serve status --endpoint 8081 vllm) + curl -L http://$ENDPOINT/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Who are you?" + } + ], + "stop_token_ids": [128009, 128001] + }' + +To enable autoscaling, you could specify additional configs in `services`: + +.. code-block:: yaml + + services: + replica_policy: + min_replicas: 0 + max_replicas: 3 + target_qps_per_replica: 2 + +This will scale the service up to when the QPS exceeds 2 for each replica. + + +**Optional**: Connect a GUI to the endpoint +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +It is also possible to access the Llama-3 service with a separate GUI frontend, so the user requests send to the GUI will be load-balanced across replicas. + +.. raw:: html + +
+ Click to see the full GUI YAML + +.. code-block:: yaml + + envs: + MODEL_NAME: meta-llama/Meta-Llama-3-70B-Instruct + ENDPOINT: x.x.x.x:3031 # Address of the API server running vllm. + + resources: + cpus: 2 + + setup: | + conda activate vllm + if [ $? -ne 0 ]; then + conda create -n vllm python=3.10 -y + conda activate vllm + fi + + # Install Gradio for web UI. + pip install gradio openai + + run: | + conda activate vllm + export PATH=$PATH:/sbin + WORKER_IP=$(hostname -I | cut -d' ' -f1) + CONTROLLER_PORT=21001 + WORKER_PORT=21002 + + echo 'Starting gradio server...' + git clone https://github.com/vllm-project/vllm.git || true + python vllm/examples/gradio_openai_chatbot_webserver.py \ + -m $MODEL_NAME \ + --port 8811 \ + --model-url http://$ENDPOINT/v1 \ + --stop-token-ids 128009,128001 | tee ~/gradio.log + +.. raw:: html + +
+ +1. Start the chat web UI: + +.. code-block:: console + + sky launch -c gui ./gui.yaml --env ENDPOINT=$(sky serve status --endpoint vllm) + + +2. Then, we can access the GUI at the returned gradio link: + +.. code-block:: console + + | INFO | stdout | Running on public URL: https://6141e84201ce0bb4ed.gradio.live + From c1b4e4157c0b4154f950adaea85a259fc629c758 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 22 Apr 2024 17:21:48 -0700 Subject: [PATCH 094/413] [Core][Distributed] use absolute path for library file (#4271) --- vllm/utils.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index fbe86dacaeb99..15c8818cc4506 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -553,6 +553,34 @@ def nccl_integrity_check(filepath): return version.value +@lru_cache(maxsize=None) +def find_library(lib_name: str) -> str: + """ + Find the library file in the system. + `lib_name` is full filename, with both prefix and suffix. + This function resolves `lib_name` to the full path of the library. + """ + # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa + # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard + # `/sbin/ldconfig` should exist in all Linux systems. + # `/sbin/ldconfig` searches the library in the system + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] + # `LD_LIBRARY_PATH` searches the library in the user-defined paths + env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + if not locs and env_ld_library_path: + locs = [ + os.path.join(dir, lib_name) + for dir in env_ld_library_path.split(":") + if os.path.exists(os.path.join(dir, lib_name)) + ] + if not locs: + raise ValueError(f"Cannot find {lib_name} in the system.") + return locs[0] + + def find_nccl_library(): so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") @@ -572,9 +600,9 @@ def find_nccl_library(): ) else: if torch.version.cuda is not None: - so_file = vllm_nccl_path or "libnccl.so.2" + so_file = vllm_nccl_path or find_library("libnccl.so.2") elif torch.version.hip is not None: - so_file = "librccl.so.1" + so_file = find_library("librccl.so.1") else: raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info(f"Found nccl from library {so_file}") From 34128a697ed2dd4d88a9829f35445fbfc5b85c64 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 23 Apr 2024 02:53:01 +0100 Subject: [PATCH 095/413] Fix `autodoc` directives (#4272) Co-authored-by: Harry Mellor --- docs/source/dev/engine/async_llm_engine.rst | 5 ++--- docs/source/dev/engine/llm_engine.rst | 6 +++--- docs/source/dev/sampling_params.rst | 3 ++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/dev/engine/async_llm_engine.rst b/docs/source/dev/engine/async_llm_engine.rst index 47db1e0a401b1..93fc310cb543b 100644 --- a/docs/source/dev/engine/async_llm_engine.rst +++ b/docs/source/dev/engine/async_llm_engine.rst @@ -1,7 +1,6 @@ - AsyncLLMEngine ================================= -.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine - :members: generate, abort +.. autoclass:: vllm.AsyncLLMEngine + :members: :show-inheritance: diff --git a/docs/source/dev/engine/llm_engine.rst b/docs/source/dev/engine/llm_engine.rst index 1de6d7adc87c6..0b8c1e219d7c9 100644 --- a/docs/source/dev/engine/llm_engine.rst +++ b/docs/source/dev/engine/llm_engine.rst @@ -1,6 +1,6 @@ LLMEngine ================================= -.. autoclass:: vllm.engine.llm_engine.LLMEngine - :members: add_request, abort_request, step - :show-inheritance: \ No newline at end of file +.. autoclass:: vllm.LLMEngine + :members: + :show-inheritance: diff --git a/docs/source/dev/sampling_params.rst b/docs/source/dev/sampling_params.rst index 844859b3ec1f0..ef3d1509bda6d 100644 --- a/docs/source/dev/sampling_params.rst +++ b/docs/source/dev/sampling_params.rst @@ -1,4 +1,5 @@ Sampling Params =============== -.. automodule:: vllm.sampling_params.SamplingParams \ No newline at end of file +.. autoclass:: vllm.SamplingParams + :members: From 0ae11f78ab89556d5d867abb98f8a132f7507269 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 23 Apr 2024 13:32:44 +0900 Subject: [PATCH 096/413] [Mypy] Part 3 fix typing for nested directories for most of directory (#4161) --- .github/workflows/mypy.yaml | 29 ++++++++++--------- format.sh | 26 ++++++++--------- pyproject.toml | 6 ++-- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/torch_sdpa.py | 3 +- vllm/attention/backends/xformers.py | 1 + vllm/core/block/block_table.py | 1 + vllm/core/block/common.py | 6 ++-- vllm/core/block/interfaces.py | 6 ++-- .../device_communicators/custom_all_reduce.py | 16 +++++----- .../device_communicators/pynccl.py | 22 ++++++++------ .../device_communicators/pynccl_utils.py | 5 +++- vllm/engine/output_processor/interfaces.py | 5 ++-- vllm/engine/output_processor/multi_step.py | 5 ++-- vllm/engine/output_processor/single_step.py | 9 +++--- vllm/engine/output_processor/util.py | 4 ++- vllm/entrypoints/openai/api_server.py | 6 ++-- vllm/entrypoints/openai/protocol.py | 13 +++++---- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 5 +++- vllm/entrypoints/openai/serving_engine.py | 19 +++++++----- vllm/lora/lora.py | 2 +- vllm/model_executor/layers/ops/sample.py | 3 +- .../model_executor/layers/rotary_embedding.py | 3 +- vllm/transformers_utils/configs/jais.py | 6 ++-- .../tokenizer_group/__init__.py | 2 +- .../tokenizer_group/ray_tokenizer_group.py | 2 ++ .../transformers_utils/tokenizers/baichuan.py | 4 +-- 29 files changed, 126 insertions(+), 88 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 477ce9bc9ce85..9f1855696e20a 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -32,19 +32,20 @@ jobs: pip install types-setuptools - name: Mypy run: | - mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/attention --config-file pyproject.toml + # TODO(sang): Fix nested dir mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml - - mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml - mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml - # TODO(sang): Follow up - # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/distributed --config-file pyproject.toml + mypy vllm/entrypoints --config-file pyproject.toml + mypy vllm/executor --config-file pyproject.toml + mypy vllm/usage --config-file pyproject.toml + mypy vllm/*.py --config-file pyproject.toml + mypy vllm/transformers_utils --config-file pyproject.toml + mypy vllm/engine --config-file pyproject.toml + mypy vllm/worker --config-file pyproject.toml + mypy vllm/spec_decode --config-file pyproject.toml + # TODO(sang): Fix nested dir + mypy vllm/model_executor/*.py --config-file pyproject.toml + # TODO(sang): Fix nested dir + # mypy vllm/lora/*.py --config-file pyproject.toml diff --git a/format.sh b/format.sh index 84ee88b5b4c8a..bd2e9e89e1806 100755 --- a/format.sh +++ b/format.sh @@ -94,21 +94,19 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' -mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/attention --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml - -# TODO(sang): Follow up -mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml -mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/distributed --config-file pyproject.toml +mypy vllm/entrypoints --config-file pyproject.toml +mypy vllm/executor --config-file pyproject.toml +mypy vllm/usage --config-file pyproject.toml +mypy vllm/*.py --config-file pyproject.toml +mypy vllm/transformers_utils --config-file pyproject.toml +mypy vllm/engine --config-file pyproject.toml +mypy vllm/worker --config-file pyproject.toml +mypy vllm/spec_decode --config-file pyproject.toml +mypy vllm/model_executor/*.py --config-file pyproject.toml +# mypy vllm/lora/*.py --config-file pyproject.toml CODESPELL_EXCLUDES=( diff --git a/pyproject.toml b/pyproject.toml index b870a4b85897b..a171d45b4e064 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,15 +46,17 @@ ignore = [ python_version = "3.8" ignore_missing_imports = true - check_untyped_defs = true +check_untyped_defs = true +follow_imports = "skip" files = "vllm" # TODO(woosuk): Include the code from Megatron and HuggingFace. exclude = [ "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", + # Ignore triton kernels in ops. + 'vllm/attention/ops/.*\.py$' ] - [tool.codespell] ignore-words-list = "dout, te, indicies" skip = "./tests/prompts,./benchmarks/sonnet.txt" diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 7a4ccecf702f4..be747c9900368 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -116,7 +116,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[AttentionMetadataPerStage], + attn_metadata: AttentionMetadata, kv_scale: float, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index dbaa71fd16add..7c5863a030ff5 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -248,6 +248,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. + assert prefill_meta.prompt_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index d21b54b16db4b..55a7ce59ac6e0 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -106,7 +106,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: TorchSDPAMetadata, + attn_metadata: TorchSDPAMetadata, # type: ignore kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -136,6 +136,7 @@ def forward( kv_scale) if attn_metadata.is_prompt: + assert attn_metadata.prompt_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b745a04a143b4..572a4dc79a719 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -288,6 +288,7 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ + assert attn_metadata.prompt_lens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index 560267e55ea3a..f1b65b2514f76 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -104,6 +104,7 @@ def append_token_ids(self, token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated + assert self._blocks is not None self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 50c70533c4fbc..f11234a0bf2dd 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -99,7 +99,7 @@ def __init__( refcounter: RefCounter, allocator: BlockAllocator, ): - self._copy_on_writes = defaultdict(list) + self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) self._refcounter = refcounter self._allocator = allocator @@ -138,6 +138,8 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: prev_block=block.prev_block).block_id # Track src/dst copy. + assert src_block_id is not None + assert block_id is not None self._copy_on_writes[src_block_id].append(block_id) return block_id @@ -180,6 +182,6 @@ def recurse(block: Block, lst: List[Block]) -> None: recurse(block.prev_block, lst) lst.append(block) - all_blocks = [] + all_blocks: List[Block] = [] recurse(last_block, all_blocks) return all_blocks diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index fbceacf0ec417..50ce922118124 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -52,8 +52,7 @@ def __call__( class BlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable(self, prev_block: Optional[Block]) -> Block: pass @abstractmethod @@ -98,8 +97,7 @@ class NoFreeBlocksError(ValueError): class DeviceAwareBlockAllocator(BlockAllocator): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable(self, prev_block: Optional[Block]) -> Block: pass @abstractmethod diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 58cbe77baf7e0..9dbb427d91ff1 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,6 +1,6 @@ import os from contextlib import contextmanager -from typing import List, Optional +from typing import Any, List, Optional import torch import torch.distributed as dist @@ -18,7 +18,7 @@ logger = init_logger(__name__) -_CA_HANDLE = None +_CA_HANDLE: Optional["CustomAllreduce"] = None _IS_CAPTURING = False _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] @@ -51,7 +51,7 @@ def init_custom_ar() -> None: "Cannot test GPU P2P because not all GPUs are visible to the " "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" " is set.") - return False + return # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported if "CUDA_VISIBLE_DEVICES" in os.environ: @@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: ca_handle = get_handle() # when custom allreduce is disabled, this will be None if ca_handle is None: - return + return None if is_capturing(): if torch.cuda.is_current_stream_capturing(): if ca_handle.should_custom_ar(input): @@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: if ca_handle.should_custom_ar(input): return ca_handle.all_reduce_unreg(input) + return None + @contextmanager def _nvml(): @@ -224,14 +226,14 @@ def _get_ipc_meta(self, inp: torch.Tensor): return self._gather_ipc_meta(shard_data) def _gather_ipc_meta(self, shard_data): - all_data = [None] * self.world_size + all_data: List[Optional[Any]] = [None] * self.world_size dist.all_gather_object(all_data, shard_data) handles = [] offsets = [] for i in range(len(all_data)): - handles.append(all_data[i][0]) - offsets.append(all_data[i][1]) + handles.append(all_data[i][0]) # type: ignore + offsets.append(all_data[i][1]) # type: ignore return handles, offsets def register_buffer(self, inp: torch.Tensor): diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index c57a4f59d442c..0707afe922f40 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -107,9 +107,10 @@ def ncclGetUniqueId() -> NcclUniqueId: ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int ] +ncclDataType_t = ctypes.c_int -# enums -class ncclDataType_t(ctypes.c_int): + +class ncclDataTypeEnum: ncclInt8 = 0 ncclChar = 0 ncclUint8 = 1 @@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int): ncclNumTypes = 10 @classmethod - def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': + def from_torch(cls, dtype: torch.dtype) -> int: if dtype == torch.int8: return cls.ncclInt8 if dtype == torch.uint8: @@ -148,7 +149,10 @@ def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t': raise ValueError(f"Unsupported dtype: {dtype}") -class ncclRedOp_t(ctypes.c_int): +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: ncclSum = 0 ncclProd = 1 ncclMax = 2 @@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int): ncclNumOps = 5 @classmethod - def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': + def from_torch(cls, op: ReduceOp) -> int: if op == ReduceOp.SUM: return cls.ncclSum if op == ReduceOp.PRODUCT: @@ -180,8 +184,8 @@ def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t': _c_ncclAllReduce = nccl.ncclAllReduce _c_ncclAllReduce.restype = ctypes.c_int _c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t, - ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t, + ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p ] # equivalent to c declaration: @@ -251,8 +255,8 @@ def all_reduce(self, result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), ctypes.c_void_p(tensor.data_ptr()), tensor.numel(), - ncclDataType_t.from_torch(tensor.dtype), - ncclRedOp_t.from_torch(op), self.comm, + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, ctypes.c_void_p(stream.cuda_stream)) assert result == 0 diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index aeb73015733d1..916dc814af7eb 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -30,6 +30,7 @@ def is_initialized() -> bool: def set_pynccl_stream(stream: torch.cuda.Stream): """Set the cuda stream for communication""" try: + assert comm is not None comm.stream = stream yield finally: @@ -52,6 +53,7 @@ def init_process_group(world_size: int, def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: """All-reduces the input tensor across the process group.""" assert input_.is_cuda, f"{input_} should be a cuda tensor" + assert comm is not None comm.all_reduce(input_, op) @@ -62,8 +64,9 @@ def destroy_process_group() -> None: def get_world_size() -> int: """Returns the world size.""" + assert comm is not None return comm.world_size -def get_nccl_backend(): +def get_nccl_backend() -> Optional["NCCLCommunicator"]: return comm diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 9ddac7a04cb36..f307ea4da3011 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Iterable, List +from typing import Callable, List from transformers import PreTrainedTokenizer @@ -8,6 +8,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter class SequenceGroupOutputProcessor(ABC): @@ -27,7 +28,7 @@ def create_output_processor( scheduler_config: SchedulerConfig, detokenizer: Detokenizer, scheduler: Scheduler, - seq_counter: Iterable[int], + seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: "StopChecker", ): diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 50da0d35fcec1..39e99d06ed875 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, List +from typing import Callable, List from transformers import PreTrainedTokenizer @@ -11,6 +11,7 @@ from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter logger = init_logger(__name__) @@ -33,7 +34,7 @@ def __init__( self, detokenizer: Detokenizer, scheduler: Scheduler, - seq_counter: Iterable[int], + seq_counter: Counter, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], stop_checker: StopChecker, ): diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index b32937327ba7f..7e9d652446703 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Tuple, Union +from typing import Dict, List, Tuple, Union from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -10,6 +10,7 @@ from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter logger = init_logger(__name__) @@ -33,7 +34,7 @@ def __init__( scheduler_config: SchedulerConfig, detokenizer: Detokenizer, scheduler: Scheduler, - seq_counter: Iterable[int], + seq_counter: Counter, stop_checker: StopChecker, ): self.scheduler_config = scheduler_config @@ -69,7 +70,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { + parent_child_dict: Dict[int, List[SequenceOutput]] = { parent_seq.seq_id: [] for parent_seq in parent_seqs } @@ -92,7 +93,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, continue # Fork the parent sequence if there are multiple child samples. for child_sample in child_samples[:-1]: - new_child_seq_id = next(self.seq_counter) + new_child_seq_id: int = next(self.seq_counter) child = parent.fork(new_child_seq_id) child.append_token_id(child_sample.output_token, child_sample.logprobs) diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 5fbb09a857a46..d076fee8c2a36 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ - output_by_sequence_group = [[] for _ in range(num_seq_groups)] + output_by_sequence_group: List[List[SamplerOutput]] = [ + [] for _ in range(num_seq_groups) + ] for step in sampler_outputs: for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d6673976bb775..37d76b8e74055 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,6 +18,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, CompletionRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -26,8 +27,8 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -openai_serving_chat: OpenAIServingChat = None -openai_serving_completion: OpenAIServingCompletion = None +openai_serving_chat: OpenAIServingChat +openai_serving_completion: OpenAIServingCompletion logger = init_logger(__name__) @@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest, return StreamingResponse(content=generator, media_type="text/event-stream") else: + assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index cf779d44c816b..d9763d024eb83 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,7 +4,8 @@ from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, conint, model_validator +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Annotated from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -30,7 +31,7 @@ class ModelPermission(BaseModel): allow_fine_tuning: bool = False organization: str = "*" group: Optional[str] = None - is_blocking: str = False + is_blocking: bool = False class ModelCard(BaseModel): @@ -56,7 +57,7 @@ class UsageInfo(BaseModel): class ResponseFormat(BaseModel): # type must be "json_object" or "text" - type: str = Literal["text", "json_object"] + type: Literal["text", "json_object"] class ChatCompletionRequest(BaseModel): @@ -152,6 +153,7 @@ def to_sampling_params(self) -> SamplingParams: def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + assert self.logit_bias is not None for token_id, bias in self.logit_bias.items(): # Clamp the bias between -100 and 100 per OpenAI API spec bias = min(100, max(-100, bias)) @@ -213,7 +215,7 @@ class CompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[int] = None max_tokens: Optional[int] = 16 - n: Optional[int] = 1 + n: int = 1 presence_penalty: Optional[float] = 0.0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) @@ -235,7 +237,7 @@ class CompletionRequest(BaseModel): min_tokens: Optional[int] = 0 skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True - truncate_prompt_tokens: Optional[conint(ge=1)] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -289,6 +291,7 @@ def to_sampling_params(self): def logit_bias_logits_processor( token_ids: List[int], logits: torch.Tensor) -> torch.Tensor: + assert self.logit_bias is not None for token_id, bias in self.logit_bias.items(): # Clamp the bias between -100 and 100 per OpenAI API spec bias = min(100, max(-100, bias)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f35eab15bc487..d502dd0a4eb75 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -115,12 +115,12 @@ async def chat_completion_stream_generator( first_iteration = True # Send response for each token for each request.n (index) + assert request.n is not None previous_texts = [""] * request.n previous_num_tokens = [0] * request.n finish_reason_sent = [False] * request.n try: async for res in result_generator: - res: RequestOutput # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index b7e2530a69b51..211b2e0424c3e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -185,6 +185,7 @@ async def completion_stream_generator( model_name: str, num_prompts: int, ) -> AsyncGenerator[str, None]: + assert request.n is not None previous_texts = [""] * request.n * num_prompts previous_num_tokens = [0] * request.n * num_prompts has_echoed = [False] * request.n * num_prompts @@ -202,6 +203,7 @@ async def completion_stream_generator( # TODO(simon): optimize the performance by avoiding full # text O(n^2) sending. + assert request.max_tokens is not None if request.echo and request.max_tokens == 0: # only return the prompt delta_text = res.prompt @@ -279,7 +281,7 @@ def request_output_to_completion_response( created_time: int, model_name: str, ) -> CompletionResponse: - choices = [] + choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 for final_res in final_res_batch: @@ -289,6 +291,7 @@ def request_output_to_completion_response( prompt_text = final_res.prompt for output in final_res.outputs: + assert request.max_tokens is not None if request.echo and request.max_tokens == 0: token_ids = prompt_token_ids top_logprobs = prompt_logprobs diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 376b581052d85..610e807cae4c7 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -4,7 +4,9 @@ from http import HTTPStatus from typing import Dict, List, Optional, Tuple, Union -from pydantic import conint +from pydantic import Field +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing_extensions import Annotated from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -45,7 +47,8 @@ def __init__(self, ] self.max_model_len = 0 - self.tokenizer = None + # Lazy initialized + self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] try: event_loop = asyncio.get_running_loop() @@ -92,7 +95,7 @@ async def show_available_models(self) -> ModelList: def _create_logprobs( self, token_ids: List[int], - top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None, + top_logprobs: List[Optional[Dict[int, Logprob]]], num_output_top_logprobs: Optional[int] = None, initial_text_offset: int = 0, ) -> LogProbs: @@ -108,6 +111,7 @@ def _create_logprobs( token = self.tokenizer.decode(token_id) logprobs.tokens.append(token) logprobs.token_logprobs.append(None) + assert logprobs.top_logprobs is not None logprobs.top_logprobs.append(None) else: token_logprob = step_top_logprobs[token_id].logprob @@ -116,6 +120,7 @@ def _create_logprobs( logprobs.token_logprobs.append(token_logprob) if num_output_top_logprobs: + assert logprobs.top_logprobs is not None logprobs.top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses @@ -155,9 +160,9 @@ def create_streaming_error_response( async def _check_model(self, request) -> Optional[ErrorResponse]: if request.model in self.served_model_names: - return + return None if request.model in [lora.lora_name for lora in self.lora_requests]: - return + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", @@ -165,7 +170,7 @@ async def _check_model(self, request) -> Optional[ErrorResponse]: def _maybe_get_lora(self, request) -> Optional[LoRARequest]: if request.model in self.served_model_names: - return + return None for lora in self.lora_requests: if request.model == lora.lora_name: return lora @@ -177,7 +182,7 @@ def _validate_prompt_and_tokenize( request: Union[ChatCompletionRequest, CompletionRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None ) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 21c2196eb2739..fefad16700fe3 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -33,7 +33,7 @@ def __init__( def optimize(self) -> "LoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" if self.scaling == 1: - return + return self self.lora_b *= self.scaling self.scaling = 1 return self diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index a19e9461f41f7..d08ae6064aa2a 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -29,8 +29,8 @@ def _multi_split_sample( sampled_tokens_size: Tuple[int, int], sampled_logprobs_size: Tuple[int, int], sample_indices: torch.Tensor, + logprobs: torch.Tensor, *, - logprobs: Optional[torch.Tensor] = None, modify_greedy_probs: bool = False, save_logprobs: bool = False, ): @@ -167,6 +167,7 @@ def sample( sampled_logprobs_size = (0, 0) logprobs = probs + assert logprobs is not None if _save_modified_probs: sampled_modified_probs_size = sampled_tokens_size else: diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 6519781c8a8eb..a5225148d7828 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -108,7 +108,8 @@ def _forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py index 94f438716f8bf..b06a946f34a47 100644 --- a/vllm/transformers_utils/configs/jais.py +++ b/vllm/transformers_utils/configs/jais.py @@ -222,13 +222,15 @@ def _alibi_scaling_validation(self): f"got {alibi_scaling_type}") if (alibi_scaling_factor is not None and not isinstance(alibi_scaling_factor, float) - or alibi_scaling_factor <= 1.0): + or (alibi_scaling_factor is not None + and alibi_scaling_factor <= 1.0)): raise ValueError( f"`alibi_scaling`'s factor field must be a float > 1.0," f"got {alibi_scaling_factor}") if (alibi_dynamic_scaling is not None and not isinstance(alibi_dynamic_scaling, int) - or alibi_dynamic_scaling <= 1): + or (alibi_dynamic_scaling is not None + and alibi_dynamic_scaling <= 1)): raise ValueError( f"`alibi_scaling`'s `train_seq_len` field must be an" f"integer > 1, got {alibi_dynamic_scaling}") diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index a3b979e8fbc13..69380d67f9b94 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -11,7 +11,7 @@ from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( RayTokenizerGroupPool) else: - RayTokenizerGroupPool = None + RayTokenizerGroupPool = None # type: ignore def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index c00b02fdbbbc0..f3cdc00564dbb 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -89,6 +89,7 @@ def encode(self, This is blocking. """ self._ensure_queue_initialized() + assert self._idle_actors is not None if self._idle_actors.empty(): raise RuntimeError("No idle actors available.") @@ -120,6 +121,7 @@ async def encode_async( This is non-blocking. """ self._ensure_queue_initialized() + assert self._idle_actors is not None actor = await self._idle_actors.get() try: diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py index 79894035cb1f1..76daabc41e0a2 100644 --- a/vllm/transformers_utils/tokenizers/baichuan.py +++ b/vllm/transformers_utils/tokenizers/baichuan.py @@ -114,9 +114,9 @@ def _convert_id_to_token(self, index): token = self.sp_model.IdToPiece(index) return token - def convert_tokens_to_string(self, tokens): + def convert_tokens_to_string(self, tokens: List[str]): """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] + current_sub_tokens: List[str] = [] out_string = "" prev_is_special = False for i, token in enumerate(tokens): From 8f2ea22bde67161895152e7f7ad602ac96ea326e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 23 Apr 2024 00:49:08 -0700 Subject: [PATCH 097/413] [Core] Some simplification of WorkerWrapper changes (#4183) --- vllm/executor/ray_gpu_executor.py | 90 ++++++++++++++----------------- vllm/worker/worker_base.py | 9 ++-- 2 files changed, 45 insertions(+), 54 deletions(-) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index f779b0f8a5113..d0b5e682bb6f7 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -2,6 +2,7 @@ import os import pickle from collections import defaultdict +from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from vllm.engine.ray_utils import RayWorkerWrapper, ray @@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", VLLM_INSTANCE_ID = get_vllm_instance_id() # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [] - for (node_id, _) in worker_node_and_gpu_ids: - all_args_to_update_environment_variables.append([{ - "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])), - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, - "VLLM_TRACE_FUNCTION": - os.getenv("VLLM_TRACE_FUNCTION", "0"), - }]) + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + os.getenv("VLLM_TRACE_FUNCTION", "0"), + }, ) for (node_id, _) in worker_node_and_gpu_ids] self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) @@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs): # avoid writing `{"name": value}` manually return kwargs - init_worker_all_kwargs = [] - # Initialize the actual workers inside worker wrapper. - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): + init_worker_all_kwargs = [] + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): local_rank = node_workers[node_id].index(rank) init_worker_all_kwargs.append( collect_arg_helper_func( @@ -265,40 +263,40 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, - all_args: Optional[List[List[Any]]] = None, + all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, ) -> Any: - """Runs the given method on all workers. - all_args and all_kwargs are used to pass heterogeneous arguments, - i.e. different arguments for each worker. + """Runs the given method on all workers. Can be used in the following + ways: + + - args/kwargs: All workers share the same args/kwargs + - args/kwargs and driver_args/driver_kwargs: Driver worker has + different args + - all_args/all_kwargs: args/kwargs for each worker are specified + individually """ - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # for mypy type checking - assert driver_args is not None - assert driver_kwargs is not None - if all_args is None: - all_args = [driver_args] + [args] * len(self.workers) - if all_kwargs is None: - all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers) - - # for mypy type checking - assert all_args is not None - assert all_kwargs is not None if max_concurrent_workers: raise NotImplementedError( "max_concurrent_workers is not supported yet.") + if driver_args is None: + driver_args = args if all_args is None else all_args[0] + if driver_kwargs is None: + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. @@ -310,22 +308,17 @@ def _run_workers( worker.execute_method.remote(method, *worker_args, **worker_kwargs) for (worker, worker_args, worker_kwargs - ) in zip(self.workers, all_args[1:], all_kwargs[1:]) + ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( - method, *all_args[0], **all_kwargs[0]) + method, *driver_args, **driver_kwargs) else: driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( - method, *all_args[0], **all_kwargs[0])) + method, *driver_args, **driver_kwargs)) # Get the results of the ray workers. if self.workers: if use_ray_compiled_dag: @@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self): class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_executor = make_async(self.driver_worker.execute_method) + async def _run_workers_async( self, method: str, @@ -399,13 +396,8 @@ async def _run_workers_async( if driver_kwargs is None: driver_kwargs = kwargs - # Run the driver worker asynchronously. - def helper(): - return self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - - driver_executor = make_async(helper) - coros.append(driver_executor()) + coros.append( + self.driver_executor(method, *driver_args, **driver_kwargs)) # Run the ray workers asynchronously. for worker in self.workers: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 783dff3a43404..bcd04e0f98db6 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -108,7 +108,8 @@ def __init__(self, self.worker_class_name = worker_class_name self.worker = None - def update_environment_variables(self, envs: Dict[str, str]) -> None: + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: key = 'CUDA_VISIBLE_DEVICES' if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior @@ -138,10 +139,8 @@ def init_worker(self, *args, **kwargs): def execute_method(self, method, *args, **kwargs): try: - if hasattr(self, method): - executor = getattr(self, method) - else: - executor = getattr(self.worker, method) + target = self if self.worker is None else self.worker + executor = getattr(target, method) return executor(*args, **kwargs) except Exception as e: # if the driver worker also execute methods, From 050f285ff6e7bbe898ee751770b2571972166bef Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 23 Apr 2024 17:02:11 +0900 Subject: [PATCH 098/413] [Core] Scheduling optimization 2 (#4280) --- tests/core/test_scheduler.py | 3 ++- vllm/core/scheduler.py | 10 ++++++++-- vllm/sequence.py | 5 +++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index a2511238506b0..ab471d206618b 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -563,7 +563,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(output.preempted) == 2 # Verify budgets are updated. assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 1 + # NOTE: When enable_chunk is False, num_seqs budget is not updated. + # assert budget.num_curr_seqs == 1 # Both should be preempted, not swapped. assert output.blocks_to_swap_out == {} # Nothing is copied. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8d7db09bbea08..99f7a34d336a4 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -395,12 +395,12 @@ def _schedule_running( # We can have up to 1 running prefill at any given time in running # queue, which means we can guarantee chunk size is at least 1. assert num_running_tokens != 0 - num_running_seqs = seq_group.get_max_num_running_seqs() running_queue.popleft() while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: @@ -439,7 +439,13 @@ def _schedule_running( token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_running_tokens) - budget.add_num_seqs(seq_group.request_id, num_running_seqs) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) diff --git a/vllm/sequence.py b/vllm/sequence.py index 7dcacab6f2ab6..b296b37a84f15 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -508,6 +508,11 @@ def get_num_uncomputed_tokens(self) -> int: return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + # Optimization. We don't need to call get_seqs if we don't need to + # filter by states. + if status is None: + return len(self.seqs_dict) + return len(self.get_seqs(status)) def num_unfinished_seqs(self) -> int: From 62b8aebc6f06b5c8fafa1f27893cd4f9bb11fa8b Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 23 Apr 2024 01:02:36 -0700 Subject: [PATCH 099/413] [Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951) --- tests/samplers/test_rejection_sampler.py | 8 +- tests/samplers/test_sampler.py | 3 +- tests/spec_decode/e2e/__init__.py | 0 tests/spec_decode/e2e/conftest.py | 45 +- tests/spec_decode/e2e/test_compatibility.py | 169 ++++++ tests/spec_decode/e2e/test_correctness.py | 540 ++++++++++++++++-- tests/spec_decode/test_metrics.py | 4 +- tests/spec_decode/test_multi_step_worker.py | 4 +- tests/spec_decode/test_spec_decode_worker.py | 40 +- tests/spec_decode/utils.py | 7 +- vllm/config.py | 67 ++- vllm/engine/arg_utils.py | 18 +- vllm/engine/llm_engine.py | 38 +- vllm/engine/metrics.py | 23 +- vllm/executor/gpu_executor.py | 1 + .../layers/rejection_sampler.py | 7 + vllm/model_executor/layers/sampler.py | 182 +++++- vllm/spec_decode/batch_expansion.py | 70 ++- vllm/spec_decode/interfaces.py | 4 +- vllm/spec_decode/metrics.py | 31 +- vllm/spec_decode/multi_step_worker.py | 29 +- vllm/spec_decode/spec_decode_worker.py | 49 +- 22 files changed, 1164 insertions(+), 175 deletions(-) create mode 100644 tests/spec_decode/e2e/__init__.py create mode 100644 tests/spec_decode/e2e/test_compatibility.py diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index d2c3a798d3087..13b5b80cccfdc 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) + # Bonus tokens are currently disabled. Verify they're set to -1. + # See https://github.com/vllm-project/vllm/issues/4212 + expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 + if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. assert torch.equal(output_token_ids[:, :-1], draft_token_ids) # Expect all bonus tokens to be included. - assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) + assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids) elif which_tokens_accepted == "no_tokens_accepted": # Expect first token to be equal to recovered tokens. assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) @@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, torch.ones_like(output_token_ids[:, 1:]) * -1) elif which_tokens_accepted == "some_tokens_accepted": recovered_plus_bonus = torch.cat( - (recovered_token_ids, bonus_token_ids), dim=-1) + (recovered_token_ids, expected_bonus_token_ids), dim=-1) # Assert first rejected token is a recovered token or bonus token. assert torch.equal( recovered_plus_bonus[torch.arange(0, batch_size), diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index dbbe13b8da060..52a2b0ca52aaa 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str): def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs - return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] + return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] + for prob in probs], None) with patch("vllm.model_executor.layers.sampler._sample", mock_sample): sampler(logits=fake_logits, sampling_metadata=sampling_metadata) diff --git a/tests/spec_decode/e2e/__init__.py b/tests/spec_decode/e2e/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 1d99cb5d32219..59fb8311fc5b7 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,3 +1,5 @@ +from typing import List, Tuple + import pytest from tests.conftest import cleanup @@ -6,28 +8,34 @@ @pytest.fixture -def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, +def baseline_llm_generator(request, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + seed): + return create_llm_generator("baseline", request, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, seed) @pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, +def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs, test_llm_kwargs, seed): - return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed) + return create_llm_generator("test", request, common_llm_kwargs, + per_test_common_llm_kwargs, test_llm_kwargs, + seed) -def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - distinct_llm_kwargs, seed): +def create_llm_generator(baseline_or_test, request, common_llm_kwargs, + per_test_common_llm_kwargs, distinct_llm_kwargs, + seed): kwargs = { **common_llm_kwargs, **per_test_common_llm_kwargs, **distinct_llm_kwargs, } + test_name = request.node.name def generator_inner(): + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') llm = LLM(**kwargs) set_random_seed(seed) @@ -36,6 +44,23 @@ def generator_inner(): del llm cleanup() - for llm in generator_inner(): - yield llm + def generator_outer(): + for llm in generator_inner(): + yield llm + del llm + + return generator_outer + + +def get_output_from_llm_generator( + llm_generator, prompts, + sampling_params) -> Tuple[List[str], List[List[int]]]: + tokens = [] + token_ids = [] + for llm in llm_generator(): + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + tokens = [output.outputs[0].text for output in outputs] del llm + + return tokens, token_ids diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py new file mode 100644 index 0000000000000..fde950c14382c --- /dev/null +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -0,0 +1,169 @@ +import pytest + +from vllm import SamplingParams + +from .conftest import get_output_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Expect failure as spec decode not supported by + # Ray backend. + "worker_use_ray": True, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_ray(test_llm_generator): + """Verify that speculative decoding with Ray fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(AssertionError, + match="Speculative decoding not yet supported for "): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "enable_chunked_prefill": True, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_chunked_prefill(test_llm_generator): + """Verify that speculative decoding with chunked prefill fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, + match="Speculative decoding and chunked prefill"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Speculative max model len > overridden max model len should raise. + "max_model_len": 128, + "speculative_max_model_len": 129, + }, + { + # Speculative max model len > draft max model len should raise. + # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 + "speculative_max_model_len": 2048 + 1, + }, + { + # Speculative max model len > target max model len should raise. + # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12 + "speculative_max_model_len": 4096 + 1, + }, + ]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): + """Verify that speculative decoding validates speculative_max_model_len. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, match="cannot be larger than"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_xfail_block_manager_v1(test_llm_generator): + """Verify that speculative decoding with block manager v1 fails. + """ + output_len = 128 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises(ValueError, + match="Speculative decoding requires usage of the V2"): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index a8ebd66841eb2..0536cc4ecde76 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -1,11 +1,42 @@ +"""The tests in this file verify end-to-end speculative decoding correctness. + +This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. This gives us good coverage of temp=0. + +For temp>0, we rely on unit tests on the rejection sampler to verify that the +output distribution is the same with spec decode vs. no spec decode (this would +be prohibitively expensive to run with a real model). + +NOTE: Speculative decoding's distribution equality requires that the measured +distributions of the target model and proposal model be deterministic given the +same input. vLLM largely guarantees this. + +@cadedaniel has seen cases where the output probabilities of a draft/target +model change slightly with certain batch sizes or prompts, even with Torch +determinism flags set. It is unclear if this is a bug in vLLM, due to non- +determinism in on-device batched operations, a bug in vLLM's spec decode +implementation, or the "hardware numerics" limitations. Either way, rejection +sampling ensures the output distribution matches the target model, but it breaks +greedy-equality tests for those batch sizes/prompts. +""" + from itertools import cycle -from typing import List, Tuple import pytest from transformers import AutoTokenizer from vllm import SamplingParams +from .conftest import get_output_from_llm_generator + @pytest.mark.parametrize( "common_llm_kwargs", @@ -14,9 +45,6 @@ # Note this is repeated in the test body; to initialize a tokenizer. "model": "JackFram/llama-68m", - # Skip real loading for fast test. - "load_format": "dummy", - # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -31,22 +59,15 @@ "num_speculative_tokens": 5, }, { - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 1, - }, - { - # No spec decode. + # Verify the detokenizer assertions in the test work when spec + # decode is disabled. }, ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [1]) -# NOTE: We should run more permutations of this test (more BS, more seeds). But -# because our spec decode generates gibberish token ids, the likelihood of -# emitting an invalid token combination is nontrivial. This causes divergence in -# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf- -# start" bytes are emitted. +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): +def test_spec_decode_e2e_with_detokenization(test_llm_generator, + batch_size: int): """Run generation with speculative decoding on a batch. Verify the engine generates the correct number of tokens (via ignore_eos=True), and that the detokenization matches HF transformers. @@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): max_tokens=output_len, ignore_eos=True, temperature=temperature, - skip_special_tokens=True, - spaces_between_special_tokens=False, ) batch_tokens, batch_token_ids = get_output_from_llm_generator( @@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): # Expect a generation for each prompt in the batch. assert len(batch_token_ids) == len(prompts) - # Expect each generation to have expected number of tokens (note - # ignore_eos=True). - assert all(len(token_ids) == output_len for token_ids in batch_token_ids) + # Expect each generation to have expected number of tokens (note ignore_eos + # is True). + assert [len(token_ids) + for token_ids in batch_token_ids] == ([output_len] * batch_size) # Expect detokenized string to match. tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") @@ -92,14 +112,111 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): @pytest.mark.parametrize( "common_llm_kwargs", [{ - # Use a small model for a fast test. - "model": "JackFram/llama-68m", + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use long output len for the small model test. + 1536, + ]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with batch size of one. + + Since this test is cheaper than other e2e correctness tests, we generate + with a higher output_len. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) - # Skip real loading for fast test. - "load_format": "dummy", +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. + { + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [64]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model and large batch size. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ # Skip cuda graph recording for fast test. "enforce_eager": True, @@ -109,43 +226,372 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int): @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ + # Try two different tiny base models. + # Note that one is equal to the draft model, another isn't. { - # Expect failure as spec decode not supported by - # Ray backend. - "worker_use_ray": True, + "model": "JackFram/llama-68m", + }, + { + "model": "JackFram/llama-160m", }, ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("max_output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( + baseline_llm_generator, test_llm_generator, batch_size: int, + max_output_len: int): + """Verify greedy equality on a tiny model, with a large batch size, and when + sampling respects the EOS token. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len=False) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # A "real" model (not tiny). + "model": "meta-llama/Llama-2-7b-chat-hf", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use decently long output len for a high quality test. + 256, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_real_model_bs1( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a "real" model and batch size of 1. This is + separate from large BS tests to make identifying the source of bugs easier. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # A "real" model (not tiny). + "model": "meta-llama/Llama-2-7b-chat-hf", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 64, + ]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail(test_llm_generator): - """Verify that speculative decoding with Ray fails. +def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality with a "real" model on a nontrivial batch size. + This is the closest test to a real production workload. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_greedy_correctness_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + # As of this writing, vLLM only compiles with these 3 block sizes by + # default. + { + "block_size": 8, + }, + { + "block_size": 16, + }, + { + "block_size": 32, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_different_block_size(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality over different block sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }, + ]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # This must be a good bit larger than speculative_max_model_len so that + # we can test the case where all seqs are skipped, but still small to + # ensure fast test. + 64, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_skip_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when some (or all) sequences skip speculation. + We do this by setting the max model len of the draft model to an + artificially low value, such that when the sequences grow beyond it, they + are skipped in speculative decoding. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": k, + } + # Try a range of common k, as well as large speculation. + for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, + output_len: int): + """Verify that speculative decoding produces exact equality to without spec + decode with many different values of k. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +def run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. """ - output_len = 128 temperature = 0.0 prompts = [ "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", ] + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, + max_tokens=max_output_len, + ignore_eos=ignore_eos, temperature=temperature, ) - with pytest.raises(AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) -def get_output_from_llm_generator( - llm_generator, prompts, - sampling_params) -> Tuple[List[str], List[List[int]]]: - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - tokens = [output.outputs[0].text for output in outputs] - del llm + assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) - return tokens, token_ids + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if print_tokens: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py index 36e91672069dc..312878804b86e 100644 --- a/tests/spec_decode/test_metrics.py +++ b/tests/spec_decode/test_metrics.py @@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): num_draft_tokens = 0 k = 5 - num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens( + max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( num_draft_tokens, k) rej_sampler = MagicMock() @@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): assert (metrics.draft_acceptance_rate == num_accepted_tokens / num_draft_tokens) assert (metrics.system_efficiency == num_emitted_tokens / - num_possible_tokens) + max_num_emitted_tokens) else: assert math.isnan(metrics.draft_acceptance_rate) assert math.isnan(metrics.system_efficiency) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index d6edbab579afd..e7aaa1ff4eff8 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations(): assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) - assert proposals.proposal_token_ids.shape == torch.Size([0, k]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k]) + assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) assert proposals.proposal_lens.shape == torch.Size([batch_size]) assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 0a3110775e2d6..d24d726c9c0cf 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -1,4 +1,5 @@ import random +from types import SimpleNamespace from unittest.mock import MagicMock import pytest @@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int): """Verify SpecDecodeWorker calls the target model with correct inputs. Everything else is mocked out. """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() + draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) + target_worker = mock_worker(use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): num_lookahead_slots=k) assert len(rejection_sampler.call_args_list) == 1 - args, _ = rejection_sampler.call_args_list[0] - (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, - actual_proposal_token_ids) = args + _, kwargs = rejection_sampler.call_args_list[0] + actual = SimpleNamespace(**kwargs) - assert torch.equal(actual_bonus_token_ids, + assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) assert torch.equal( - actual_proposal_scores, + actual.target_probs, target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) - assert torch.equal(actual_proposal_token_ids, proposal_token_ids) - assert torch.equal(actual_proposal_probs, proposal_probs) + assert torch.equal(actual.draft_token_ids, proposal_token_ids) + assert torch.equal(actual.draft_probs, proposal_probs) @pytest.mark.parametrize('k', [1, 2, 6]) @@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): """ vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size) - target_worker = mock_worker(vocab_size=vocab_size) + draft_worker = mock_worker(cls=MultiStepWorker, + vocab_size=vocab_size, + use_spec=False) + target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) @@ -500,8 +506,8 @@ def test_init_device(): """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as well as other GPU initialization. """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() + draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) + target_worker = mock_worker(use_spec=False) rejection_sampler = MagicMock(spec=RejectionSampler) rejection_sampler.token_id_dtype = torch.int64 metrics_collector = MagicMock(spec=AsyncMetricsCollector) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index d04b6029493f4..4f8295d25cf41 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -63,11 +63,14 @@ def create_execute_model_data( def mock_worker(cls=None, vocab_size: int = 30_000, max_model_len: int = 2048, - rank: int = 0) -> MagicMock: + rank: int = 0, + use_spec: bool = True) -> MagicMock: if cls is None: cls = Worker - worker = MagicMock(spec=cls) + spec = cls if use_spec else None + + worker = MagicMock(spec=spec) worker.vocab_size = vocab_size worker.max_model_len = max_model_len worker.rank = rank diff --git a/vllm/config.py b/vllm/config.py index 97ede0faa21ab..2ff42de08f8f7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -655,6 +655,9 @@ def maybe_create_spec_config( target_dtype: str, speculative_model: Optional[str], num_speculative_tokens: Optional[int], + speculative_max_model_len: Optional[int], + enable_chunked_prefill: bool, + use_v2_block_manager: bool, ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -672,6 +675,15 @@ def maybe_create_spec_config( model, if provided. num_speculative_tokens (Optional[int]): The number of speculative tokens, if provided. + speculative_max_model_len (Optional[int]): The maximum model len of + the speculative model. Used when testing the ability to skip + speculation for some sequences. + enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since its not + yet compatible with spec decode. + use_v2_block_manager (bool): Whether vLLM is configured to use the + v2 block manager or not. Used for raising an error since the v2 + block manager is required with spec decode. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if @@ -690,12 +702,21 @@ def maybe_create_spec_config( assert (speculative_model is not None and num_speculative_tokens is not None) + if enable_chunked_prefill: + raise ValueError( + "Speculative decoding and chunked prefill are " + f"currently mutually exclusive ({enable_chunked_prefill=}).") + + if not use_v2_block_manager: + raise ValueError( + "Speculative decoding requires usage of the V2 " + "block manager. Enable it with --use-v2-block-manager.") + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None draft_code_revision = None draft_quantization = None - draft_max_model_len = None draft_model_config = ModelConfig( model=speculative_model, @@ -707,7 +728,7 @@ def maybe_create_spec_config( revision=draft_revision, code_revision=draft_code_revision, tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=draft_max_model_len, + max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, max_context_len_to_capture=target_model_config. @@ -715,6 +736,13 @@ def maybe_create_spec_config( max_logprobs=target_model_config.max_logprobs, ) + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config)) @@ -725,6 +753,41 @@ def maybe_create_spec_config( num_speculative_tokens, ) + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig) -> ParallelConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5de20633ffdd6..6a6ac49ae3211 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -73,6 +73,7 @@ class EngineArgs: # Speculative decoding configuration. speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None + speculative_max_model_len: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -237,7 +238,7 @@ def add_cli_args( parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 128], + choices=[8, 16, 32], help='Token block size for contiguous chunks of ' 'tokens.') @@ -420,17 +421,25 @@ def add_cli_args( parser.add_argument( '--speculative-model', type=str, - default=None, + default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') parser.add_argument( '--num-speculative-tokens', type=int, - default=None, + default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-max-model-len', + type=str, + default=EngineArgs.speculative_max_model_len, + help='The maximum sequence length supported by the ' + 'draft model. Sequences over this length will skip ' + 'speculation.') + parser.add_argument('--model-loader-extra-config', type=str, default=EngineArgs.model_loader_extra_config, @@ -481,6 +490,9 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_max_model_len=self.speculative_max_model_len, + enable_chunked_prefill=self.enable_chunked_prefill, + use_v2_block_manager=self.use_v2_block_manager, ) scheduler_config = SchedulerConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d96025ea1fb6a..19e58fb1722cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup) + SequenceGroup, SequenceStage) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -480,9 +480,12 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - # If uncomputed tokens > 0, it means prefill is chunked. - # We don't need to process outputs in that case. - if seq_group.get_num_uncomputed_tokens() == 0: + + # If all sequences in the sequence group are in DECODE, then we can + # process the output tokens. Otherwise, they are (chunked) prefill + # samples and should not be processed. + stages = [seq.data._stage for seq in seq_group.seqs_dict.values()] + if all(stage == SequenceStage.DECODE for stage in stages): self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -569,7 +572,8 @@ def step(self) -> List[RequestOutput]: # Log stats. if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) + self.stat_logger.log( + self._get_stats(scheduler_outputs, model_output=output)) return request_outputs @@ -578,9 +582,18 @@ def do_log_stats(self) -> None: if self.log_stats: self.stat_logger.log(self._get_stats(scheduler_outputs=None)) - def _get_stats(self, - scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: - """Get Stats to be Logged to Prometheus.""" + def _get_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + """ now = time.time() # KV Cache Usage in %. @@ -637,6 +650,14 @@ def _get_stats(self, time_to_first_tokens = time_last_iters if prompt_run else [] time_per_output_tokens = [] if prompt_run else time_last_iters + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and (model_output[0].spec_decode_worker_metrics + is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + return Stats( now=now, num_running=num_running, @@ -649,6 +670,7 @@ def _get_stats(self, time_to_first_tokens=time_to_first_tokens, time_per_output_tokens=time_per_output_tokens, time_e2e_requests=time_e2e_requests, + spec_decode_metrics=spec_decode_metrics, ) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 04e27e69ce0f3..25e96f6c7eaf7 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ import time from dataclasses import dataclass -from typing import Dict, List, Protocol +from typing import TYPE_CHECKING, Dict, List, Optional, Protocol import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -8,6 +8,9 @@ from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + logger = init_logger(__name__) disable_created_metrics() @@ -118,6 +121,8 @@ class Stats: time_per_output_tokens: List[float] time_e2e_requests: List[float] + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + class SupportsMetricsInfo(Protocol): @@ -235,3 +240,19 @@ def log(self, stats: Stats) -> None: self.num_prompt_tokens = [] self.num_generation_tokens = [] self.last_local_log = stats.now + + if stats.spec_decode_metrics is not None: + logger.info( + self._format_spec_decode_metrics_str( + stats.spec_decode_metrics)) + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 77c997f97956e..d413a7d27ff37 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -83,6 +83,7 @@ def _init_spec_worker(self): scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, + # TODO allow draft-model specific load config. load_config=self.load_config, local_rank=0, rank=0, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index ecd2bd0fce3a3..5edbbf2c70a49 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -144,6 +144,7 @@ def _batch_modified_rejection_sampling( recovered_probs = self._get_recovered_probs( target_probs, draft_probs).reshape(batch_size * k, vocab_size) + # NOTE: the recovered_probs are overwritten by this method. recovered_token_ids = _multinomial(recovered_probs, num_samples=1).reshape( batch_size, k) @@ -307,6 +308,12 @@ def _create_output( output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, bonus_token_ids, -1) + # We disable bonus tokens because it causes corrupt KV cache for + # proposal methods that require KV cache. We can fix it by "prefilling" + # the bonus token in the proposer. The following issue tracks the fix. + # https://github.com/vllm-project/vllm/issues/4212 + output_with_bonus_tokens[:, -1] = -1 + # Fill the recovered token ids. output.mul_(~after_false_mask).add_( recovered_token_ids.mul(after_false_mask)) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 03bf38caebe0e..c4b11cb33a677 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -35,6 +35,14 @@ class Sampler(nn.Module): in logits for each token in the input prompt. """ + def __init__(self): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.include_gpu_probs_tensor = False + def forward( self, logits: torch.Tensor, @@ -79,13 +87,45 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata, - sampling_tensors) + sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + assert maybe_sampled_tokens_tensor is not None + sampled_tokens_tensor = maybe_sampled_tokens_tensor + on_device_tensors = (probs, sampled_tokens_tensor) + else: + on_device_tensors = None + # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) + return _build_sampler_output(sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors) + + @property + def _should_modify_greedy_probs_inplace(self) -> bool: + """Whether or not the sampler should modify the probability distribution + of greedily-sampled tokens such that multinomial sampling would sample + the greedily-sampled token. + + In other words, if True then we set the probability of the greedily- + sampled token to 1. + + This is used by speculative decoding, which requires that the sampling + method be encoded into the probability distribution. + """ + # Modify greedy probs if include_gpu_probs_tensor is set. + return self.include_gpu_probs_tensor def _get_bin_counts_and_mask( @@ -359,7 +399,9 @@ def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, -) -> List[Tuple[List[int], List[int]]]: + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): @@ -371,6 +413,15 @@ def _sample_with_torch( sample_metadata = {} multinomial_samples = {} + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.empty(logprobs.shape[0], + 1, + dtype=torch.long, + device=logprobs.device) + else: + sampled_token_ids_tensor = None + # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: @@ -383,9 +434,25 @@ def _sample_with_torch( is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_metadata[sampling_type] = (seq_group_ids, seq_groups, is_prompts, sample_indices) + long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[sample_indices.long()], + greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) + + if include_gpu_probs_tensor: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = greedy_samples.unsqueeze(-1) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace(logprobs, probs, + long_sample_indices, + greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of_in_batch = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): @@ -397,15 +464,23 @@ def _sample_with_torch( "seq_groups": seq_groups, "generators": sampling_metadata.generators, } + multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of_in_batch, + probs[long_sample_indices], max_best_of_in_batch, **seeded_args) + + if include_gpu_probs_tensor: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = multinomial_samples[sampling_type] + elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: raise ValueError(f"Unsupported sampling type: {sampling_type}") # GPU<->CPU sync happens in the loop below. + # This also converts the sample output to Python objects. for sampling_type in SamplingType: if sampling_type not in sample_metadata: @@ -427,7 +502,7 @@ def _sample_with_torch( sample_results_dict[i] for i in range(len(sampling_metadata.seq_groups)) ] - return sample_results + return sample_results, sampled_token_ids_tensor def _sample_with_triton_kernel( @@ -511,12 +586,17 @@ def _sample_with_triton_kernel( def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - return _sample_with_torch(probs, logprobs, sampling_metadata) + probs: torch.Tensor, logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, modify_greedy_probs: bool +) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: + return _sample_with_torch( + probs, + logprobs, + sampling_metadata, + include_gpu_probs_tensor=include_gpu_probs_tensor, + modify_greedy_probs=modify_greedy_probs, + ) # TODO: Enable once Triton kernel & associated code is faster. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, @@ -680,12 +760,73 @@ def _get_logprobs( return result_prompt_logprobs, result_sample_logprobs +def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor) -> None: + """Modify the probability distributions of the greedily-sampled tokens such + that each sampled token has a "probability" of 1.0. This is required by + speculative decoding, which depends on the sampling method being encoded + within the probability distribution for correctness. + + # Why do we only need to do this for greedy sampling? + + vLLM's sampler performs the following steps for greedy or multinomial + (random) sampling: + 1. Get logits from model. + 2. Modify logits according to per-sequence sampling parameters. + - Multiply by temperature, top-k and top-p masking, penalize tokens + according to their frequency, etc. + 3. Sample a token. + - Random sampling simply samples from the modified probability + distribution. + - Greedy sampling performs `argmax` to obtain the token with the + highest likelihood. + + Ignoring greedy sampling for a moment, we find that the computed probability + distribution has the following property: we can sample from it independently + and find that the token sampled by the Sampler has a frequency corresponding + to how often we see it in our sampling. In other words, for tokens sampled + with vLLM's random SamplingType, the computed probability distribution + encodes the sampling methodology completely. + + Greedy sampling does not normally have this property. vLLM modifies logits + according to sampling params, then performs `argmax`, then returns the + sampled token and the computed probability distribution. If we sample from + the distribution, we'll find the likelihood of the greedily-sampled token + is not always 1.0. + + Since lossless speculative decoding requires that the sampling methodology + be encoded within the probability distribution, we are motivated to modify + the probability distribution such that the sampled token has probability 1 + when speculative decoding is used. + + NOTE: Alternatively, we could use an extremely low temperature to achieve + greedy sampling using multinomial computation and unite the codepaths. This + has implications on the overall design of the sampler, e.g. how to record + accurate logprobs for the user, so this improvement is deferred to later. + """ + logprobs[sample_indices, :] = -float('inf') + logprobs[sample_indices, greedy_samples] = 0.0 + probs[sample_indices, :] = 0 + probs[sample_indices, greedy_samples] = 1.0 + + def _build_sampler_output( sample_results: List[Tuple[List[int], List[int]]], sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], ) -> SamplerOutput: + """Construct Python objects with the output of sampling. + + Args: + on_device_tensors: Tuple containing on-device tensors with the + probabilities used in sampling and the sampled token ids. This + allows post-processing without copies to CPU/serialization, e.g. in + speculative decoding rejection sampling. + """ + sampler_output = [] for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, @@ -701,4 +842,15 @@ def _build_sampler_output( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return SamplerOutput(outputs=sampler_output) + + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + sampled_token_probs, sampled_token_ids = on_device_tensors + else: + sampled_token_probs, sampled_token_ids = (None, None) + + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index bbc5b1778854f..c29b838f854c0 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,8 +6,8 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors, - nvtx_range, sampler_output_to_torch, +from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, + sampler_output_to_torch, split_batch_by_proposal_len) from vllm.worker.worker_base import WorkerBase @@ -72,10 +72,16 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() + # Filter the list to ignore -1 proposals. + proposal_token_ids_list_without_skips = [ + proposals for proposals in proposal_token_ids_list + if -1 not in proposals + ] + (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( seq_group_metadata_list=seq_group_metadata_list, - proposal_token_ids_list=proposal_token_ids_list, + proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) @@ -89,7 +95,7 @@ def score_proposals( target_sampler_output = target_sampler_output[0] all_tokens, all_probs = self._contract_batch( - original_bs=len(seq_group_metadata_list), + contracted_bs=len(seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, @@ -128,14 +134,21 @@ def _expand_batch( select_proposal_len_zero=True) target_seq_group_metadata_list = self._create_scoring_model_input( - spec_seqs, proposal_token_ids_list) + seq_group_metadata_list=spec_seqs, + proposal_token_ids=proposal_token_ids_list, + # NOTE: We determine the seq ids in the expanded batch using the + # full seq_group_metadata_list, instead of only spec_seqs. + target_seq_ids_iter=self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(seq_group_metadata_list)), + ) + num_scoring_tokens = len(target_seq_group_metadata_list) target_seq_group_metadata_list.extend(non_spec_seqs) return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) - def _contract_batch(self, original_bs: int, + def _contract_batch(self, contracted_bs: int, target_sampler_output: List[SamplerOutput], proposals: SpeculativeProposals, num_scoring_tokens: int, non_spec_indices: List[int], @@ -144,42 +157,41 @@ def _contract_batch(self, original_bs: int, """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. - """ - - # We mock the device tensors until PR 7/9 is merged (e2e correctness). - # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer - maybe_mock_device_tensors( - sampler_output=target_sampler_output, - batch_size=len(non_spec_indices) + num_scoring_tokens, - vocab_size=self._vocab_size, - device=self._device, - ) + contracted_bs is the original batch size, and the batch size that the + target_sampler_output will be contracted to. + """ (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token # of shape [batch_size * k + 1] back to [batch_size, k + 1]. - batch_size, k = proposals.proposal_token_ids.shape + expanded_batch_size, k = proposals.proposal_token_ids.shape + + # The number of tokens in the expanded batch used for speculation is + # equal to the total expanded batch size minus the number of samples for + # non-speculative sequences. + non_spec_expanded_bs, _ = non_spec_target_token_ids.shape + spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs target_token_ids = target_token_ids.squeeze().reshape( - batch_size, k + 1) - target_probs = target_probs.squeeze().reshape(batch_size, k + 1, + spec_expanded_bs, k + 1) + target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size) - all_tokens = torch.full(size=(original_bs, k + 1), + all_tokens = torch.full(size=(contracted_bs, k + 1), fill_value=-1, device=self._device, dtype=torch.long) - all_probs = torch.zeros(original_bs, + all_probs = torch.zeros(contracted_bs, k + 1, self._vocab_size, device=self._device, dtype=torch.float32) if non_spec_indices: - all_tokens[non_spec_indices, 0] = non_spec_target_token_ids + all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs if spec_indices: @@ -189,20 +201,22 @@ def _contract_batch(self, original_bs: int, return all_tokens, all_probs def _create_scoring_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + target_seq_ids_iter: Iterator[TargetSeqId], ) -> List[SequenceGroupMetadata]: """Given the original input sequences and proposed tokens from the draft model, create a list of target sequences that can be used for scoring. + + target_seq_ids_iter provides sequence ids for the expanded batch, + fulfilling the requirement that no seq id in the expanded batch is equal + to the seq id in the original batch. """ if not seq_group_metadata_list: return [] - target_seq_ids_iter = self._create_target_seq_id_iterator( - get_all_seq_ids(seq_group_metadata_list)) - target_seq_group_metadata = list( chain.from_iterable( self._create_target_seq_group_metadata( diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index f0715120192e5..dd040779922e9 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -24,9 +24,9 @@ class SpeculativeProposals: def __repr__(self): return (f"SpeculativeProposals(" - f"proposal_token_ids={self.proposal_token_ids.shape}, " + f"proposal_token_ids={self.proposal_token_ids}, " f"proposal_probs={self.proposal_probs.shape}, " - f"proposal_lens={self.proposal_lens.shape})") + f"proposal_lens={self.proposal_lens})") @dataclass diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index d1e72b6640548..ab1d96c558de7 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -147,15 +147,16 @@ def _collect_rejsample_metrics( emitted_tokens = self._aggregate_num_emitted_tokens.item() draft_tokens = self._aggregate_num_draft_tokens - num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k) + max_num_emitted_tokens = self.get_max_num_emitted_tokens( + draft_tokens, k) if draft_tokens > 0: draft_acceptance_rate = accepted_tokens / draft_tokens else: draft_acceptance_rate = float("nan") - if num_possible_tokens > 0: - system_efficiency = emitted_tokens / num_possible_tokens + if max_num_emitted_tokens > 0: + system_efficiency = emitted_tokens / max_num_emitted_tokens else: system_efficiency = float("nan") @@ -169,8 +170,22 @@ def _collect_rejsample_metrics( ) @staticmethod - def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int: - # Divide by k since batch size can be variable. - total_num_spec_seqs = draft_tokens / k - num_accepted_per_seq_if_all_accepted = k + 1 - return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted) + def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: + """Calculate the number of emitted tokens, assuming all tokens are + accepted. + + This is equal to the number of sequences that have been speculated on, + times (speculation len + 1). The +1 comes from the bonus token. + """ + # Determine the number of sequences that have been speculated on. Since + # the batch size can be variable, we divide by k. + assert draft_tokens % k == 0 + total_num_spec_seqs = draft_tokens // k + + # A single sequence may emit k accepted tokens and one bonus token in + # the best case. + num_emitted_per_seq_if_all_accepted = k + 1 + + # The max num of emitted tokens is the number of speculated sequences + # times the max emitted per seq. + return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 8b722476853fa..7cf338bbae5f0 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -6,8 +6,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) -from vllm.spec_decode.util import (maybe_mock_device_tensors, - sampler_output_to_torch) +from vllm.spec_decode.util import sampler_output_to_torch from vllm.worker.worker import Worker @@ -329,12 +328,15 @@ def _merge_outputs( """ if maybe_sampler_output is None: # If no speculative tokens, the sampler output will be None. - # In this case we return empty tensors. - proposal_tokens = torch.zeros(0, - max_proposal_len, - dtype=torch.long, - device=self._device) - proposal_probs = torch.zeros(0, + # In this case we return empty proposals. + proposal_tokens = torch.full(size=( + batch_size, + max_proposal_len, + ), + fill_value=-1, + dtype=torch.long, + device=self._device) + proposal_probs = torch.zeros(batch_size, max_proposal_len, self._vocab_size, dtype=torch.float32, @@ -345,17 +347,6 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - - # We mock the device tensors until PR 7/9 is merged (e2e correctness). - # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer - for step_output in sampler_output: - maybe_mock_device_tensors( - sampler_output=step_output, - batch_size=len(proposal_lens), - vocab_size=self._vocab_size, - device=self._device, - ) - proposal_tokens, proposal_probs = sampler_output_to_torch( sampler_output) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 68a2a774ef4b7..2c6642f5a3c81 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -111,6 +111,32 @@ def init_device(self) -> None: device=self.device, vocab_size=self._vocab_size) + self._configure_model_sampler_for_spec_decode() + + def _configure_model_sampler_for_spec_decode(self): + """Configure model sampler to emit GPU tensors. This allows spec decode + to keep data on device without transferring to CPU and serializing, + which significantly reduces overhead of rejection sampling. + + NOTE(cade): This breaks abstraction boundaries pretty badly. The better + design is to have the "move to CPU and serialize" sampling decision be + done outside of the model/sampler; this way the "last-mile" worker + object which interfaces with the scheduler can serialize and incur the + performance hit as necessary. This allows us to run the worker several + iterations in a row without incurring the "move to CPU and serialize" + performance penalty. + + Since this requires a large change to vLLM, we defer it to later and + temporarily accept this broken abstraction boundary. + + NOTE(cade): This will require a special check if the proposer worker + does not have a sampler (e.g. ngram speculation). + """ + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + (self.proposer_worker.model_runner.model.sampler. + include_gpu_probs_tensor) = True + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. @@ -286,15 +312,26 @@ def _verify_tokens( select_proposal_len_zero=True) original_indices = spec_indices + non_spec_indices - proposal_probs = proposal_scores.probs[spec_indices, :-1] - bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + # Get probabilities of target model, excluding bonus token. + proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1] + + # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + # Get bonus tokens from target model. + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + + # Get probabilities according to proposal method. + proposal_probs = proposals.proposal_probs[spec_indices] + + # Get proposed tokens. + proposal_token_ids = proposals.proposal_token_ids[spec_indices] + accepted_token_ids = self.rejection_sampler( - proposal_probs, - bonus_token_ids, - proposals.proposal_probs, - proposals.proposal_token_ids, + target_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, ) # Append output tokens from non-speculative sequences to From d3c8180ac4143f4affd2ef26855058e96b72b5f5 Mon Sep 17 00:00:00 2001 From: Jack Gordley Date: Tue, 23 Apr 2024 12:06:29 +0100 Subject: [PATCH 100/413] [Bugfix] Fixing max token error message for openai compatible server (#4016) --- vllm/entrypoints/openai/serving_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 610e807cae4c7..31da27a447c6c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -206,6 +206,12 @@ def _validate_prompt_and_tokenize( token_num = len(input_ids) if request.max_tokens is None: + if token_num >= self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.", ) request.max_tokens = self.max_model_len - token_num if token_num + request.max_tokens > self.max_model_len: From d87f39e9a9dd149f5dd7a58b4d98b21f713827b6 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:28:35 +0800 Subject: [PATCH 101/413] [Bugfix] Add init_cached_hf_modules to RayWorkerWrapper (#4286) --- vllm/executor/ray_gpu_executor.py | 2 ++ vllm/worker/worker_base.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index d0b5e682bb6f7..e69f104e7d5a4 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -100,6 +100,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", )(RayWorkerWrapper).remote( worker_module_name="vllm.worker.worker", worker_class_name="Worker", + trust_remote_code=self.model_config.trust_remote_code, ) worker_ip = ray.get(worker.get_node_ip.remote()) @@ -110,6 +111,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self.driver_worker = RayWorkerWrapper( worker_module_name="vllm.worker.worker", worker_class_name="Worker", + trust_remote_code=self.model_config.trust_remote_code, ) else: # Else, added to the list of workers. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index bcd04e0f98db6..b5dade0a770a0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -103,10 +103,15 @@ class WorkerWrapperBase: def __init__(self, worker_module_name=None, - worker_class_name=None) -> None: + worker_class_name=None, + trust_remote_code: bool = False) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name self.worker = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() @staticmethod def update_environment_variables(envs: Dict[str, str]) -> None: From d86285a4a4b79b883620d2878c0b52b22ad4c640 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 23 Apr 2024 09:45:52 -0700 Subject: [PATCH 102/413] [Core][Logging] Add last frame information for better debugging (#4278) --- vllm/logger.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/vllm/logger.py b/vllm/logger.py index 046f0e9099a4b..341fc473585d7 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -83,13 +83,27 @@ def _trace_calls(log_path, root_dir, frame, event, arg=None): return # Log every function call or return try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" with open(log_path, 'a') as f: if event == 'call': f.write(f"{datetime.datetime.now()} Call to" - f" {func_name} in {filename}:{lineno}\n") + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n") else: f.write(f"{datetime.datetime.now()} Return from" - f" {func_name} in {filename}:{lineno}\n") + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n") except NameError: # modules are deleted during shutdown pass From 62b5166bd4c458c3a8f6eda89d3ef9db14a4c2c8 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 23 Apr 2024 09:51:41 -0700 Subject: [PATCH 103/413] [CI] Add ccache for wheel builds job (#4281) --- .github/workflows/publish.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index fc97e33c19af2..4b9fc3d04d872 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -56,6 +56,9 @@ jobs: - name: Checkout uses: actions/checkout@v3 + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1.2 + - name: Set up Linux Env if: ${{ runner.os == 'Linux' }} run: | From 2b7949c1c2e34de41d9cfc84dd0e377cc6bd58c2 Mon Sep 17 00:00:00 2001 From: James Fleming Date: Tue, 23 Apr 2024 13:59:33 -0400 Subject: [PATCH 104/413] AQLM CUDA support (#3287) Co-authored-by: mgoin --- CMakeLists.txt | 1 + benchmarks/kernels/benchmark_aqlm.py | 302 ++++++++ csrc/ops.h | 15 + csrc/pybind.cpp | 2 + csrc/quantization/aqlm/gemm_kernels.cu | 712 ++++++++++++++++++ examples/aqlm_example.py | 46 ++ tests/models/test_aqlm.py | 95 +++ vllm/model_executor/layers/linear.py | 41 +- .../layers/quantization/__init__.py | 2 + .../layers/quantization/aqlm.py | 373 +++++++++ .../model_executor/layers/quantization/awq.py | 4 +- .../layers/quantization/gptq.py | 3 +- .../layers/quantization/marlin.py | 3 +- .../layers/quantization/squeezellm.py | 4 +- 14 files changed, 1592 insertions(+), 11 deletions(-) create mode 100644 benchmarks/kernels/benchmark_aqlm.py create mode 100644 csrc/quantization/aqlm/gemm_kernels.cu create mode 100644 examples/aqlm_example.py create mode 100644 tests/models/test_aqlm.py create mode 100644 vllm/model_executor/layers/quantization/aqlm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1845151181284..b2d0cf3e568b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,7 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC + "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/marlin_cuda_kernel.cu" "csrc/custom_all_reduce.cu") diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py new file mode 100644 index 0000000000000..9602d20bcbc74 --- /dev/null +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -0,0 +1,302 @@ +import argparse +import os +import sys +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm._C import ops +from vllm.model_executor.layers.quantization.aqlm import ( + dequantize_weight, generic_dequantize_gemm, get_int_dtype, + optimized_dequantize_gemm) + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +def torch_mult( + input: torch.Tensor, # [..., in_features] + weights: torch.Tensor, + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] +) -> torch.Tensor: + output = F.linear(input, weights) + return output + + +def dequant_out_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return flattened_output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +def dequant_weight_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +def dequant_no_scale( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + return F.linear(input, weights, bias) + + +# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against +# the generic pytorch version. +# Just visual comparison. +def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: + + n = parts.sum().item() + + device = torch.device('cuda:0') + + code_range = (1 << bits) // 2 + ingroups = 8 + + codes = torch.randint(-code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device) + + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device) + + count = 0 + for index in range(16): + for i in range(8): + for book in range(nbooks): + codebooks[book, index, 0, i] = count * (10**book) + count += 1 + + print("codes shape", codes.shape) + + for i in range(16): + for book in range(nbooks): + codes[0, i, book] = i + codes[0, -i, book] = i + + weights = dequantize_weight(codes, codebooks, None) + weights2 = ops.aqlm_dequant(codes, codebooks, parts) + + print("weights shape:", weights.shape) + print("weights2 shape:", weights2.shape) + + print("weights are:", weights) + print("weights2 are:", weights2) + + print("first 128 weights are", weights[0, 0:128].to(torch.int32)) + print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32)) + + print("last 128 weights are", weights[0, -128:]) + print("last 128 weights2 are:", weights2[0, -128:]) + + +def main(): + + parser = argparse.ArgumentParser(description="Benchmark aqlm performance.") + + # Add arguments + parser.add_argument("--nbooks", + type=int, + default=1, + help="Number of codebooks (default: 1)") + parser.add_argument("--bits", + type=int, + default=16, + help="Number of bits per code element (default: 16)") + parser.add_argument( + "--test", + type=bool, + default=False, + help="Run the decompression/dequant tester rather than benchmarking " + "(default: False)") + + # Parse the arguments + args = parser.parse_args() + + # Extract values + nbooks = args.nbooks + bits = args.bits + + if args.test: + dequant_test(4096, torch.tensor((4096, )), nbooks, bits) + return + + # Otherwise, benchmark. + methods = [ + ops.aqlm_gemm, + dequant_out_scale, + generic_dequantize_gemm, + optimized_dequantize_gemm, + dequant_weight_scale, + torch_mult, + dequant_no_scale, + ] + + filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv" + print(f"writing benchmarks to file {filename}") + with open(filename, "w") as f: + sys.stdout = f + + print('m | k | n | n parts', end='') + for method in methods: + print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') + print('') + + # These are reasonable prefill sizes. + ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), + (4096, (11008, 11008)), (11008, (4096, ))) + + # reasonable ranges for m. + for m in [ + 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, + 128, 256, 512, 1024, 1536, 2048, 3072, 4096 + ]: + print(f'{m}', file=sys.__stdout__) + for ksp in ksandpartions: + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, + methods) + + sys.stdout = sys.__stdout__ + + +def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, + methods): + + # I didn't see visible improvements from increasing these, but feel free :) + num_warmup_trials = 1 + num_trials = 1 + + num_calls = 100 + + # warmup. + for method in methods: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + m=m, + k=k, + parts=parts, + nbooks=nbooks, + bits=bits, + method=method, + ) + + n = parts.sum().item() + print(f'{m} | {k} | {n} | {parts.tolist()}', end='') + + for method in methods: + best_time_us = 1e20 + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + m=m, + k=k, + parts=parts, + nbooks=nbooks, + bits=bits, + method=method, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + + if kernel_dur_us < best_time_us: + best_time_us = kernel_dur_us + + print(f' | {kernel_dur_us:.0f}', end='') + + print('') + + +def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, + nbooks: int, bits: int, method) -> float: + + n = parts.sum().item() + + device = torch.device('cuda:0') + + input = torch.randn((1, m, k), dtype=torch.float16, device=device) + + code_range = (1 << bits) // 2 + ingroups = 8 + + codes = torch.randint(-code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device) + + codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device) + + scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) + + # for comparison to just a pytorch mult. + weights = torch.randn((n, k), dtype=torch.float16, device=device) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + + if method is torch_mult: + for i in range(num_calls): + torch_mult(input, weights, scales) + else: + for i in range(num_calls): + method(input, codes, codebooks, scales, parts, None) + + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/csrc/ops.h b/csrc/ops.h index 41ecc1e89371b..a379c910d9cf3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -86,6 +86,21 @@ void gelu_fast( torch::Tensor& input); #ifndef USE_ROCM +torch::Tensor aqlm_gemm( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias +); + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes +); + torch::Tensor awq_gemm( torch::Tensor _in_feats, torch::Tensor _kernel, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc162113..42e92e5382e8e 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Quantization ops #ifndef USE_ROCM + ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); + ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu new file mode 100644 index 0000000000000..4415316e1e8cd --- /dev/null +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -0,0 +1,712 @@ +/* + * Modified by Neural Magic + * Adapted from https://github.com/Vahe1994/AQLM + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace vllm { +namespace aqlm { + +__global__ void Code1x16MatVec( + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + __shared__ int4 sh_b[32 * 9]; + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) + sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t dec[4]; + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[enc[i]]) + ); + half2* a = reinterpret_cast(&dec); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; + #pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(a[j], b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + +__global__ void Code2x8MatVec( + const int4* __restrict__ A, + const int4* __restrict__ B, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. + +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + int b_gl_rd = 0; + int c_gl_wr = a_gl_rd; + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + extern __shared__ int4 sh[]; + int4* sh_b = sh; + int4* sh_code = sh_b + 32 * 9; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; + #pragma unroll + for (int j = 0; j < 8; j++) + sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32); + while (iters--) { + // We pad shared memory to avoid bank conflicts during reads + __syncthreads(); + for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { + if (b_gl_rd + i < prob_k / 8) + sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + } + __syncthreads(); + b_gl_rd += 32 * 8; + + int b_sh_rd = 9 * (threadIdx.x % 32); + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2 res2 = {}; + #pragma unroll + for (int j = 0; j < 4; j++) + res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); + res += __half2float(res2.x) + __half2float(res2.y); + b_sh_rd++; + } + a_gl_rd += 32; + } + } + + if (pred) { + #pragma unroll + for (int i = 16; i > 0; i /= 2) + res += __shfl_down_sync(0xffffffff, res, i); + if (threadIdx.x % 32 == 0) + reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); + } +} + + +__global__ void Code1x16Dequant( + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + auto dec = reinterpret_cast(&chunk); + // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't + // actually help us; this brings > 2x speedup. + asm volatile ( + "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*) &codebook[enc[i]]) + ); + + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + + +__global__ void Code2x8Dequant( + const int4* __restrict__ A, + int4* __restrict__ C, + const int4* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int a_gl_stride = prob_k / 8 / 8; + int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + bool pred = a_gl_rd < prob_m; + + if (pred) + { + // advance to the correct codebook, this easy because we only multiply one column of the codebook. + auto codebook_size = &codebook_a_sizes.x; + while (a_gl_rd >= *codebook_size) + { + codebook += codebook_stride; + ++codebook_size; + } + } + + a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32; + int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32; + int lane = threadIdx.x % 8; + + int c_gl_stride = prob_k / 8; + int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); + c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8; + + extern __shared__ int4 sh[]; + int4* sh_code = sh; + int4* sh_code0 = sh_code; + int4* sh_code1 = sh_code + 256 * 8; + + for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { + int4 dec = codebook[i]; + #pragma unroll + for (int j = 0; j < 8; j++) + sh_code[8 * i + (j + lane) % 8] = dec; + } + __syncthreads(); + + float res = 0; + + int iters = (prob_k / 8 - 1) / (8 * 32) + 1; + while (iters--) { + if (pred && a_gl_rd < a_gl_end) { + const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); + #pragma unroll + for (int i = 0; i < 8; i++) { + int4 chunk; + half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); + C[a_gl_rd * 8 + i] = chunk; + } + } + a_gl_rd += 32; + } +} + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +const int THREAD_M = 16; + +void code1x16_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, + const int codebook_stride +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16MatVec<<>>( + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); +} + +void code2x8_matvec_cuda( + const void* __restrict__ A, + const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, + const int codebook_stride +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaFuncSetAttribute( + Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared + ); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code2x8MatVec<<>>( + (const int4*) A, + (const int4*) B, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); +} + +void code1x16_dequant_cuda( + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + const int codebook_stride // as int4. +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + Code1x16Dequant<<>>( + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. + codebook_stride // as int4. + ); +} + +// Dequantizes the code and codebook into weights. +void code2x8_dequant_cuda( + const void* __restrict__ A, + void* __restrict__ C, + const void* __restrict__ codebook, + int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. + const int codebook_stride // as int4 +) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); + int waves = 0; + int thread_m; + do { + waves++; + thread_m = ceildiv(prob_m, waves * sms); + } while (thread_m > THREAD_M); + + int blocks = ceildiv(prob_m, thread_m); + int threads = 32 * thread_m; + int shared = 16 * (2 * 256 * 8 + 32 * 9); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + cudaFuncSetAttribute( + Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared + ); + Code2x8Dequant<<>>( + (const int4*) A, + (int4*) C, + (const int4*) codebook, + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride + ); +} + +int codebook_stride(const torch::Tensor& codebooks) +{ + return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); +} + +void code1x16_matvec( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + + code1x16_matvec_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + codebook.data_ptr(), + prob_m, + prob_k, + codebook_a_sizes, + codebook_stride(codebook) + ); +} + +torch::Tensor code1x16_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code1x16_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + codebook_a_sizes + ); + } + flat_output *= scales.flatten().unsqueeze(0); + + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +void code2x8_matvec( + const torch::Tensor& A, + const torch::Tensor& B, + torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + int prob_m = C.size(0); + int prob_k = B.size(0); + code2x8_matvec_cuda( + A.data_ptr(), + B.data_ptr(), + C.data_ptr(), + codebook.data_ptr(), + prob_m, + prob_k, + codebook_a_sizes, + 2 * codebook_stride(codebook) + ); +} + +torch::Tensor code2x8_matmat( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias +) { + auto input_sizes = input.sizes(); + auto out_features = codes.size(0) * codebooks.size(2); + auto flat_input = input.reshape({-1, input.size(-1)}); + auto flat_output = torch::empty({flat_input.size(0), out_features}, + torch::TensorOptions() + .dtype(input.dtype()) + .device(input.device()) + ); + + for (int i = 0; i < flat_input.size(0); ++i) { + auto input_vec = flat_input.index({i}); + auto output_vec = flat_output.index({i}); + code2x8_matvec( + codes.squeeze(2), + input_vec, + output_vec, + codebooks, + codebook_a_sizes + ); + } + flat_output *= scales.flatten().unsqueeze(0); + if (bias.has_value()) { + flat_output += bias->unsqueeze(0); + } + + auto output_sizes = input_sizes.vec(); + output_sizes.pop_back(); + output_sizes.push_back(-1); + auto output = flat_output.reshape(output_sizes); + return output; +} + +// Accumulate the partition sizes. +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) +{ + int4 cumulative_sizes; + auto cumulative_size = &cumulative_sizes.x; + int i = 0; + int last = 0; + assert(codebook_partition_sizes.size(0) <= 4); + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) + { + *cumulative_size = codebook_partition_sizes[i].item() + last; + last = *cumulative_size; + } + // fill in the rest with unreachable. + for (; i < 4; ++i, ++cumulative_size) + { + *cumulative_size = last*10; + } + return cumulative_sizes; +} + +} // namespace aqlm +} // namespace vllm + + +torch::Tensor aqlm_gemm( + const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias +) +{ + int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const entries = codebooks.size(1); + + if (nbooks == 1 && entries == (1 << 16)) + { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + } + if (nbooks == 2 && entries == (1 << 8)) + { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + return {}; +} + +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes +) +{ + int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); + + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const entries = codebooks.size(1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); + int rows = codes.size(1); + int cols = codes.size(0); + + auto in_features = codes.size(1) * 8; + auto out_features = codes.size(0); + + assert(out_features = codebook_partition_sizes.sum().item()); + + auto weights = torch::empty({out_features, in_features}, + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device()) + ); + + if (nbooks == 1 && entries == (1 << 16)) + { + vllm::aqlm::code1x16_dequant_cuda( + codes.data_ptr(), + weights.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features, + cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) + // weights *= scales.index({"...", 0, 0}); + + return weights; + } + + if (nbooks == 2 && entries == (1 << 8)) + { + vllm::aqlm::code2x8_dequant_cuda( + codes.data_ptr(), + weights.data_ptr(), + codebooks.data_ptr(), + out_features, + in_features, + cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) + // weights *= scales.index({"...", 0, 0}); + + return weights; + } + + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + return {}; +} diff --git a/examples/aqlm_example.py b/examples/aqlm_example.py new file mode 100644 index 0000000000000..e7c17fa0362ae --- /dev/null +++ b/examples/aqlm_example.py @@ -0,0 +1,46 @@ +import argparse + +from vllm import LLM, SamplingParams + + +def main(): + + parser = argparse.ArgumentParser(description='AQLM examples') + + parser.add_argument('--model', + '-m', + type=str, + default=None, + help='model path, as for HF') + parser.add_argument('--choice', + '-c', + type=int, + default=0, + help='known good models by index, [0-4]') + parser.add_argument('--tensor_parallel_size', + '-t', + type=int, + default=1, + help='tensor parallel size') + + args = parser.parse_args() + + models = [ + "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", + "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf", + "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf", + "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf", + "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf", + ] + + model = LLM(args.model if args.model is not None else models[args.choice], + tensor_parallel_size=args.tensor_parallel_size) + + sampling_params = SamplingParams(max_tokens=100, temperature=0) + outputs = model.generate("Hello my name is", + sampling_params=sampling_params) + print(outputs[0].outputs[0].text) + + +if __name__ == '__main__': + main() diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py new file mode 100644 index 0000000000000..a7abc011f57d7 --- /dev/null +++ b/tests/models/test_aqlm.py @@ -0,0 +1,95 @@ +"""Compare the outputs of a AQLM model between vLLM and HF Transformers + +Run `pytest tests/models/test_aqlm.py`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +aqlm_not_supported = (capability < + QUANTIZATION_METHODS["aqlm"].get_min_capability()) + +# In this test we hardcode prompts and generations for the model so we don't +# need to require the AQLM package as a dependency +example_prompts = [ + 'vLLM is a high-throughput and memory-efficient inference and serving ' + 'engine for LLMs.\n', + 'Briefly describe the major milestones in the development of artificial ' + 'intelligence from 1950 to 2020.\n', + 'Compare and contrast artificial intelligence with human intelligence in ' + 'terms of processing information.\n', + 'Describe the basic components of a neural network and how it can be ' + 'trained.\n', + 'Write a short story about a robot that dreams for the first time.\n', + 'Analyze the impact of the COVID-19 pandemic on global economic structures ' + 'and future business models.\n', + 'Explain the cultural significance of the Mona Lisa painting, and how its ' + 'perception might vary in Western versus Eastern societies.\n', + "Translate the following English sentence into Japanese, French, and " + "Swahili: 'The early bird catches the worm.'\n" +] + +# These ground truth generations were generated using `transformers==4.38.1 +# aqlm==1.1.0 torch==2.2.0` +# and the below code: +# ```python +# from transformers import AutoTokenizer, AutoModelForCausalLM +# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf" +# quantized_model = AutoModelForCausalLM.from_pretrained(model_id, +# torch_dtype="auto", device_map="cuda").cuda() +# tokenizer = AutoTokenizer.from_pretrained(model_id) +# outputs = [] +# for prompt in example_prompts: +# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda") +# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32) +# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:])) +# print(outputs) +# ``` +ground_truth_generations = [ + '\n### Features\n\n- **High-throughput**: v', + 'The major milestones in the development of artificial intelligence from ' + '195', + 'Compare and contrast artificial intelligence with human intelligence in ' + 'terms of processing information. The', + 'Explain the difference between supervised and unsupervised learning.' + '\nExplain', + 'Write a short story about a robot that dreams for the first time. The', + 'Analyze the impact of the COVID-19 pandemic on global economic', + 'The Mona Lisa is a painting by Leonardo da Vinci, and it', + 'The early bird catches the worm.\nThe early bird catches the' +] + + +@pytest.mark.skipif(aqlm_not_supported, + reason="AQLM is not supported on this GPU type.") +@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("num_logprobs", [1]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + + # loop through the prompts to compare against the ground truth generations + for prompt_idx in range(len(example_prompts)): + vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[ + prompt_idx] + + print("Prompt: ", repr(example_prompts[prompt_idx])) + print("Reference output:", repr(ground_truth_generations[prompt_idx])) + print("Output output: ", repr(vllm_output_str)) + assert vllm_output_str == ground_truth_generations[prompt_idx] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d466d8807fc64..e56af9075e2fd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -31,7 +31,7 @@ class LinearMethodBase(ABC): @abstractmethod def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): """Create weights for a linear layer. @@ -70,9 +70,10 @@ def __init__(self, separate_bias_add: bool = False): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), @@ -127,7 +128,7 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size, - self.output_size, self.input_size, + [self.output_size], self.input_size, self.output_size, self.params_dtype) if bias: self.bias = Parameter( @@ -161,6 +162,8 @@ class ColumnParallelLinear(torch.nn.Module): skip adding bias but instead return it. params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. """ def __init__( @@ -172,6 +175,7 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, + output_sizes: Optional[List[int]] = None, ): super().__init__() @@ -188,10 +192,12 @@ def __init__( self.params_dtype = params_dtype if linear_method is None: linear_method = UnquantizedLinearMethod() + if output_sizes is None: + output_sizes = [output_size] self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size, - self.output_size_per_partition, + [x // tp_size for x in output_sizes], self.input_size, self.output_size, self.params_dtype, @@ -268,14 +274,17 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method) + skip_bias_add, params_dtype, linear_method, + self.output_sizes) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): + param_data = param.data output_dim = getattr(param, "output_dim", None) + is_metadata = getattr(param, "is_metadata", False) if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -328,6 +337,11 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -393,8 +407,14 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + output_sizes = [ + self.num_heads * tp_size * self.head_size, + self.num_kv_heads * tp_size * self.head_size, + self.num_kv_heads * tp_size * self.head_size + ] + super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method) + params_dtype, linear_method, output_sizes) def weight_loader(self, param: Parameter, @@ -402,6 +422,7 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + is_metadata = getattr(param, "is_metadata", False) if loaded_shard_id is None: # Loaded weight is already packed. @@ -469,6 +490,12 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -536,7 +563,7 @@ def __init__( self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size_per_partition, - self.output_size, + [self.output_size], self.input_size, self.output_size, self.params_dtype, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0344d6e4e3e45..a525add458499 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,5 +1,6 @@ from typing import Type +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -9,6 +10,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS = { + "aqlm": AQLMConfig, "awq": AWQConfig, "fp8": FP8Config, "gptq": GPTQConfig, diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py new file mode 100644 index 0000000000000..6115b1de679ad --- /dev/null +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -0,0 +1,373 @@ +# Supports AQLM compression, see https://github.com/Vahe1994/AQLM +# and https://arxiv.org/pdf/2401.06118.pdf + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +def get_int_dtype(nbits: int) -> torch.dtype: + if nbits <= 8: + return torch.int8 + if nbits <= 16: + return torch.int16 + if nbits <= 32: + return torch.int32 + if nbits <= 64: + return torch.int64 + raise ValueError(f"No dtype available for {nbits}-bit codebooks") + + +@torch.inference_mode() +def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: + return data.to(torch.int64) % (2**nbits) + + +def dequantize_weight(codes: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape + [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, + [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be + broadcastble with + [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape + [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = \ + codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, + device=codes.device) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, + codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size + # * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + + [num_out_groups, num_in_groups, out_group_size, in_group_size]) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( + scales) + return reconstructed_weight_groupwise.swapaxes( + -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + +def dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + dequantized_weight = dequantize_weight( + unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), + codebooks, + scales, + ) + return F.linear(input, dequantized_weight, bias) + + +# Generic dequantization, slow but flexible. +def generic_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output_shape = input.shape[:-1] + (scales.shape[0], ) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + num_outputs = len(output_partition_sizes) + + # break the inputs and codebooks apart then combine the outputs. + # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big + # multiply at the end. + num_codebooks = codebooks.shape[0] // num_outputs + assert (scales.shape[0] == codes.shape[0]) + assert (sum(output_partition_sizes) == scales.shape[0]) + output_offset = 0 + codebooks_offset = 0 + for output_size in output_partition_sizes: + shard_output = dequantize_gemm( + input, codes.narrow(0, output_offset, output_size), + codebooks.narrow(0, codebooks_offset, num_codebooks), + scales.narrow(0, output_offset, output_size), None + if bias is None else bias.narrow(0, output_offset, output_size)) + + output_slice = output.narrow(-1, output_offset, output_size) + assert (output_slice.shape == shard_output.shape) + output_slice.copy_(shard_output) + output_offset += output_size + codebooks_offset += num_codebooks + return output + + +# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 +# at 6 and 9 times faster than the generic version above, respectively. +def optimized_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: torch.IntTensor, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + # scaling the output is fastest, so we do that when possible. + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +class AQLMConfig(QuantizationConfig): + """Config class for AQLM. + + Reference: https://github.com/Vahe1994/AQLM + """ + + def __init__( + self, + in_group_size: int, + nbits_per_codebook: int, + num_codebooks: int, + out_group_size: int, + ) -> None: + self.in_group_size = in_group_size + self.nbits_per_codebook = nbits_per_codebook + self.num_codebooks = num_codebooks + self.out_group_size = out_group_size + + # out_group_size > 1 is untested, and probably won't work as-is. + assert (self.out_group_size == 1) + self.pack_factor = (self.in_group_size * self.out_group_size) + + def __repr__(self) -> str: + return (f"AQLMConfig(in_group_size={self.in_group_size}, " + f"nbits_per_codebook={self.nbits_per_codebook}, " + f"num_codebooks={self.num_codebooks}, " + f"out_group_size={self.out_group_size})") + + @classmethod + def get_name(cls) -> str: + return "aqlm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": + in_group_size = cls.get_from_keys(config, ["in_group_size"]) + nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) + num_code_books = cls.get_from_keys(config, ["num_codebooks"]) + out_group_size = cls.get_from_keys(config, ["out_group_size"]) + return cls(in_group_size, nbits_per_codebook, num_code_books, + out_group_size) + + def get_linear_method(self) -> "AQLMLinearMethod": + return AQLMLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class AQLMLinearMethod(LinearMethodBase): + """Linear method for AQLM. + + Args: + quant_config: The AQLM quantization config. + """ + + def __init__(self, quant_config: AQLMConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.half: + raise ValueError("Only half is currently supported by aqlm") + if input_size_per_partition % self.quant_config.in_group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.out_group_size != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + codes = Parameter( + torch.empty( + # There could actually be two pack factors, one along input and + # one along output, but we don't currently support + # out_group_size, and only the one along output needs to be + # marked with "packed_dim" in order for QKVLinear to work. + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + self.quant_config.num_codebooks, + dtype=get_int_dtype(self.quant_config.nbits_per_codebook), + ), + requires_grad=False, + ) + + set_weight_attrs( + codes, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + codebooks = Parameter( + torch.empty( + self.quant_config.num_codebooks * len(output_partition_sizes), + 2**self.quant_config.nbits_per_codebook, + self.quant_config.out_group_size, + self.quant_config.in_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + codebooks, + { + # metadata indicates fixed size concatenated along dim 0 + "is_metadata": + True, + "output_partition_sizes": + torch.tensor(output_partition_sizes, device='cpu'), + }, + ) + + scales = Parameter( + torch.empty( + ( + output_size_per_partition // + self.quant_config.out_group_size, + 1, + 1, + 1, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 0, + "packed_dim": 0, + "pack_factor": self.quant_config.out_group_size + }, + ) + + layer.register_parameter("codes", codes) + set_weight_attrs(codes, extra_weight_attrs) + layer.register_parameter("codebooks", codebooks) + set_weight_attrs(codebooks, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + codebooks = layer.codebooks + codes = layer.codes + scales = layer.scales + output_partition_sizes = getattr(codebooks, "output_partition_sizes", + None) + + nbooks = codes.shape[2] + ingroups = codebooks.shape[3] + outgroups = codebooks.shape[2] + bits = codebooks.shape[1] + + # We support these formats with dedicated gemm and decompression + # kernels. + if ingroups == 8 and outgroups == 1 and ( + (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): + + # thresholds determined by timings on an A6000, one GPU + use_gemv = math.prod(x.shape[:-1]) <= 6 + + return ops.aqlm_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) if use_gemv else optimized_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) + + # fall back all unoptimized formats + return generic_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 98651aed8be0e..4f75134ee1889 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -81,7 +81,7 @@ def __init__(self, quant_config: AWQConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: @@ -89,6 +89,8 @@ def create_weights(self, layer: torch.nn.Module, "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f370b94a210ee..92a5cdb9af928 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -91,7 +91,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -103,6 +103,7 @@ def create_weights( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) if (output_size_per_partition % self.quant_config.pack_factor.numerator != 0): raise ValueError( diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index bf0500f1155a1..00c3c404c2d7a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -93,7 +93,7 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, @@ -106,6 +106,7 @@ def create_weights( f"The params dtype must be float16, but got {params_dtype}") # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.min_n_threads != 0: raise ValueError( f"Weight output_size_per_partition = " diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 661ff9c55d0d1..cc44447d347b8 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -70,7 +70,7 @@ def __init__(self, quant_config: SqueezeLLMConfig): def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: @@ -78,6 +78,8 @@ def create_weights(self, layer: torch.nn.Module, "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) qweight = Parameter( torch.empty( input_size_per_partition // self.quant_config.pack_factor, From 1e8f4252aa163041094a8fedbb701a98d7087d7c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 24 Apr 2024 02:19:03 +0800 Subject: [PATCH 105/413] [Bugfix][Frontend] Raise exception when file-like chat template fails to be opened (#4292) --- tests/async_engine/test_chat_template.py | 19 ++++++++++++++----- vllm/entrypoints/openai/serving_chat.py | 23 +++++++++++++++-------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 6972ae1dee4a1..8d6ad6706fb0e 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -76,20 +76,29 @@ def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -def test_no_load_chat_template(): +def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() + mock_serving_chat = MockServingChat(tokenizer) + + with pytest.raises(ValueError, match="looks like a file path"): + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) + + +def test_no_load_chat_template_literallike(): + # Testing chatml template + template = "{{ messages }}" + tokenizer = MockTokenizer() + mock_serving_chat = MockServingChat(tokenizer) OpenAIServingChat._load_chat_template(mock_serving_chat, chat_template=template) template_content = tokenizer.chat_template - # Test assertions - assert template_content is not None - # Hard coded value for template_chatml.jinja - assert template_content == """../../examples/does_not_exist""" + assert template_content == template @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d502dd0a4eb75..2ff335eb71073 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -319,23 +319,30 @@ async def chat_completion_full_generator( return response def _load_chat_template(self, chat_template): + tokenizer = self.tokenizer + if chat_template is not None: try: with open(chat_template, "r") as f: - self.tokenizer.chat_template = f.read() - except OSError: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly - self.tokenizer.chat_template = codecs.decode( + tokenizer.chat_template = codecs.decode( chat_template, "unicode_escape") logger.info( - f"Using supplied chat template:\n{self.tokenizer.chat_template}" - ) - elif self.tokenizer.chat_template is not None: + f"Using supplied chat template:\n{tokenizer.chat_template}") + elif tokenizer.chat_template is not None: logger.info( - f"Using default chat template:\n{self.tokenizer.chat_template}" - ) + f"Using default chat template:\n{tokenizer.chat_template}") else: logger.warning( "No chat template provided. Chat API will not work.") From eace8bf0b9118877c390e6d490502214c39db132 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 23 Apr 2024 18:18:23 -0700 Subject: [PATCH 106/413] [Kernel] FP8 support for MoE kernel / Mixtral (#4244) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208 It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows: Screenshot 2024-04-21 at 1 31 50 PM **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking! --- CMakeLists.txt | 1 + csrc/ops.h | 5 + csrc/pybind.cpp | 1 + csrc/quantization/fp8/fp8_cuda_kernels.cu | 103 ++++++++++++ vllm/_custom_ops.py | 10 +- ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 146 ++++++++++++++++++ .../layers/fused_moe/fused_moe.py | 93 ++++++++--- vllm/model_executor/model_loader/loader.py | 2 + vllm/model_executor/model_loader/utils.py | 1 + vllm/model_executor/models/mixtral.py | 44 +++++- 10 files changed, 385 insertions(+), 21 deletions(-) create mode 100644 csrc/quantization/fp8/fp8_cuda_kernels.cu create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json diff --git a/CMakeLists.txt b/CMakeLists.txt index b2d0cf3e568b7..4a99985d9abc4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/fp8/fp8_cuda_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/csrc/ops.h b/csrc/ops.h index a379c910d9cf3..ff7a3de1a0a8c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -146,6 +146,11 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); +void scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + void moe_align_block_size( torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 42e92e5382e8e..a5b16c5abc3ed 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu new file mode 100644 index 0000000000000..c3337cede1282 --- /dev/null +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -0,0 +1,103 @@ +#include +#include +#include + +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : + __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finish before consuming *scale. +template +__global__ void segmented_max_reduction( + float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { + __shared__ float cache[1024]; + int i = blockDim.x * blockIdx.x + threadIdx.x; + + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + cache[threadIdx.x] = tmp; + + __syncthreads(); + + // Now perform parallel reduction within the thread block + int ib = blockDim.x / 2; + while (ib != 0) { + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; + } + __syncthreads(); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (threadIdx.x == 0) { + atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); + } +} + +template +__global__ void scaled_fp8_quant_kernel( + c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + while (i < num_elems) { + out[i] = static_cast(input[i] / *scale); + i += blockDim.x * gridDim.x; + } +} + +} // namespace vllm + +void scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "scaled_fp8_quant_kernel", + [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), + input.data_ptr(), + num_elems); + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scale.data_ptr(), + num_elems); + }); +} + diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a0837a20875fe..e4b16ed918d1a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch @@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# fp8 +def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + vllm_ops.scaled_fp8_quant(output, input, scale) + return output, scale + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..2ad07bf79a25c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 377b6588dbf47..ac7c30e2a9727 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -21,6 +21,8 @@ def fused_moe_kernel( a_ptr, b_ptr, c_ptr, + a_scale_ptr, + b_scale_ptr, topk_weights_ptr, sorted_token_ids_ptr, expert_ids_ptr, @@ -49,6 +51,7 @@ def fused_moe_kernel( MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, + use_fp8: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -111,6 +114,10 @@ def fused_moe_kernel( b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block @@ -129,7 +136,10 @@ def fused_moe_kernel( mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. - accumulator += tl.dot(a, b) + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk @@ -140,7 +150,10 @@ def fused_moe_kernel( other=0) accumulator = accumulator * moe_weight[:, None] - accumulator = accumulator.to(compute_type) + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -207,15 +220,24 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + B_scale: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, mul_routed_weight: bool, top_k: int, - config: Dict[str, Any]) -> None: + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + if not use_fp8: + A_scale = None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A) + assert B_scale is not None + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) @@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, A, B, C, + A_scale, + B_scale, topk_weights, sorted_token_ids, expert_ids, @@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, C.stride(2), MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, - compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + compute_type=compute_type, + use_fp8=use_fp8, **config, ) -def get_config_file_name(E: int, N: int) -> str: +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: device_name = torch.cuda.get_device_name().replace(" ", "_") - return f"E={E},N={N},device_name={device_name}.json" + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" @functools.lru_cache -def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N) + json_file_name = get_config_file_name(E, N, dtype) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) @@ -288,6 +315,9 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -305,6 +335,12 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -358,7 +394,8 @@ def fused_moe( config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2]) + configs = get_moe_configs(E, w2.shape[2], + "float8" if use_fp8 else None) if configs: # If an optimal configuration map has been found, look up the @@ -394,17 +431,37 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, False, - topk_ids.shape[1], config) + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=tl.float16, + use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, - topk_weights, topk_ids, sorted_token_ids, - expert_ids, num_tokens_post_padded, True, 1, - config) + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=tl.float16, + use_fp8=use_fp8) if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 64cd186506bdb..f75c35a69d4a9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -232,6 +232,8 @@ def load_model(self, *, model_config: ModelConfig, linear_method = getattr(module, "linear_method", None) if linear_method is not None: linear_method.process_weights_after_loading(module) + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() return model.eval() diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index a0a3b2784614d..f7e0f56c1a46e 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,6 +24,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. if (model_config.quantization is not None + and model_config.quantization != "fp8" and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4d1755f2bbe63..a33b795d7088e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,6 +39,8 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod, + per_tensor_quantize) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,6 +49,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import SamplerOutput +from vllm.utils import print_warning_once class MixtralMoE(nn.Module): @@ -66,6 +69,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -73,6 +77,9 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + # FIXME(pcmoritz): Make this more general to support different + # quantization schemes + self.use_fp8 = isinstance(linear_method, Fp8LinearMethod) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -97,6 +104,16 @@ def __init__( device="cuda", dtype=self.params_dtype)) + # Scaling factors for FP8 weights + self.ws_scale = nn.Parameter( + torch.ones( + self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) if self.use_fp8 else None + self.w2s_scale = nn.Parameter( + torch.ones( + self.num_total_experts, device="cuda", dtype=torch.float32), + requires_grad=False) if self.use_fp8 else None + set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, }) @@ -118,6 +135,18 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + def process_weights_after_loading(self): + if self.use_fp8: + ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) + w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) + for expert in range(self.num_total_experts): + ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize( + self.ws.data[expert, :, :]) + w2s[expert, :, :], self.w2s_scale[ + expert] = per_tensor_quantize(self.w2s.data[expert, :, :]) + self.ws = nn.Parameter(ws, requires_grad=False) + self.w2s = nn.Parameter(w2s, requires_grad=False) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) @@ -129,7 +158,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits, self.top_k, renormalize=True, - inplace=True) + inplace=True, + use_fp8=self.use_fp8, + w1_scale=self.ws_scale, + w2_scale=self.w2s_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -171,6 +203,13 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window + if isinstance(linear_method, Fp8LinearMethod): + print_warning_once( + "For Mixtral FP8 quantization, we currently do not quantize " + "the attention layers until their FP8 performance is improved." + ) + linear_method = None + self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -238,7 +277,8 @@ def __init__( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + intermediate_size=config.intermediate_size, + linear_method=linear_method) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, From 79a268c4ab2cbf44280eebd998b8efc383bac216 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 23 Apr 2024 21:26:33 -0400 Subject: [PATCH 107/413] [BUG] fixed fp8 conflict with aqlm (#4307) Fixes fp8 iterface which broke in AQLM merge. --- .buildkite/test-pipeline.yaml | 3 +++ vllm/model_executor/layers/linear.py | 16 +++++++++++++--- vllm/model_executor/layers/quantization/fp8.py | 3 ++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f7c1569696249..11cda053260ec 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -96,6 +96,9 @@ steps: - label: Metrics Test command: pytest -v -s metrics +- label: Quantization Test + command: pytest -v -s quantization + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" commands: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e56af9075e2fd..6ad7ae0f63197 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -34,9 +34,19 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - """Create weights for a linear layer. - - The weights will be set as attributes of the layer.""" + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8df82e0e18edd..01e494c870e71 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -64,12 +64,13 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): + output_size_per_partition = sum(output_partition_sizes) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), From 91f50a6fe240b2c92a99e171bb11d083f82e4a84 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 23 Apr 2024 18:32:19 -0700 Subject: [PATCH 108/413] [Core][Distributed] use cpu/gloo to initialize pynccl (#4248) --- tests/distributed/test_pynccl.py | 15 ++- .../device_communicators/pynccl.py | 122 ++++++++++-------- .../device_communicators/pynccl_utils.py | 12 +- vllm/distributed/parallel_state.py | 6 + vllm/worker/worker.py | 9 +- 5 files changed, 93 insertions(+), 71 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index d58f621d36b86..6d7d4a5806bd0 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,6 +5,7 @@ from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) +from vllm.distributed.parallel_state import init_distributed_environment from vllm.utils import update_environment_variables @@ -26,19 +27,23 @@ def distributed_run(fn, world_size): for p in processes: p.join() + for p in processes: + assert p.exitcode == 0 + -def update_env(fn): +def worker_fn_wrapper(fn): # `multiprocessing.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function - def wrapper(env): + def wrapped_fn(env): update_environment_variables(env) + init_distributed_environment() fn() - return wrapper + return wrapped_fn -@update_env +@worker_fn_wrapper def worker_fn(): comm = NCCLCommunicator() tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) @@ -53,7 +58,7 @@ def test_pynccl(): distributed_run(worker_fn, 2) -@update_env +@worker_fn_wrapper def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 0707afe922f40..fcedf0fed34cb 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -20,14 +20,15 @@ # variable in the code. import ctypes -import datetime import platform +from typing import Optional, Union # ===================== import region ===================== import torch import torch.distributed as dist -from torch.distributed import ReduceOp +from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger from vllm.utils import find_nccl_library, nccl_integrity_check @@ -59,6 +60,18 @@ ncclResult_t = ctypes.c_int +_c_ncclGetErrorString = nccl.ncclGetErrorString +_c_ncclGetErrorString.restype = ctypes.c_char_p +_c_ncclGetErrorString.argtypes = [ncclResult_t] + + +def NCCL_CHECK(result: ncclResult_t) -> None: + if result != 0: + error_str = _c_ncclGetErrorString(result) + error_str = error_str.decode("utf-8") + raise RuntimeError(f"NCCL error: {error_str}") + + # equivalent to c declaration: # ncclResult_t ncclGetVersion(int *version); _c_ncclGetVersion = nccl.ncclGetVersion @@ -68,8 +81,7 @@ def ncclGetVersion() -> str: version = ctypes.c_int() - result = _c_ncclGetVersion(ctypes.byref(version)) - assert result == 0 + NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) # something like 21903 --> "2.19.3" version_str = str(version.value) major = version_str[0].lstrip("0") @@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure): def ncclGetUniqueId() -> NcclUniqueId: unique_id = NcclUniqueId() - result = _c_ncclGetUniqueId(ctypes.byref(unique_id)) - assert result == 0 + NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) return unique_id @@ -199,66 +210,75 @@ class NCCLCommunicator: def __init__( self, - backend=None, - init_method=None, - timeout=datetime.timedelta(seconds=10), - world_size: int = -1, - rank: int = -1, - store=None, - group_name: str = "", - pg_options=None, - local_rank: int = -1, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, ): - if not dist.is_initialized(): - backend = backend or "nccl" - assert backend == 'nccl', ( - "only use nccl backend for starting the NCCL communicator") - dist.init_process_group(backend=backend, - init_method=init_method, - timeout=timeout, - world_size=world_size, - rank=rank, - store=store, - group_name=group_name, - pg_options=pg_options) - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - if local_rank == -1: - local_rank = self.rank - self.local_rank = local_rank - # don't use these args, as they can be -1 - # use `self.rank`, `self.local_rank` and `self.world_size` instead - del world_size, rank, local_rank - torch.cuda.set_device(self.local_rank) + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the NCCLCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + assert dist.is_initialized() + group = get_cpu_world_group() if group is None else group + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "NCCLCommunicator should be attached to a non-NCCL group.") + self.group = group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) if self.rank == 0: self.unique_id = ncclGetUniqueId() else: self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( - self.local_rank) - dist.broadcast(tensor, src=0) - byte_list = tensor.cpu().tolist() + tensor = torch.ByteTensor(list(self.unique_id.internal)) + dist.broadcast(tensor, src=0, group=group) + byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank) - assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}") + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + current_device = torch.cuda.current_device() + try: + torch.cuda.set_device(device) + NCCL_CHECK( + _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank)) + self.stream = torch.cuda.Stream() + finally: + torch.cuda.set_device(current_device) def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream)) - assert result == 0 + NCCL_CHECK( + _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), + ctypes.c_void_p(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + ctypes.c_void_p(stream.cuda_stream))) def __del__(self): # `dist` module might have been already destroyed diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index 916dc814af7eb..a717fddb695ba 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from torch.distributed import ReduceOp +from torch.distributed import ProcessGroup, ReduceOp from vllm.logger import init_logger @@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream): pass -def init_process_group(world_size: int, - rank: int, - init_method: str, - local_rank: int = -1) -> None: +def init_process_group(group: Optional[ProcessGroup] = None) -> None: assert not is_initialized() global comm logger.info(f"vLLM is using nccl=={ncclGetVersion()}") - comm = NCCLCommunicator(init_method=init_method, - world_size=world_size, - local_rank=local_rank, - rank=rank) + comm = NCCLCommunicator(group=group) def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e2473736375e0..515f2212511b7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -4,6 +4,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib +import os from typing import Optional import torch @@ -73,6 +74,11 @@ def init_distributed_environment( ranks = list(range(torch.distributed.get_world_size())) _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, backend="gloo") + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1 and distributed_init_method == "env://": + local_rank = int(os.environ['LOCAL_RANK']) global _LOCAL_RANK _LOCAL_RANK = local_rank diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2203570b37ad6..39ad428f16fe3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -298,12 +298,9 @@ def init_worker_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - pynccl_utils.init_process_group( - world_size=parallel_config.world_size, - local_rank=local_rank, - rank=rank, - init_method=distributed_init_method, - ) + # NOTE(kaichao): By default, pynccl will use information inside + # `parallel_state` for initialization. + pynccl_utils.init_process_group() ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) From e4bf860a54a302ccb2d80489368d5df686e46923 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 23 Apr 2024 18:33:12 -0700 Subject: [PATCH 109/413] [CI][Build] change pynvml to nvidia-ml-py (#4302) --- requirements-cuda.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-cuda.txt b/requirements-cuda.txt index c6d2cd46aee54..1bddae4c6f40f 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -3,7 +3,7 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 -pynvml == 11.5.0 +nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1 From 468d761b32e3b3c5d64eeaa797e54ab809b7e50c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 23 Apr 2024 18:54:33 -0700 Subject: [PATCH 110/413] [Misc] Reduce supported Punica dtypes (#4304) --- CMakeLists.txt | 12 -------- csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu | 4 --- csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu | 4 --- csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu | 4 --- csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu | 4 --- csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu | 4 --- csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu | 4 --- csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu | 4 --- csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu | 4 --- csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu | 4 --- csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu | 4 --- csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu | 4 --- csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu | 4 --- csrc/punica/bgmv/generator.py | 20 ++++++++++++ csrc/punica/punica_ops.cc | 17 ++++++++++ tests/lora/test_layers.py | 41 +++++++++++++++++-------- 16 files changed, 66 insertions(+), 72 deletions(-) delete mode 100644 csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a99985d9abc4..e9262b57d0867 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,23 +212,11 @@ define_gpu_extension_target( set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu" "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu" "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu" "csrc/punica/punica_ops.cc") # diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu deleted file mode 100644 index e8202dff561d9..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu deleted file mode 100644 index 3e7cf31dead0f..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu deleted file mode 100644 index 68277fa6b7d56..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu deleted file mode 100644 index 3b7531b8fbcfc..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu deleted file mode 100644 index b3b74aa3ec904..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu deleted file mode 100644 index 3cc87f5df76a1..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu deleted file mode 100644 index 9eda98bd8ddcf..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu deleted file mode 100644 index 060f9ebb8c2b1..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu deleted file mode 100644 index b37e44570bf40..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu deleted file mode 100644 index 06718cbb0a3e9..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu deleted file mode 100644 index 41fb0e45ef4e6..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu deleted file mode 100644 index 50b7ead9fcefd..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index c347d4f2ab9f4..9bf7f6358880f 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -18,6 +18,26 @@ if weight_dtype == "fp32": # FP32 weights are not supported. continue + if output_dtype == "fp32": + # LoRA A matrix. + if input_dtype != weight_dtype: + # NOTE(woosuk): While Punica supports the case where the + # input and weight dtypes are different, we only generate + # the kernels the same dtypes to reduce the binary size. + continue + elif input_dtype == "fp32": + # LoRA B matrix. + if output_dtype != weight_dtype: + # NOTE(woosuk): While Punica supports the case where the + # output and weight dtypes are different, we only generate + # the kernels the same dtypes to reduce the binary size. + continue + elif not (input_dtype == output_dtype == weight_dtype): + # NOTE(woosuk): While Punica supports mixed data types for + # input, output, and weight, we only generate the kernels with + # the same data types to reduce the binary size. + continue + kernel_definition = TEMPLATE.format( input_dtype=DTYPE_MAP[input_dtype], output_dtype=DTYPE_MAP[output_dtype], diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 7ebfd851c4feb..a1eaa90e85f27 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -50,6 +50,23 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { + // NOTE(woosuk): While Punica supports various combinations of input/output + // data types, we limit the supported data types to reduce the binary size. + constexpr bool is_input_float = std::is_same::value; + constexpr bool is_output_float = std::is_same::value; + if (is_input_float) { + if (!std::is_same::value) { + return false; + } + } else if (is_output_float) { + if (!std::is_same::value) { + return false; + } + } else if (!(std::is_same::value && + std::is_same::value)) { + return false; + } + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ case pack_u32(feat_in, feat_out): \ diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index e9e0c8554c1ef..1616fdfd4cff9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, def _pretest(): linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, vocab_size) + 1024, + vocab_size, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( @@ -445,7 +447,7 @@ def _pretest(): num_inputs=8 * num_loras, # * 3, input_size=(1, 1024), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -494,7 +496,7 @@ def _pretest(): num_inputs=8 * num_loras * 3, input_size=(1, 1024), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: def create_random_linear_parallel_layer(): if orientation == "row": - linear = RowParallelLinear(4096, 4096, bias=False) + linear = RowParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = RowParallelLinearWithLoRA(linear) else: - linear = ColumnParallelLinear(4096, 4096, bias=False) + linear = ColumnParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = ColumnParallelLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) @@ -561,7 +569,7 @@ def create_random_linear_parallel_layer(): num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -600,7 +608,7 @@ def create_random_linear_parallel_layer(): num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: def create_column_parallel_packed_layer(): if repeats == 2: linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False) + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = MergedColumnParallelLinearWithLoRA(linear) elif repeats == 3: - linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = MergedQKVParallelLinearWithLora(linear) else: - linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = QKVParallelLinearWithLora(linear) @@ -676,7 +693,7 @@ class FakeConfig: num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -716,7 +733,7 @@ class FakeConfig: num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) From 3cd9b5bb2d4a0d5eed07186ae140f5dc8f839708 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 24 Apr 2024 09:00:20 -0700 Subject: [PATCH 111/413] [Core][Distributed] use existing torch.cuda.device (#4318) [Core][Distributed] use existing torch.cuda.device context manager (#4318) --- vllm/distributed/device_communicators/pynccl.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fcedf0fed34cb..e922beba44bfa 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -250,15 +250,13 @@ def __init__( assert isinstance(device, torch.device) self.device = device # nccl communicator and stream will use this device - current_device = torch.cuda.current_device() - try: - torch.cuda.set_device(device) + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): NCCL_CHECK( _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, self.unique_id, self.rank)) self.stream = torch.cuda.Stream() - finally: - torch.cuda.set_device(current_device) def all_reduce(self, tensor: torch.Tensor, From 7923dcad12edae7dcfd6e0cf7ce2984b3bcecf0f Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 24 Apr 2024 09:49:13 -0700 Subject: [PATCH 112/413] [Misc] Update ShareGPT Dataset Sampling in Serving Benchmark (#4279) --- benchmarks/benchmark_serving.py | 50 ++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6054df439fa57..2c2d69da4a7d1 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -27,7 +27,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import AsyncGenerator, List, Tuple +from typing import AsyncGenerator, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, @@ -58,7 +58,11 @@ def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -68,38 +72,32 @@ def sample_sharegpt_requests( dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] - # some of these will be filtered out, so sample more than we need - sampled_indices = random.sample(range(len(dataset)), - int(num_requests * 1.2)) - dataset = [dataset[i] for i in sampled_indices] - - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + # Shuffle the dataset. + random.shuffle(dataset) - # Filter out too long sequences. + # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len if prompt_len < 4 or output_len < 4: # Prune too short sequences. - # This is because TGI causes errors when the input or output length - # is too short. continue if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset def sample_sonnet_requests( @@ -361,6 +359,7 @@ def main(args: argparse.Namespace): dataset_path=args.dataset, num_requests=args.num_prompts, tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "sharegpt": @@ -368,6 +367,7 @@ def main(args: argparse.Namespace): dataset_path=args.dataset_path, num_requests=args.num_prompts, tokenizer=tokenizer, + fixed_output_len=args.sharegpt_output_len, ) elif args.dataset_name == "sonnet": @@ -524,6 +524,12 @@ def main(args: argparse.Namespace): default=1000, help="Number of prompts to process.", ) + parser.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") parser.add_argument( "--sonnet-input-len", type=int, From aae08249acca69060d0a8220cab920e00520932c Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Wed, 24 Apr 2024 13:35:01 -0400 Subject: [PATCH 113/413] [Bugfix] Fix marlin kernel crash on H100 (#4218) This PR addresses the Marlin kernel H100 crash that was reported here: neuralmagic#187. The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one. --- .../quantization/marlin/marlin_cuda_kernel.cu | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index cf1b0afdec8b4..002a70001885d 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -67,20 +67,13 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -// Asynchronous global->shared copy with a cache hint indicating that the values -// may be evicted immediately; used for quantized weights B, which are only -// accessed precisely once and should thus not pollute the L2 cache which we -// need for inputs A and outputs C. -__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) { +// Asynchronous global->shared copy +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -448,14 +441,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { int4 *sh_s_stage = sh_s + s_sh_stage * pipe; if (s_sh_wr_pred) - cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -750,7 +743,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // write-out if (group_blocks == -1 && last) { if (s_sh_wr_pred) - cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce(); From 2768884ac4a026609efceef92edea55839af0c30 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 24 Apr 2024 14:09:44 -0700 Subject: [PATCH 114/413] [Doc] Add note for docker user (#4340) Co-authored-by: Simon Mo --- docs/source/serving/deploying_with_docker.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/serving/deploying_with_docker.rst b/docs/source/serving/deploying_with_docker.rst index 7ec769630300d..cfc462ff33b90 100644 --- a/docs/source/serving/deploying_with_docker.rst +++ b/docs/source/serving/deploying_with_docker.rst @@ -49,3 +49,6 @@ To run vLLM: --env "HUGGING_FACE_HUB_TOKEN=" \ vllm/vllm-openai +.. note:: + + vLLM docker image is currently designed to be run under the root user (contribution welcomed for changing this!). It will try to load library at runtime under the root user's home directory, e.g. `/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . If you are running the container under a different user, you may need to change the permissions of the library (and all the parent directories) to allow the user to access it. Then run vLLM with environment variable `VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . From a395a638c2f18d549e7d01655cf7a6dbee566f91 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 24 Apr 2024 14:10:24 -0700 Subject: [PATCH 115/413] [Misc] Use public API in benchmark_throughput (#4300) --- benchmarks/benchmark_throughput.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6bb889d1eceba..695d06e7b243d 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -103,25 +103,22 @@ def run_vllm( ) # Add the requests to the engine. + prompts = [] + sampling_params = [] for prompt, _, output_len in requests: - sampling_params = SamplingParams( - n=n, - temperature=0.0 if use_beam_search else 1.0, - top_p=1.0, - use_beam_search=use_beam_search, - ignore_eos=True, - max_tokens=output_len, - ) - # FIXME(woosuk): Do not use internal method. - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) start = time.perf_counter() - # FIXME(woosuk): Do not use internal method. - llm._run_engine(use_tqdm=True) + llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() return end - start From 96e90fdeb3c4ebacfe24513556afccb918722b7c Mon Sep 17 00:00:00 2001 From: Caio Mendes Date: Thu, 25 Apr 2024 00:06:57 -0300 Subject: [PATCH 116/413] [Model] Adds Phi-3 support (#4298) --- README.md | 1 + docs/source/models/supported_models.rst | 4 + vllm/config.py | 2 +- .../model_executor/layers/rotary_embedding.py | 136 +++++++++++++++++- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/llama.py | 14 +- 6 files changed, 149 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 947d50d4ad764..fce3de3b70430 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) +- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) - Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 951fc3aac0c75..f4dd5d52ad873 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it. - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - + * - :code:`Phi3ForCausalLM` + - Phi-3 + - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. diff --git a/vllm/config.py b/vllm/config.py index 2ff42de08f8f7..311a69f822571 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1056,7 +1056,7 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None: + if rope_scaling is not None and rope_scaling["type"] != "su": assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index a5225148d7828..b8361af61ae3f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -338,6 +338,114 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: return cache +class Phi3SuScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + short_factor: List[float], + long_factor: List[float], + short_mscale: float = 1.1, + long_mscale: float = 1.225, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \ + head_size ({rotary_dim}!={head_size}).") + if is_neox_style is False: + raise ValueError( + "`Phi3SuScaledRotaryEmbedding` only supports neox_style.") + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(torch.get_default_dtype()) + self.register_buffer("short_cos_sin_cache", + short_cache, + persistent=False) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(torch.get_default_dtype()) + self.register_buffer("long_cos_sin_cache", + long_cache, + persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( + idx.device) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -349,17 +457,26 @@ def get_rope( is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, - tuple(rope_scaling.items()) if rope_scaling is not None else None) + rope_scaling_args) if key in _ROPE_DICT: return _ROPE_DICT[key] - if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) else: scaling_type = rope_scaling["type"] - scaling_factor = rope_scaling["factor"] + if scaling_type != "su": + scaling_factor = rope_scaling["factor"] if scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, @@ -383,6 +500,19 @@ def get_rope( base, is_neox_style, scaling_factor, **extra_kwargs) + elif scaling_type == "su": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3SuScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, short_factor, long_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 17fc970568042..c0aeab5dd3032 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -46,6 +46,7 @@ "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 016e3b039d1e8..c102b40045c92 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -180,6 +180,10 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) sliding_window = getattr(config, "sliding_window", None) @@ -378,11 +382,11 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: From 479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 24 Apr 2024 23:52:22 -0700 Subject: [PATCH 117/413] [Core] Move ray_utils.py from `engine` to `executor` package (#4347) --- vllm/__init__.py | 2 +- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 2 +- vllm/executor/ray_gpu_executor.py | 10 ++++++---- vllm/{engine => executor}/ray_utils.py | 0 vllm/transformers_utils/tokenizer_group/__init__.py | 2 +- .../tokenizer_group/ray_tokenizer_group.py | 2 +- 7 files changed, 11 insertions(+), 9 deletions(-) rename vllm/{engine => executor}/ray_utils.py (100%) diff --git a/vllm/__init__.py b/vllm/__init__.py index 5ca4680227598..ca454efd44b24 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -3,8 +3,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.model_executor.models import ModelRegistry from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3a2f7db679358..4b007d71e9cfc 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_ray_cluster, ray +from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 19e58fb1722cf..56c2417d6a6e6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -15,8 +15,8 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group -from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase +from vllm.executor.ray_utils import initialize_ray_cluster from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index e69f104e7d5a4..14b3f803782c6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -5,8 +5,8 @@ from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.engine.ray_utils import RayWorkerWrapper, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -74,7 +74,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # The driver dummy worker does not actually use any resources. # It holds the resource for the driver worker. - self.driver_dummy_worker: RayWorkerWrapper = None + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None # The remaining workers are the actual ray actors. self.workers: List[RayWorkerWrapper] = [] @@ -318,6 +318,7 @@ def _run_workers( driver_worker_output = self.driver_worker.execute_method( method, *driver_args, **driver_kwargs) else: + assert self.driver_dummy_worker is not None driver_worker_output = ray.get( self.driver_dummy_worker.execute_method.remote( method, *driver_args, **driver_kwargs)) @@ -353,8 +354,9 @@ def _compiled_ray_dag(self): # a dummy value for now. It will be fixed soon. with InputNode() as input_data: forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote.bind(input_data) - for worker in self.workers + worker.execute_model_compiled_dag_remote. + bind( # type: ignore[attr-defined] + input_data) for worker in self.workers ]) return forward_dag.experimental_compile() diff --git a/vllm/engine/ray_utils.py b/vllm/executor/ray_utils.py similarity index 100% rename from vllm/engine/ray_utils.py rename to vllm/executor/ray_utils.py diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 69380d67f9b94..0195c40c27f60 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,7 +1,7 @@ from typing import Optional from vllm.config import TokenizerPoolConfig -from vllm.engine.ray_utils import ray +from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index f3cdc00564dbb..7c605416854b8 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer from vllm.config import TokenizerPoolConfig -from vllm.engine.ray_utils import ray +from vllm.executor.ray_utils import ray from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) From fbf152d976e655f8734561452a1036e63081ccfd Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Fri, 26 Apr 2024 00:35:56 +0800 Subject: [PATCH 118/413] [Bugfix][Model] Refactor OLMo model to support new HF format in transformers 4.40.0 (#4324) Co-authored-by: Woosuk Kwon --- README.md | 2 +- docs/source/models/supported_models.rst | 2 +- requirements-dev.txt | 1 - vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/olmo.py | 305 ++++++++++++------------ 5 files changed, 150 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index fce3de3b70430..e59a1c60cc369 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) +- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f4dd5d52ad873..ceb658bbd5c66 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -101,7 +101,7 @@ Alongside each architecture, we include some popular models that use it. - * - :code:`OLMoForCausalLM` - OLMo - - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc. + - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. - * - :code:`OPTForCausalLM` - OPT, OPT-IML diff --git a/requirements-dev.txt b/requirements-dev.txt index 1317e51b2dd11..d9816828d007d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -26,7 +26,6 @@ requests ray peft awscli -ai2-olmo # required for OLMo # Benchmarking aiohttp diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c0aeab5dd3032..6afb2f31c1334 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -42,7 +42,7 @@ "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index b92003bc0e067..15527569b9e20 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -1,53 +1,36 @@ # coding=utf-8 # Adapted from -# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and -# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py -# Copyright 2023 The vLLM team. -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. # -# BSD 3-Clause License +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: +# http://www.apache.org/licenses/LICENSE-2.0 # -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" from typing import Iterable, List, Optional, Tuple import torch -# this model must need this dependency -from hf_olmo import OLMoConfig from torch import nn +from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -55,7 +38,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -70,55 +53,52 @@ class OlmoAttention(nn.Module): def __init__( self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.hidden_size = config.d_model - assert config.d_model % config.n_heads == 0 + self.hidden_size = config.hidden_size tensor_model_parallel_world_size = ( get_tensor_model_parallel_world_size()) - self.total_num_heads = self.config.n_heads + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv - # Layer norms. - self.attn_norm = nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False) # Attention input projection. Projects x -> (q, k, v) - self.att_proj = QKVParallelLinear( - config.d_model, + self.qkv_proj = QKVParallelLinear( + self.hidden_size, self.head_dim, self.total_num_heads, - bias=config.include_bias, + bias=config.attention_bias, linear_method=linear_method, ) # Rotary embeddings. - if self.config.rope: - rope_theta = getattr(config, "rope_theta", 10000) - max_position_embeddings = getattr(config, - "max_position_embeddings", 8192) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling) # Attention output projection. - self.attn_out = RowParallelLinear( - config.d_model, - config.d_model, - bias=config.include_bias, + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, linear_method=linear_method, ) @@ -129,13 +109,13 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - hidden_states = self.attn_norm(hidden_states) - qkv, _ = self.att_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) - if self.config.rope: - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.attn_out(attn_output) + output, _ = self.o_proj(attn_output) return output @@ -148,37 +128,30 @@ class OlmoMLP(nn.Module): def __init__( self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size - is not None else config.mlp_ratio * config.d_model) - - # Layer norms. - self.ff_norm = nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size # Feed-forward input projection. - self.ff_proj = MergedColumnParallelLinear( - config.d_model, - [self.hidden_size // 2] * 2, - bias=config.include_bias, + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, linear_method=linear_method, ) # Activation function. - self.act = SiluAndMul() - self.act.output_multiplier = 0.5 - assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + self.act_fn = SiluAndMul() # Feed-forward output projection. - self.ff_out = RowParallelLinear( - int(self.act.output_multiplier * self.hidden_size), - config.d_model, - bias=config.include_bias, + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, linear_method=linear_method, ) @@ -186,19 +159,13 @@ def forward( self, x: torch.Tensor, ) -> torch.Tensor: - # Add feed-forward projection. - # shape: (batch_size, seq_len, d_model) - og_x = x - x = self.ff_norm(x) - x, _ = self.ff_proj(x) - x = self.act(x) - x, _ = self.ff_out(x) - x = og_x + x - + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) return x -class OlmoBlock(nn.Module): +class OlmoDecoderLayer(nn.Module): """ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` @@ -206,15 +173,23 @@ class OlmoBlock(nn.Module): """ def __init__(self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() # Attention block. - self.attn = OlmoAttention(config, linear_method) + self.self_attn = OlmoAttention(config, linear_method) # MLP block. self.mlp = OlmoMLP(config, linear_method) + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + def forward( self, positions: torch.Tensor, @@ -223,52 +198,37 @@ def forward( attn_metadata: AttentionMetadata, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: # Attention block. - og_x = hidden_states - x = self.attn(positions, hidden_states, kv_cache, attn_metadata) - x = x + og_x + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual # MLP block. - hidden_states = self.mlp(x) + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states return hidden_states class OlmoModel(nn.Module): def __init__(self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config - self.transformer = nn.ModuleDict( - dict( - wte=VocabParallelEmbedding( - config.embedding_size or config.vocab_size, - config.d_model, - ), - ln_f=nn.LayerNorm(config.d_model, - elementwise_affine=False, - bias=False), - )) - - blocks = [ - OlmoBlock(config, linear_method) for i in range(config.n_layers) - ] - if self.config.block_group_size > 1: - raise NotImplementedError("Block group size > 1 not supported yet") - else: - self.transformer.update({"blocks": nn.ModuleList(blocks)}) - - if not config.weight_tying: - self.transformer.update({ - "ff_out": - ColumnParallelLinear( - config.d_model, - config.embedding_size or config.vocab_size, - bias=config.include_bias, - linear_method=linear_method, - ) - }) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, linear_method) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) def forward( self, @@ -282,39 +242,49 @@ def forward( """ # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - x = self.transformer.wte(input_ids) # type: ignore + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds # Apply blocks one-by-one. - for block_idx, block in enumerate(self.transformer.blocks): + for layer_idx, decoder_layer in enumerate(self.layers): # shape: (batch_size, seq_len, d_model) - x = block( + hidden_states = decoder_layer( positions, - x, - kv_caches[block_idx], + hidden_states, + kv_caches[layer_idx], attn_metadata, ) # Apply final layer norm. # shape: (batch_size, seq_len or 1, d_model) - x = self.transformer.ln_f(x) # type: ignore - return x + hidden_states = self.norm(hidden_states) + return hidden_states -class OLMoForCausalLM(nn.Module): +class OlmoForCausalLM(nn.Module): """ Extremely barebones HF model wrapper. """ def __init__(self, - config: OLMoConfig, + config: OlmoConfig, linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config self.linear_method = linear_method self.model = OlmoModel(config, linear_method) - self.lm_head_weight = (self.model.transformer.wte.weight - if config.weight_tying else - self.model.transformer.ff_out.weight) + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -348,20 +318,39 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - # attention - if ".att" in name: - name = name.replace(".att", ".attn.att") - # mlp - if ".ff_proj" in name: - name = name.replace(".ff_proj", ".mlp.ff_proj") - # Reverse the weight for the MergeColumnParallelLinear - loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1]) - if ".ff_out" in name and "transformer.ff_out" not in name: - name = name.replace(".ff_out", ".mlp.ff_out") - # there is no bias in olmo - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 7ee82bef1e71febc28757807f5df8191bb36d88e Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:37:20 -0500 Subject: [PATCH 119/413] [CI/Build] Adding functionality to reset the node's GPUs before processing. (#4213) --- .buildkite/run-amd-test.sh | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 83a56e25aca73..38aff57a410dc 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -5,6 +5,19 @@ set -ex # Print ROCm version rocminfo + +echo "reset" > /opt/amdgpu/etc/gpu_state + +while true; do + sleep 3 + if grep -q clean /opt/amdgpu/etc/gpu_state; then + echo "GPUs state is \"clean\"" + break + fi +done + + + # Try building the docker image docker build -t rocm -f Dockerfile.rocm . @@ -14,7 +27,8 @@ trap remove_docker_container EXIT remove_docker_container # Run the image -docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server & +export HIP_VISIBLE_DEVICES=1 +docker run --device /dev/kfd --device /dev/dri --network host -e HIP_VISIBLE_DEVICES --name rocm rocm python3 -m vllm.entrypoints.api_server & # Wait for the server to start wait_for_server_to_start() { From bd7a8eef25cd85be7eb9f2a94fd752d27ee7dce3 Mon Sep 17 00:00:00 2001 From: Caio Mendes Date: Thu, 25 Apr 2024 14:32:00 -0300 Subject: [PATCH 120/413] [Doc] README Phi-3 name fix. (#4372) Co-authored-by: Caio Mendes --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e59a1c60cc369..524d027137aba 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) -- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) +- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) - Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) From f4bc4de1b1a1bd33bd3ec7dea3377eb75884250a Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 25 Apr 2024 19:03:56 +0000 Subject: [PATCH 121/413] [Core]refactor aqlm quant ops (#4351) --- benchmarks/kernels/benchmark_aqlm.py | 2 +- vllm/_custom_ops.py | 14 ++++++++++++++ vllm/model_executor/layers/quantization/aqlm.py | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 9602d20bcbc74..59392947b15c8 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.aqlm import ( dequantize_weight, generic_dequantize_gemm, get_int_dtype, optimized_dequantize_gemm) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e4b16ed918d1a..508d35656eb00 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: torch.Tensor) -> torch.Tensor: + return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + + # fp8 def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: scale = torch.zeros(1, device=input.device, dtype=torch.float32) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 6115b1de679ad..b48c6e1702be4 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( From b5b4a398a794e3276729a525e58eaa92f5fc0212 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 26 Apr 2024 04:13:50 +0900 Subject: [PATCH 122/413] [Mypy] Typing lora folder (#4337) --- .github/workflows/mypy.yaml | 7 ++-- format.sh | 2 +- vllm/lora/layers.py | 35 ++++++++++++-------- vllm/lora/lora.py | 28 +++++++++------- vllm/lora/models.py | 64 ++++++++++++++++++++----------------- vllm/lora/worker_manager.py | 21 ++++++------ vllm/worker/model_runner.py | 4 +-- 7 files changed, 91 insertions(+), 70 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 9f1855696e20a..089c7d18ad6f2 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -33,8 +33,6 @@ jobs: - name: Mypy run: | mypy vllm/attention --config-file pyproject.toml - # TODO(sang): Fix nested dir - mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml @@ -44,8 +42,9 @@ jobs: mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml + # TODO(sang): Fix nested dir mypy vllm/model_executor/*.py --config-file pyproject.toml - # TODO(sang): Fix nested dir - # mypy vllm/lora/*.py --config-file pyproject.toml + mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/format.sh b/format.sh index bd2e9e89e1806..4ac1842daef0a 100755 --- a/format.sh +++ b/format.sh @@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor/*.py --config-file pyproject.toml -# mypy vllm/lora/*.py --config-file pyproject.toml +mypy vllm/lora --config-file pyproject.toml CODESPELL_EXCLUDES=( diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index aac86351b15e1..98e74168002c4 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( self, @@ -233,9 +235,10 @@ def create_lora_weights( self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], self.lora_a_stacked.shape[2], ) - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None - self.embeddings_indices = None + # Lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + self.embeddings_indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -267,6 +270,7 @@ def set_lora( self.embeddings_tensors.shape[1], self.embeddings_tensors.shape[2] )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def set_mapping( @@ -343,11 +347,12 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None self.output_dim = self.lora_b_stacked.shape[2] + # lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 @@ -475,8 +480,9 @@ def create_lora_weights( device=self.device, ) for _ in range(n_slices)) - self.indices: Optional[torch.Tensor] = None self.output_dim = self.lora_b_stacked[0].shape[2] + # Lazily initialized. + self.indices: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -690,7 +696,8 @@ def create_lora_weights( self.kv_proj_shard_size) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None + # lazily initialized. + self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -814,8 +821,9 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - self.indices: Optional[torch.Tensor] = None - self.indices_len: Optional[List[int]] = None + # Lazily initialized + self.indices: torch.Tensor + self.indices_len: List[int] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 @@ -991,9 +999,10 @@ def create_lora_weights( dtype=self.dtype, device=self.device, ) - self.indices = None - self.indices_padded = None - self.indices_len = None + # Lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + self.indices_padded: torch.Tensor def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index fefad16700fe3..d7794aa7cd35c 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -97,9 +97,9 @@ def __init__( self, module_name: str, rank: int, - lora_alphas: List[int], - lora_a: List[torch.Tensor], - lora_b: List[torch.Tensor], + lora_alphas: List[Optional[int]], + lora_a: List[Optional[torch.Tensor]], + lora_b: List[Optional[torch.Tensor]], scaling: Optional[List[float]] = None, ) -> None: super().__init__( @@ -108,17 +108,20 @@ def __init__( lora_alpha=0, lora_a=lora_a, lora_b=lora_b, - scaling=scaling, + scaling=scaling, # type: ignore embeddings_tensor=None, ) self.lora_alphas = lora_alphas if scaling is None: - self.scaling = [ - lora_alpha / self.rank for lora_alpha in self.lora_alphas + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas ] @classmethod - def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + def pack( + cls, loras: List[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": """Pack a list of LoRAs into a single LoRA. If LoRA is None, it signifies that the submodule does not have a LoRA. @@ -136,16 +139,19 @@ def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], - scaling=[1 if lora is not None else None for lora in loras]) + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) return obj def optimize(self) -> "PackedLoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" for i in range(len(self.lora_b)): - if self.scaling[i] == 1 or self.lora_b[i] is None: + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore continue - self.lora_b[i] *= self.scaling[i] - self.scaling[i] = 1 + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore return self @property diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6bb9fee27d535..c249497a4d893 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,7 +3,7 @@ import math import os import re -from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type import safetensors.torch import torch @@ -53,44 +53,46 @@ def convert_mapping( embeddings. indices_len: List of lengths of the above tensors. """ - indices = list(mapping.index_mapping).copy() - embedding_indices = indices.copy() - lora_indices = indices.copy() - prompt_mapping = [ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None - for i in range(len(indices)): + for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(indices[i]) - if indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if indices[i] > 0 else 0 - indices[i] = i + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + index_mapping_indices[i] = i lora_indices[i] = lora_idx - indices = torch.tensor([indices, lora_indices, embedding_indices], - dtype=torch.long, - device="cuda") - prompt_mapping = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) + indices = torch.tensor( + [index_mapping_indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size) ]) embeddings_indices[embeddings_indices == -1] = max_loras - 1 base_indices = indices[1] - sampler_indices = prompt_mapping + sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded = ( torch.arange( 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (sampler_indices_padded * len(sampler_indices_padded))) - indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], - embeddings_indices.shape[-1]) + indices_len = [ + base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] + ] return (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, indices_len) @@ -149,6 +151,7 @@ def from_lora_tensors( if module_name not in loras: lora_embeddings_tensor = None if embeddings: + assert embedding_modules is not None embeddings_module = next( (k for k in embedding_modules if k in module_name), None) @@ -171,6 +174,7 @@ def from_lora_tensors( else: loras[module_name].lora_b = tensor.to(device=device, dtype=dtype).t() + assert embedding_padding_modules is not None if any(name in module_name for name in embedding_padding_modules ) and target_embedding_padding is not None: @@ -295,11 +299,10 @@ def __init__( self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.offsets = [] # 4 is the number of indicies tensors defined above # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices - self.indices_len = [None] * 4 + self.indices_len: List[Optional[int]] = [None] * 4 self.model: nn.Module = model if hasattr(self.model, "supported_lora_modules"): @@ -312,7 +315,7 @@ def __init__( self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. self._active_loras: Dict[int, None] = {} - self._last_mapping = None + self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self @@ -370,7 +373,7 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False - def _add_lora(self, lora: LoRAModel) -> bool: + def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_loras[lora.id] = lora @@ -418,7 +421,7 @@ def list_loras(self) -> Dict[int, LoRAModel]: def get_lora(self, lora_id: int) -> Optional[LoRAModel]: return self._registered_loras.get(lora_id, None) - def remove_all_loras(self) -> bool: + def remove_all_loras(self): """Remove all LoRAModels from the manager.""" self._registered_loras.clear() self.lora_index_to_id = [None] * self.lora_slots @@ -467,6 +470,7 @@ def create_dummy_lora( continue parts = module_name.split(".") if module_name not in self.packed_modules: + assert embedding_modules is not None if parts[-1] in embedding_modules: input_dim = (module.base_layer.org_vocab_size + self.lora_config.lora_extra_vocab_size if @@ -500,7 +504,7 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - subloras = [] + subloras: List[Optional["LoRALayerWeights"]] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, @@ -538,7 +542,7 @@ def _register_packed_modules(self, module_full_name: str) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): - replacement_loras = [] + replacement_loras: List[Optional[LoRALayerWeights]] = [] has_replacement = False for r in new_module_names: lora = lora_model.get_lora(r) @@ -557,12 +561,12 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: class LoRALRUCache(LRUCache[LoRAModel]): - def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], - None]): + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): super().__init__(capacity) self.deactivate_lora_fn = deactivate_lora_fn - def _on_remove(self, key: Hashable, value: LoRAModel): + def _on_remove(self, key: int, value: LoRAModel): logger.debug(f"Removing LoRA. int id: {key}") self.deactivate_lora_fn(key) return super()._on_remove(key, value) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 5356b79537b05..ec3c10c591a18 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Optional, Set, Type +from typing import Any, Dict, List, Set, Type import torch @@ -37,7 +37,7 @@ def create_lora_manager( ... @abstractmethod - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: ... @@ -54,7 +54,7 @@ def remove_lora(self, lora_id: int) -> bool: ... @abstractmethod - def remove_all_loras(self) -> bool: + def remove_all_loras(self): ... @abstractmethod @@ -81,10 +81,11 @@ def __init__( embedding_padding_modules: List[str], lora_model_cls: Type[LoRAModel] = LoRAModel, ): - self._lora_manager: Optional[LoRAModelManager] = None self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules + # Lazily initialized by create_lora_manager. + self._lora_manager: LoRAModelManager super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device) @@ -104,7 +105,7 @@ def create_lora_manager( lora_config=self.lora_config, lora_manager_cls=self._lora_manager_cls, ) - self._lora_manager: LoRAModelManager = lora_manager + self._lora_manager = lora_manager return lora_manager.model def set_active_loras(self, lora_requests: Set[LoRARequest], @@ -188,7 +189,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self._lora_manager.remove_lora(lora_id) - def remove_all_loras(self) -> bool: + def remove_all_loras(self): self._lora_manager.remove_all_loras() def list_loras(self) -> Set[int]: @@ -217,10 +218,10 @@ def create_lora_manager( lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager: LRUCacheLoRAModelManager = lora_manager + self._lora_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request @@ -237,12 +238,14 @@ def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id not in self.list_loras(): # Remove before we load the new lora to save memory if len(self._lora_manager) + 1 > self._lora_manager.capacity: + assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) self._lora_manager.remove_oldest_lora() lora = self._load_lora(lora_request) loaded = self._lora_manager.add_lora(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + loaded = self._lora_manager.get_lora( + lora_request.lora_int_id) is not None self._lora_manager.activate_lora(lora_request.lora_int_id) return loaded diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 31e08789dfd1f..33dbf8d90c35d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -928,10 +928,10 @@ def profile_run(self) -> None: torch.cuda.synchronize() return - def remove_all_loras(self) -> bool: + def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_loras() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: From b6dcb4d44281a9e85cafcfa6376c373d02286779 Mon Sep 17 00:00:00 2001 From: Roy Date: Fri, 26 Apr 2024 03:43:32 +0800 Subject: [PATCH 123/413] [Misc] Fix flash attention backend log (#4368) --- vllm/attention/selector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 554e802cd5513..7cc17f21dcd0e 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -25,7 +25,7 @@ class _Backend(enum.Enum): def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: backend = _which_attn_to_use(dtype) if backend == _Backend.FLASH_ATTN: - logger.info("Using FlashAttention backend.") + logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend @@ -62,12 +62,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: # NVIDIA GPUs. if torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. - logger.info("Cannot use FlashAttention backend for Volta and Turing " + logger.info("Cannot use FlashAttention-2 backend for Volta and Turing " "GPUs.") return _Backend.XFORMERS if dtype not in (torch.float16, torch.bfloat16): - logger.info("Cannot use FlashAttention backend for dtype other than " + logger.info("Cannot use FlashAttention-2 backend for dtype other than " "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS @@ -75,8 +75,8 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: import flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention backend because the flash_attn package " - "is not found. Please install it for better performance.") + "Cannot use FlashAttention-2 backend because the flash_attn " + "package is not found. Please install it for better performance.") return _Backend.XFORMERS backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) From 15e7c675b0dc36109c7b591f856f102e96493a94 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 25 Apr 2024 16:32:48 -0700 Subject: [PATCH 124/413] [Core] Add `shutdown()` method to `ExecutorBase` (#4349) --- vllm/engine/llm_engine.py | 6 ++++++ vllm/executor/executor_base.py | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 56c2417d6a6e6..7de60d738113e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -287,6 +287,12 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 1839b5603ff3e..1838c34be2fda 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -95,6 +95,13 @@ def check_health(self) -> None: exception.""" raise NotImplementedError + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + class ExecutorAsyncBase(ExecutorBase): From efffb63f584c1ce4fdcf4e7b7fd0bfc8b33a733a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 25 Apr 2024 16:45:12 -0700 Subject: [PATCH 125/413] [Core] Move function tracing setup to util function (#4352) --- vllm/utils.py | 21 ++++++++++++++++++++- vllm/worker/worker_base.py | 18 ++++-------------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 15c8818cc4506..79ac1db01fc69 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,10 +1,13 @@ import asyncio +import datetime import enum import gc import glob import os import socket import subprocess +import tempfile +import threading import uuid import warnings from collections import defaultdict @@ -18,7 +21,7 @@ import torch from packaging.version import Version, parse -from vllm.logger import init_logger +from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") logger = init_logger(__name__) @@ -607,3 +610,19 @@ def find_nccl_library(): raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info(f"Found nccl from library {so_file}") return so_file + + +def enable_trace_function_call_for_thread() -> None: + """Set up function tracing for the current thread, + if enabled via the VLLM_TRACE_FUNCTION environment variable + """ + + if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b5dade0a770a0..0a89e3a79769f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,15 +1,13 @@ -import datetime import importlib import os -import tempfile -import threading from abc import ABC, abstractmethod from typing import Dict, List, Set, Tuple -from vllm.logger import enable_trace_function_call, init_logger +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import get_vllm_instance_id, update_environment_variables +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) logger = init_logger(__name__) @@ -128,15 +126,7 @@ def init_worker(self, *args, **kwargs): function tracing if required. Arguments are passed to the worker class constructor. """ - if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): - tmp_dir = tempfile.gettempdir() - filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), - filename) - os.makedirs(os.path.dirname(log_path), exist_ok=True) - enable_trace_function_call(log_path) + enable_trace_function_call_for_thread() mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) From cf29b7eda47d5a8c5fe6a7a53490271da8520563 Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Thu, 25 Apr 2024 21:12:25 -0400 Subject: [PATCH 126/413] [ROCm][Hardware][AMD][Doc] Documentation update for ROCm (#4376) Co-authored-by: WoosukKwon --- .../getting_started/amd-installation.rst | 165 +++++++----------- 1 file changed, 65 insertions(+), 100 deletions(-) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 3d736bf7120ec..61fcd45a26347 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,9 +3,7 @@ Installation with ROCm ====================== -vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm. -At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. -Data types currently supported in ROCm are FP16 and BF16. +vLLM supports AMD GPUs with ROCm 5.7 and 6.0. Requirements ------------ @@ -13,114 +11,57 @@ Requirements * OS: Linux * Python: 3.8 -- 3.11 * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -* Pytorch 2.0.1/2.1.1/2.2 -* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9) +* ROCm 6.0 and ROCm 5.7 Installation options: -#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image ` -#. :ref:`Build from source ` #. :ref:`Build from source with docker ` +#. :ref:`Build from source ` -.. _quick_start_docker_rocm: - -(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image ---------------------------------------------------------------------------- - -This option is for ROCm 5.7 only: - -.. code-block:: console - - $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4 - $ docker run -it \ - --network=host \ - --group-add=video \ - --ipc=host \ - --cap-add=SYS_PTRACE \ - --security-opt seccomp=unconfined \ - --device /dev/kfd \ - --device /dev/dri \ - -v :/app/model \ - embeddedllminfo/vllm-rocm \ - bash - - -.. _build_from_source_rocm: - -Option 2: Build from source ---------------------------- - -You can build and install vLLM from source: - -Below instruction is for ROCm 5.7 only. -At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website. - -0. Install prerequisites (skip if you are already in an environment/docker with the following installed): - -- `ROCm `_ -- `Pytorch `_ - - .. code-block:: console - - $ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version - - -1. Install `flash attention for ROCm `_ - - Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ - -.. note:: - - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. - - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - -2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention +.. _build_from_source_docker_rocm: - .. code-block:: console +Option 1: Build from source with docker (recommended) +----------------------------------------------------- - $ pip install xformers==0.0.23 --no-deps - $ bash patch_xformers.rocm.sh +You can build and install vLLM from source. -3. Build vLLM. +First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. - .. code-block:: console +`Dockerfile.rocm `_ uses ROCm 6.0 by default, but also supports ROCm 5.7. +It provides flexibility to customize the build of docker image using the following arguments: - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation +* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` +* `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target. +* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` +* `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo `_. The default is `ae7928c` +* `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1. +Their values can be passed in when running ``docker build`` with ``--build-arg`` options. -.. _build_from_source_docker_rocm: -Option 3: Build from source with docker ------------------------------------------------------ +To build vllm on ROCm 6.0 for MI200 and MI300 series, you can use the default: -You can build and install vLLM from source: +.. code-block:: console -Build a docker image from `Dockerfile.rocm`, and launch a docker container. + $ docker build -f Dockerfile.rocm -t vllm-rocm . -The `Dockerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments: +To build vllm on ROCm 6.0 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: -* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` -* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` -* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo `_. The default is `3d2b6f5` -* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target. +.. code-block:: console -Their values can be passed in when running ``docker build`` with ``--build-arg`` options. + $ docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm . -For example, to build docker image for vllm on ROCm 5.7, you can run: +To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below: .. code-block:: console $ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ -f Dockerfile.rocm -t vllm-rocm . -To build vllm on ROCm 6.0, you can use the default: +To run the above docker image ``vllm-rocm``, use the below command: .. code-block:: console - $ docker build -f Dockerfile.rocm -t vllm-rocm . $ docker run -it \ --network=host \ --group-add=video \ @@ -133,7 +74,13 @@ To build vllm on ROCm 6.0, you can use the default: vllm-rocm \ bash -Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below: +Where the `` is the location where the model is stored, for example, the weights for llama2 or llama3 models. + + +.. _build_from_source_rocm: + +Option 2: Build from source +--------------------------- 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): @@ -141,32 +88,50 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from - `Pytorch `_ - `hipBLAS `_ -1. Install `flash attention for ROCm `_ +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`. - Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ +Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started `_ + +For rocm6.0: + +.. code-block:: console + + $ pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.0 + + +For rocm5.7: + +.. code-block:: console + + $ pip install torch --index-url https://download.pytorch.org/whl/rocm5.7 + + +1. Install `Triton flash attention for ROCm `_ + +Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_ + +2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ + +Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention `_ .. note:: - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. + - If you fail to install `ROCm/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention - - .. code-block:: console - - $ pip install xformers==0.0.23 --no-deps - $ bash patch_xformers.rocm.sh - 3. Build vLLM. - .. code-block:: console +.. code-block:: console - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py install # This may take 5-10 minutes. + $ cd vllm + $ pip install -U -r requirements-rocm.txt + $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation -.. note:: - - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. +.. tip:: + - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. + - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention. + - The ROCm version of pytorch, ideally, should match the ROCm driver version. From a74dee9b62d10767eb0580f196f5e508e9e80a2d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 26 Apr 2024 10:10:48 +0800 Subject: [PATCH 127/413] [Bugfix] Fix parameter name in `get_tokenizer` (#4107) --- tests/tokenization/test_tokenizer.py | 20 ++++++++++++++++++++ vllm/transformers_utils/tokenizer.py | 11 ++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 tests/tokenization/test_tokenizer.py diff --git a/tests/tokenization/test_tokenizer.py b/tests/tokenization/test_tokenizer.py new file mode 100644 index 0000000000000..8db7204f15d4e --- /dev/null +++ b/tests/tokenization/test_tokenizer.py @@ -0,0 +1,20 @@ +import pytest +from transformers import PreTrainedTokenizerBase + +from vllm.transformers_utils.tokenizer import get_tokenizer + +TOKENIZER_NAMES = [ + "facebook/opt-125m", + "gpt2", +] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) +def test_tokenizer_revision(tokenizer_name: str): + # Assume that "main" branch always exists + tokenizer = get_tokenizer(tokenizer_name, revision="main") + assert isinstance(tokenizer, PreTrainedTokenizerBase) + + # Assume that "never" branch always does not exist + with pytest.raises(OSError, match='not a valid git identifier'): + get_tokenizer(tokenizer_name, revision="never") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index c98a673bfed4b..afc02c434dd43 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -58,11 +58,12 @@ def get_tokenizer( *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tokenizer_revision: Optional[str] = None, + revision: Optional[str] = None, download_dir: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface/modelscope.""" + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ if VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -74,7 +75,7 @@ def get_tokenizer( tokenizer_path = snapshot_download( model_id=tokenizer_name, cache_dir=download_dir, - revision=tokenizer_revision, + revision=revision, # Ignore weights - we only need the tokenizer. ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = tokenizer_path @@ -90,7 +91,7 @@ def get_tokenizer( tokenizer_name, *args, trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, + revision=revision, **kwargs) except ValueError as e: # If the error pertains to the tokenizer class not existing or not @@ -114,7 +115,7 @@ def get_tokenizer( tokenizer_name, *args, trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, + revision=revision, **kwargs) else: raise e From 2f30e7c72fca61c8225654880ee1ef89cad1690c Mon Sep 17 00:00:00 2001 From: Norman Mu Date: Thu, 25 Apr 2024 22:36:01 -0700 Subject: [PATCH 128/413] [Frontend] Add --log-level option to api server (#4377) --- vllm/entrypoints/api_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 587142adb9c6b..075de0b4efb2d 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -100,6 +100,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) @@ -110,7 +111,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: uvicorn.run(app, host=args.host, port=args.port, - log_level="debug", + log_level=args.log_level, timeout_keep_alive=TIMEOUT_KEEP_ALIVE, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, From a88081bf768fcc1c662e4f588bd01ca9ddcc6aad Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 26 Apr 2024 16:16:58 +0900 Subject: [PATCH 129/413] [CI] Disable non-lazy string operation on logging (#4326) Co-authored-by: Danny Guinther --- docs/source/conf.py | 5 +- pyproject.toml | 1 + setup.py | 7 ++- vllm/config.py | 16 ++--- vllm/core/scheduler.py | 10 +-- .../device_communicators/custom_all_reduce.py | 8 +-- .../device_communicators/pynccl.py | 15 ++--- .../device_communicators/pynccl_utils.py | 4 +- vllm/distributed/parallel_state.py | 6 +- vllm/distributed/utils.py | 4 +- vllm/engine/async_llm_engine.py | 18 +++--- vllm/engine/llm_engine.py | 61 +++++++++++-------- vllm/engine/metrics.py | 21 ++++--- vllm/entrypoints/openai/api_server.py | 4 +- vllm/entrypoints/openai/serving_chat.py | 11 ++-- vllm/executor/cpu_executor.py | 2 +- vllm/executor/gpu_executor.py | 4 +- vllm/executor/ray_gpu_executor.py | 4 +- vllm/executor/ray_utils.py | 6 +- vllm/logger.py | 2 +- vllm/lora/models.py | 6 +- .../layers/fused_moe/fused_moe.py | 4 +- .../model_executor/model_loader/tensorizer.py | 8 +-- .../model_loader/weight_utils.py | 14 ++--- vllm/model_executor/models/__init__.py | 10 +-- vllm/model_executor/models/gemma.py | 6 +- vllm/spec_decode/spec_decode_worker.py | 3 +- vllm/transformers_utils/configs/dbrx.py | 13 ++-- vllm/transformers_utils/tokenizer.py | 5 +- vllm/utils.py | 20 +++--- vllm/worker/model_runner.py | 27 ++++---- 31 files changed, 176 insertions(+), 149 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index aac8cbb63ebeb..9da5a4991734d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -98,9 +98,10 @@ def setup(app): for mock_target in autodoc_mock_imports: if mock_target in sys.modules: logger.info( - f"Potentially problematic mock target ({mock_target}) found; " + "Potentially problematic mock target (%s) found; " "autodoc_mock_imports cannot mock modules that have already " - "been loaded into sys.modules when the sphinx build starts.") + "been loaded into sys.modules when the sphinx build starts.", + mock_target) class MockedClassDocumenter(autodoc.ClassDocumenter): diff --git a/pyproject.toml b/pyproject.toml index a171d45b4e064..2e026c1ac8911 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ select = [ "SIM", # isort # "I", + "G", ] ignore = [ # star imports diff --git a/setup.py b/setup.py index 4b672e1af8494..6ba36b85ea318 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def compute_num_jobs(self): num_jobs = os.environ.get("MAX_JOBS", None) if num_jobs is not None: num_jobs = int(num_jobs) - logger.info(f"Using MAX_JOBS={num_jobs} as the number of jobs.") + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) else: try: # os.sched_getaffinity() isn't universally available, so fall @@ -81,8 +81,9 @@ def compute_num_jobs(self): nvcc_threads = os.getenv("NVCC_THREADS", None) if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) - logger.info(f"Using NVCC_THREADS={nvcc_threads} as the number" - " of nvcc threads.") + logger.info( + "Using NVCC_THREADS=%d as the number of nvcc threads.", + nvcc_threads) else: nvcc_threads = 1 num_jobs = max(1, num_jobs // nvcc_threads) diff --git a/vllm/config.py b/vllm/config.py index 311a69f822571..887a73d9462dc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -167,9 +167,9 @@ def _verify_quantization(self) -> None: f"supported in ROCm.") if self.quantization != "marlin": logger.warning( - f"{self.quantization} quantization is not fully " + "%s quantization is not fully " "optimized yet. The speed can be slower than " - "non-quantized models.") + "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: if self.max_context_len_to_capture is None: @@ -360,7 +360,7 @@ def verify_with_parallel_config( if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: - logger.warning("Possibly too large swap space. " + msg) + logger.warning("Possibly too large swap space. %s", msg) @dataclass @@ -898,8 +898,8 @@ def verify_with_model_config(self, model_config: ModelConfig): "awq", "gptq" ]: # TODO support marlin and squeezellm - logger.warning(f"{model_config.quantization} quantization is not " - "tested with LoRA yet.") + logger.warning("%s quantization is not tested with LoRA yet.", + model_config.quantization) def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): if scheduler_config.max_num_batched_tokens > 65528: @@ -1008,7 +1008,7 @@ def _get_and_verify_dtype( pass else: # Casting between float16 and bfloat16 is allowed with a warning. - logger.warning(f"Casting {config_dtype} to {torch_dtype}.") + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) return torch_dtype @@ -1051,8 +1051,8 @@ def _get_and_verify_max_len( logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " - f"{possible_keys}. Assuming the model's maximum length is " - f"{default_max_len}.") + "%d. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 99f7a34d336a4..ac3bd7d228e94 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -617,8 +617,9 @@ def _schedule_prefills( if num_new_tokens > self.prompt_limit: logger.warning( - f"Input prompt ({num_new_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", num_new_tokens, + self.prompt_limit) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -631,8 +632,9 @@ def _schedule_prefills( break elif can_allocate == AllocStatus.NEVER: logger.warning( - f"Input prompt ({num_new_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + "Input prompt (%d tokens) is too long" + " and exceeds the capacity of block_manager", + num_new_tokens) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 9dbb427d91ff1..ec4533326e841 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -37,7 +37,7 @@ def init_custom_ar() -> None: return if world_size not in _SUPPORTED_WORLD_SIZES: - logger.warn( + logger.warning( "Custom allreduce is disabled due to an unsupported world size: " "%d. Supported world sizes: %s. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.", world_size, @@ -47,7 +47,7 @@ def init_custom_ar() -> None: # note: num dev can be larger than world_size if we're only using # first few GPUs if num_dev < world_size: - logger.warn( + logger.warning( "Cannot test GPU P2P because not all GPUs are visible to the " "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" " is set.") @@ -62,7 +62,7 @@ def init_custom_ar() -> None: # this checks hardware and driver support for NVLink full_nvlink = _is_full_nvlink(device_ids) if world_size > 2 and not full_nvlink: - logger.warn( + logger.warning( "Custom allreduce is disabled because it's not supported on more" " than two PCIe-only GPUs. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.") @@ -71,7 +71,7 @@ def init_custom_ar() -> None: # this is expensive to compute at the first time # then we cache the result if not _can_p2p(rank, world_size): - logger.warn( + logger.warning( "Custom allreduce is disabled because your platform lacks GPU P2P" " capability or P2P test failed. To silence this warning, specify" " disable_custom_all_reduce=True explicitly.") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index e922beba44bfa..9434867e1b120 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -43,15 +43,16 @@ nccl = ctypes.CDLL(so_file) except Exception as e: logger.error( - f"Failed to load NCCL library from {so_file} ." + "Failed to load NCCL library from %s ." "It is expected if you are not running on NVIDIA/AMD GPUs." "Otherwise, the nccl library might not exist, be corrupted " - f"or it does not support the current platform {platform.platform()}." - f"One solution is to download libnccl2 version 2.18 from " - f"https://developer.download.nvidia.com/compute/cuda/repos/ " - f"and extract the libnccl.so.2 file. If you already have the " - f"library, please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.") + "or it does not support the current platform %s." + "One solution is to download libnccl2 version 2.18 from " + "https://developer.download.nvidia.com/compute/cuda/repos/ " + "and extract the libnccl.so.2 file. If you already have the " + "library, please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) raise e # === export types and functions from nccl to Python === diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py index a717fddb695ba..44e4f39217a41 100644 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ b/vllm/distributed/device_communicators/pynccl_utils.py @@ -14,7 +14,7 @@ except Exception as e: # in non-NVIDIA environments, we can't import the nccl module # e.g. when running on machines with AMD GPUs - logger.info(f"Failed to import NCCL library: {e}") + logger.info("Failed to import NCCL library: %s", e) logger.info("It is expected if you are not running on NVIDIA GPUs.") pass @@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream): def init_process_group(group: Optional[ProcessGroup] = None) -> None: assert not is_initialized() global comm - logger.info(f"vLLM is using nccl=={ncclGetVersion()}") + logger.info("vLLM is using nccl==%s", ncclGetVersion()) comm = NCCLCommunicator(group=group) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 515f2212511b7..6ca6fc5b5f9fe 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -57,8 +57,10 @@ def init_distributed_environment( local_rank: int = -1, backend: str = "nccl", ): - logger.debug(f"{world_size=} {rank=} {local_rank=} " - f"{distributed_init_method=} {backend=}") + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index e0a871ebe1756..9a13b94c3ada1 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -112,7 +112,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: and (not os.path.exists(path)): # only the local master process (with local_rank == 0) can # enter this block to calculate the cache - logger.info(f"generating GPU P2P access cache for in {path}") + logger.info("generating GPU P2P access cache for in %s", path) cache = {} for _i in range(num_dev): for _j in range(num_dev): @@ -126,7 +126,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: if is_distributed: cpu_world_group = get_cpu_world_group() dist.barrier(cpu_world_group) - logger.info(f"reading GPU P2P access cache from {path}") + logger.info("reading GPU P2P access cache from %s", path) with open(path, "r") as f: cache = json.load(f) _gpu_p2p_access_cache = cache diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4b007d71e9cfc..518532e4a280d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -117,7 +117,7 @@ def process_request_output(self, self._request_streams[request_id].put(request_output) if request_output.finished: if verbose: - logger.info(f"Finished request {request_id}.") + logger.info("Finished request %s.", request_id) self.abort_request(request_id) def process_exception(self, @@ -128,7 +128,7 @@ def process_exception(self, """Propagate an exception from the engine.""" self._request_streams[request_id].put(exception) if verbose: - logger.info(f"Finished request {request_id}.") + logger.info("Finished request %s.", request_id) self.abort_request(request_id) def add_request(self, request_id: str, @@ -151,7 +151,7 @@ def add_request(self, request_id: str, def abort_request(self, request_id: str, *, verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: - logger.info(f"Aborted request {request_id}.") + logger.info("Aborted request %s.", request_id) self._finished_requests.put_nowait(request_id) @@ -521,11 +521,11 @@ async def add_request( if shortened_token_ids is not None: shortened_token_ids = shortened_token_ids[:self. max_log_len] - logger.info(f"Received request {request_id}: " - f"prompt: {shortened_prompt!r}, " - f"sampling_params: {sampling_params}, " - f"prompt_token_ids: {shortened_token_ids}, " - f"lora_request: {lora_request}.") + logger.info( + "Received request %s: prompt: %r, " + "sampling_params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, + sampling_params, shortened_token_ids, lora_request) if not self.is_running: if self.start_engine_loop: @@ -717,4 +717,4 @@ async def check_health(self) -> None: raise RuntimeError("Engine is dead.") from e else: await self.engine.check_health_async() - logger.debug(f"Health check took {time.perf_counter()-t}s") + logger.debug("Health check took %fs", time.perf_counter() - t) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7de60d738113e..d2f5379e621c6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -96,29 +96,38 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, ) -> None: logger.info( - f"Initializing an LLM engine (v{vllm.__version__}) with config: " - f"model={model_config.model!r}, " - f"speculative_config={speculative_config!r}, " - f"tokenizer={model_config.tokenizer!r}, " - f"skip_tokenizer_init={model_config.skip_tokenizer_init}, " - f"tokenizer_mode={model_config.tokenizer_mode}, " - f"revision={model_config.revision}, " - f"tokenizer_revision={model_config.tokenizer_revision}, " - f"trust_remote_code={model_config.trust_remote_code}, " - f"dtype={model_config.dtype}, " - f"max_seq_len={model_config.max_model_len}, " - f"download_dir={load_config.download_dir!r}, " - f"load_format={load_config.load_format}, " - f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " - f"disable_custom_all_reduce=" - f"{parallel_config.disable_custom_all_reduce}, " - f"quantization={model_config.quantization}, " - f"enforce_eager={model_config.enforce_eager}, " - f"kv_cache_dtype={cache_config.cache_dtype}, " - f"quantization_param_path={model_config.quantization_param_path}, " - f"device_config={device_config.device}, " - f"decoding_config={decoding_config!r}, " - f"seed={model_config.seed})") + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " + "max_seq_len=%d, download_dir=%r, load_format=%s, " + "tensor_parallel_size=%d, disable_custom_all_reduce=%s" + "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, seed=%d)", + vllm.__version__, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + model_config.seed, + ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config @@ -237,8 +246,10 @@ def _initialize_kv_caches(self) -> None: if self.cache_config.num_gpu_blocks_override is not None: num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override - logger.info(f"Overriding {num_gpu_blocks=} with " - f"{num_gpu_blocks_override=}") + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) num_gpu_blocks = num_gpu_blocks_override self.cache_config.num_gpu_blocks = num_gpu_blocks diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 25e96f6c7eaf7..d3560f5fefff1 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -227,14 +227,19 @@ def log(self, stats: Stats) -> None: # Log to stdout. logger.info( - f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " - f"Avg generation throughput: " - f"{generation_throughput:.1f} tokens/s, " - f"Running: {stats.num_running} reqs, " - f"Swapped: {stats.num_swapped} reqs, " - f"Pending: {stats.num_waiting} reqs, " - f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, " - f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%") + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%, " + "CPU KV cache usage: %.1f%", + prompt_throughput, + generation_throughput, + stats.num_running, + stats.num_swapped, + stats.num_waiting, + stats.gpu_cache_usage * 100, + stats.cpu_cache_usage * 100, + ) # Reset tracked stats for next interval. self.num_prompt_tokens = [] diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 37d76b8e74055..af9ba7a3bc825 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -148,8 +148,8 @@ async def authentication(request: Request, call_next): raise ValueError(f"Invalid middleware {middleware}. " f"Must be a function or a class.") - logger.info(f"vLLM API server version {vllm.__version__}") - logger.info(f"args: {args}") + logger.info("vLLM API server version %s", vllm.__version__) + logger.info("args: %s", args) if args.served_model_name is not None: served_model_names = args.served_model_name diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2ff335eb71073..f6011b6fc4cb6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -57,8 +57,7 @@ async def create_chat_completion( tokenize=False, add_generation_prompt=request.add_generation_prompt) except Exception as e: - logger.error( - f"Error in applying chat template from request: {str(e)}") + logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) request_id = f"cmpl-{random_uuid()}" @@ -338,11 +337,11 @@ def _load_chat_template(self, chat_template): tokenizer.chat_template = codecs.decode( chat_template, "unicode_escape") - logger.info( - f"Using supplied chat template:\n{tokenizer.chat_template}") + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) elif tokenizer.chat_template is not None: - logger.info( - f"Using default chat template:\n{tokenizer.chat_template}") + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) else: logger.warning( "No chat template provided. Chat API will not work.") diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 8d6a1fff91fd8..aa810f9743395 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -69,7 +69,7 @@ def initialize_cache(self, num_gpu_blocks: int, # NOTE: `cpu block` for CPU backend is located on CPU memory but is # referred as `gpu block`. Because we want to reuse the existing block # management procedure. - logger.info(f"# CPU blocks: {num_gpu_blocks}") + logger.info("# CPU blocks: %d", num_gpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model(self, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index d413a7d27ff37..d2c60a3b68e14 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -116,8 +116,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: # NOTE: This is logged in the executor because there can be >1 worker # with other executors. We could log in the engine level, but work # remains to abstract away the device for non-GPU configurations. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 14b3f803782c6..6f72babe14fd5 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -214,8 +214,8 @@ def initialize_cache(self, num_gpu_blocks: int, # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index febae42b84549..9db3ae2ff8298 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -43,9 +43,9 @@ def execute_model_compiled_dag_remote(self, ignored): return output except ImportError as e: - logger.warning(f"Failed to import Ray with {e!r}. " - "For distributed inference, please install Ray with " - "`pip install ray`.") + logger.warning( + "Failed to import Ray with %r. For distributed inference, " + "please install Ray with `pip install ray`.", e) ray = None # type: ignore RayWorkerWrapper = None # type: ignore diff --git a/vllm/logger.py b/vllm/logger.py index 341fc473585d7..3928e5367d1e6 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -126,7 +126,7 @@ def enable_trace_function_call(log_file_path: str, "VLLM_TRACE_FUNCTION is enabled. It will record every" " function executed by Python. This will slow down the code. It " "is suggested to be used for debugging hang or crashes only.") - logger.info(f"Trace frame log is saved to {log_file_path}") + logger.info("Trace frame log is saved to %s", log_file_path) if root_dir is None: # by default, this is the vllm root directory root_dir = os.path.dirname(os.path.dirname(__file__)) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index c249497a4d893..6a077e9b0c755 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -345,8 +345,8 @@ def activate_lora( index, _ = first_free_slot self._active_loras[lora_id] = None lora_model = self._registered_loras[lora_id] - logger.debug( - f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + logger.debug("Activating LoRA. int id: %d, slot index: %d", + lora_model.id, index) self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) @@ -567,7 +567,7 @@ def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], self.deactivate_lora_fn = deactivate_lora_fn def _on_remove(self, key: int, value: LoRAModel): - logger.debug(f"Removing LoRA. int id: {key}") + logger.debug("Removing LoRA. int id: %d", key) self.deactivate_lora_fn(key) return super()._on_remove(key, value) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ac7c30e2a9727..aed2c350bdd10 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -296,8 +296,8 @@ def get_moe_configs(E: int, N: int, os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: - logger.info( - f"Using configuration from {config_file_path} for MoE layer.") + logger.info("Using configuration from %s for MoE layer.", + config_file_path) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 16be0ecf9ce07..7e65d54bc522f 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -334,10 +334,10 @@ def deserialize(self): per_second = convert_bytes(deserializer.total_tensor_bytes / duration) after_mem = get_mem_usage() deserializer.close() - logger.info(f"Deserialized {total_bytes_str} in " - f"{end - start:0.2f}s, {per_second}/s") - logger.info(f"Memory usage before: {before_mem}") - logger.info(f"Memory usage after: {after_mem}") + logger.info("Deserialized %s in %0.2fs, %f/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) self._check_tensors_on_meta_device() self._resize_lora_embeddings() diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 9995f2afe3cf7..c0905b9051314 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -190,7 +190,7 @@ def download_weights_from_hf(model_name_or_path: str, allow_patterns = [pattern] break - logger.info(f"Using model weights format {allow_patterns}") + logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): @@ -310,17 +310,17 @@ def kv_cache_scales_loader( return layer_scales_map.items() except FileNotFoundError: - logger.error(f"File or directory '{filename}' not found.") + logger.error("File or directory '%s' not found.", filename) except json.JSONDecodeError: - logger.error(f"Error decoding JSON in file '{filename}'.") + logger.error("Error decoding JSON in file '%s'.", filename) except Exception as e: - logger.error(f"An error occurred while reading '{filename}': {e}") + logger.error("An error occurred while reading '%s': %s", filename, e) # This section is reached if and only if any of the excepts are hit # Return an empty iterable (list) => no KV cache scales are loaded # which ultimately defaults to 1.0 scales - logger.warning("Defaulting to KV cache scaling factors = 1.0 " - f"for all layers in TP rank {tp_rank} " - "as an error occurred during loading.") + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 for all " + "layers in TP rank %d as an error occurred during loading.", tp_rank) return [] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6afb2f31c1334..c5cdc059473b3 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -91,8 +91,8 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: "ROCm for now.") if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: logger.warning( - f"Model architecture {model_arch} is partially supported " - "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + "Model architecture %s is partially supported by ROCm: %s", + model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) module_name, model_cls_name = _MODELS[model_arch] module = importlib.import_module( @@ -107,9 +107,9 @@ def get_supported_archs() -> List[str]: def register_model(model_arch: str, model_cls: Type[nn.Module]): if model_arch in _MODELS: logger.warning( - f"Model architecture {model_arch} is already registered, " - "and will be overwritten by the new model " - f"class {model_cls.__name__}.") + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 6d01537c5c344..c3193258d6418 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -55,10 +55,10 @@ def _get_gemma_act_fn( "in the config JSON file when it was initially released. " "Changing the activation function to approximate GeLU " "(`gelu_pytorch_tanh`). If you want to use the legacy " - f"`{hidden_act}`, edit the config JSON to set " - f"`hidden_activation={hidden_act}` instead of `hidden_act`. " + "`%s`, edit the config JSON to set " + "`hidden_activation=%s` instead of `hidden_act`. " "See https://github.com/huggingface/transformers/pull/29402 " - "for more details.") + "for more details.", hidden_act, hidden_act) return GeluAndMul(approximate="tanh") elif hidden_activation == "gelu_pytorch_tanh": return GeluAndMul(approximate="tanh") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 2c6642f5a3c81..4e70ea9686005 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -183,7 +183,8 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") - logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}") + logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", + num_lookahead_slots) # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py index 1d2724f22abd6..0dc9664723d34 100644 --- a/vllm/transformers_utils/configs/dbrx.py +++ b/vllm/transformers_utils/configs/dbrx.py @@ -72,9 +72,10 @@ def from_pretrained( and config_dict["model_type"] != cls.model_type ): logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], cls.model_type) return cls.from_dict(config_dict, **kwargs) @@ -151,9 +152,9 @@ def from_pretrained( and config_dict["model_type"] != cls.model_type ): logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type) return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index afc02c434dd43..2fcddc3bea5ab 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -138,9 +138,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, # No tokenizer was found in the LoRA folder, # use base model tokenizer logger.warning( - f"No tokenizer found in {lora_request.lora_local_path}, " - "using base model tokenizer instead. " - f"(Exception: {str(e)})") + "No tokenizer found in %s, using base model tokenizer instead. " + "(Exception: %s)", lora_request.lora_local_path, e) tokenizer = None return tokenizer diff --git a/vllm/utils.py b/vllm/utils.py index 79ac1db01fc69..76c2fc66e47c3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -289,8 +289,9 @@ def get_open_port() -> int: def update_environment_variables(envs: Dict[str, str]): for k, v in envs.items(): if k in os.environ and os.environ[k] != v: - logger.warning(f"Overwriting environment variable {k} " - f"from '{os.environ[k]}' to '{v}'") + logger.warning( + "Overwriting environment variable %s " + "from '%s' to '%s'", k, os.environ[k], v) os.environ[k] = v @@ -310,11 +311,12 @@ def get_nvcc_cuda_version() -> Optional[Version]: if not cuda_home: cuda_home = '/usr/local/cuda' if os.path.isfile(cuda_home + '/bin/nvcc'): - logger.info(f'CUDA_HOME is not found in the environment. ' - f'Using {cuda_home} as CUDA_HOME.') + logger.info( + 'CUDA_HOME is not found in the environment. ' + 'Using %s as CUDA_HOME.', cuda_home) else: - logger.warning( - f'Not found nvcc in {cuda_home}. Skip cuda version check!') + logger.warning('Not found nvcc in %s. Skip cuda version check!', + cuda_home) return None nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True) @@ -599,8 +601,8 @@ def find_nccl_library(): # manually load the nccl library if so_file: logger.info( - f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}" - ) + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) else: if torch.version.cuda is not None: so_file = vllm_nccl_path or find_library("libnccl.so.2") @@ -608,7 +610,7 @@ def find_nccl_library(): so_file = find_library("librccl.so.1") else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info(f"Found nccl from library {so_file}") + logger.info("Found nccl from library %s", so_file) return so_file diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 33dbf8d90c35d..c6da28f110325 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -170,8 +170,8 @@ def load_model(self) -> None: ) self.model_memory_usage = m.consumed_memory - logger.info(f"Loading model weights took " - f"{self.model_memory_usage / float(2**30):.4f} GB") + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) if self.lora_config: assert hasattr(self.model, "supported_lora_modules" @@ -196,18 +196,19 @@ def load_model(self) -> None: self.model.load_kv_cache_scales( self.model_config.quantization_param_path) else: - raise RuntimeError("Using FP8 KV cache and scaling " - "factors provided but model " - f"{self.model.__class__} does not " - "support loading scaling factors.") + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__) else: - logger.warn("Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") elif self.model_config.quantization_param_path is not None: - logger.warn("KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used.") + logger.warning("KV cache scaling factors provided, " + "but the KV cache data type is not FP8. " + "KV cache scaling factors will not be used.") def set_block_size(self, block_size: int) -> None: self.block_size = block_size @@ -1054,7 +1055,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: end_time = time.perf_counter() elapsed_time = end_time - start_time # This usually takes < 10 seconds. - logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") + logger.info("Graph capturing finished in %.0f secs.", elapsed_time) def __del__(self) -> None: # Delete the CUDA graphs before deleting the pynccl communicator. From 603ad8481594321ceae7d54e2c0050b3638c6502 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 26 Apr 2024 22:02:02 +0900 Subject: [PATCH 130/413] [Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309) --- tests/samplers/test_logprobs.py | 44 +- tests/samplers/test_sampler.py | 47 +- tests/test_logits_processor.py | 10 +- tests/worker/test_model_runner.py | 19 +- vllm/core/scheduler.py | 15 + vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 25 +- vllm/engine/output_processor/interfaces.py | 6 + vllm/engine/output_processor/multi_step.py | 9 + vllm/engine/output_processor/single_step.py | 22 +- vllm/engine/output_processor/util.py | 7 +- .../model_executor/layers/logits_processor.py | 27 +- vllm/model_executor/layers/sampler.py | 544 ++++++++++++------ vllm/model_executor/sampling_metadata.py | 349 ++++++++--- vllm/sequence.py | 11 +- vllm/worker/cpu_model_runner.py | 116 +--- vllm/worker/model_runner.py | 123 +--- vllm/worker/neuron_model_runner.py | 119 +--- 18 files changed, 862 insertions(+), 633 deletions(-) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 41b7f3da1e839..57d6d2a410ee5 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -9,15 +9,26 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size def test_get_prompt_logprobs( hf_runner, vllm_runner, model, dtype, + chunked_prefill_token_size: int, + num_top_logprobs: int, example_prompts, ): + max_num_seqs = 256 + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) + max_num_batched_tokens = chunked_prefill_token_size + max_tokens = 5 - num_top_logprobs = 6 hf_model = hf_runner(model, dtype=dtype) hf_logprobs = hf_model.generate_greedy_logprobs( example_prompts, @@ -25,10 +36,17 @@ def test_get_prompt_logprobs( ) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) + vllm_model = vllm_runner( + model, + dtype=dtype, + max_logprobs=num_top_logprobs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + ) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, logprobs=num_top_logprobs, - prompt_logprobs=5, + prompt_logprobs=num_top_logprobs, temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) @@ -52,9 +70,18 @@ def test_get_prompt_logprobs( "The output text from the top logprob for each token position " "should be the same as the output text in the result.") + # The first prompt logprob is always None + assert result.prompt_logprobs[0] is None + for prompt_logprobs in result.prompt_logprobs[1:]: + # If the prompt token is not included in the top X + # logprob, it can return 1 more data + assert (len(prompt_logprobs) == num_top_logprobs + or len(prompt_logprobs) == num_top_logprobs + 1) + # Test whether prompt logprobs are consistent with HF for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): # Check prompt logprobs + # The first prompt logprob is always None, so we compare it from 1:. vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for token_id, logprob in vllm_prompt_logprob_dict.items(): @@ -74,6 +101,17 @@ def test_get_prompt_logprobs( "The token should be decoded by the time it is returned " " to the user.") + # Test if prompt logprobs are correctly set. + for vllm_result in vllm_results: + token_ids = vllm_result.prompt_token_ids + prompt_logprobs = vllm_result.prompt_logprobs + + # The first token doesn't have logprob. + assert prompt_logprobs[0] is None + + for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]): + assert token_id in logprob_dict + def test_max_logprobs(): runner = VllmRunner("facebook/opt-125m", max_logprobs=1) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 52a2b0ca52aaa..6f2145f8cdcf4 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -8,6 +8,7 @@ from transformers import GenerationConfig, GenerationMixin from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import Counter @@ -54,6 +55,7 @@ def _do_sample( sampler: MockLogitsSampler, model_runner: ModelRunner, sampling_params: SamplingParams, + device: str, ): seq_group_metadata_list = [] prompt_lens = [] @@ -68,9 +70,12 @@ def _do_sample( )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=device, + pin_memory=model_runner.pin_memory) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str): sampling_params = SamplingParams(temperature=0) sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str): n=random.randint(1, 10), ) sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str): seed=random.randint(0, 10000), ) sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params) + model_runner, sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params) + model_runner, sampling_params, device) assert first_sampler_output == second_sampler_output @@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str): best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) + _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, + device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -443,10 +449,12 @@ def run_test_case(*, "batch size") _, fake_logits, sampler, model_runner = _prepare_test(batch_size) - sampling_metadata = model_runner._prepare_sample( + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, prompt_lens=prompt_lens if prompt_lens else None, - subquery_lens=prompt_lens if prompt_lens else None) + subquery_lens=prompt_lens if prompt_lens else None, + device=device, + pin_memory=model_runner.pin_memory) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): - sampling_metadata = model_runner._prepare_sample( - seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=device, + pin_memory=model_runner.pin_memory) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=device, + pin_memory=model_runner.pin_memory) sample_probs = None diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 5bb93ca74855b..dbaeb4de18258 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -6,6 +6,7 @@ import torch from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner @@ -82,9 +83,12 @@ def pick_ith(token_ids, logits): )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 59bed2ce0dad3..abb401f25c100 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,6 +2,7 @@ import torch from vllm.config import ModelConfig, SchedulerConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size): assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) assert len(input_tokens) == sum(prompt_lens) assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices @@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size): for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ac3bd7d228e94..7439f7dc33e8d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -915,6 +915,20 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: self.block_manager.get_common_computed_block_ids( seq_group.get_seqs(status=SequenceStatus.RUNNING))) + do_sample = True + if seq_group.is_prefill(): + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < + seqs[0].data.get_len()): + do_sample = False + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. is_prompt = seq_group.is_prefill() @@ -924,6 +938,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + do_sample=do_sample, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 518532e4a280d..89ee3f0db491c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -219,7 +219,7 @@ async def step_async(self) -> List[RequestOutput]: request_outputs = self._process_model_outputs( output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups) + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. if self.log_stats: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d2f5379e621c6..741d3bcd80890 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceStage) + SequenceGroup, SequenceGroupMetadata) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -476,9 +476,12 @@ def has_unfinished_requests(self) -> bool: return self.scheduler.has_unfinished_seqs() def _process_model_outputs( - self, output: List[SamplerOutput], - scheduled_seq_groups: List[SequenceGroup], - ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: + self, + output: List[SamplerOutput], + scheduled_seq_groups: List[SequenceGroup], + ignored_seq_groups: List[SequenceGroup], + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> List[RequestOutput]: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. @@ -492,17 +495,15 @@ def _process_model_outputs( sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, - output_by_sequence_group): + for scheduled_seq_group, outputs, seq_group_meta in zip( + scheduled_seq_groups, output_by_sequence_group, + seq_group_metadata_list): seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - # If all sequences in the sequence group are in DECODE, then we can - # process the output tokens. Otherwise, they are (chunked) prefill - # samples and should not be processed. - stages = [seq.data._stage for seq in seq_group.seqs_dict.values()] - if all(stage == SequenceStage.DECODE for stage in stages): + self.output_processor.process_prompt_logprob(seq_group, outputs) + if seq_group_meta.do_sample: self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -585,7 +586,7 @@ def step(self) -> List[RequestOutput]: request_outputs = self._process_model_outputs( output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups) + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. if self.log_stats: diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index f307ea4da3011..9ddb6a3648b8c 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -68,3 +68,9 @@ def process_outputs(self, sequence_group: SequenceGroup, scheduler. """ pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 39e99d06ed875..9abd87a4d5a9a 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -44,6 +44,15 @@ def __init__( self.get_tokenizer_for_seq = get_tokenizer_for_seq self.stop_checker = stop_checker + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + # TODO(sang): Prompt logprob currently not implemented in multi step + # workers. + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + pass + def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: """Append new tokens in the outputs to sequences in the sequence group. diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 7e9d652446703..07b140584bbe2 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -55,17 +55,23 @@ def process_outputs(self, sequence_group: SequenceGroup, ), f"{type(self)} does not support multiple outputs per step" return self._process_sequence_group_outputs(sequence_group, outputs[0]) - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: - - # Process prompt logprobs - prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and \ - seq_group.sampling_params.detokenize and self.detokenizer: + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + assert len(outputs) == 1, ("Single step should only has 1 output.") + output = outputs[0] + prompt_logprobs = output.prompt_logprobs + if (prompt_logprobs is not None + and seq_group.sampling_params.detokenize and self.detokenizer): self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) - seq_group.prompt_logprobs = prompt_logprobs + if not seq_group.prompt_logprobs: + # The first prompt token's logprob is None because it doesn't + # have tokens that are precedent. + seq_group.prompt_logprobs = [None] + seq_group.prompt_logprobs.extend(prompt_logprobs) + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index d076fee8c2a36..9816e966c1e36 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,10 +1,11 @@ from typing import List -from vllm.sequence import SamplerOutput +from vllm.sequence import SamplerOutput, SequenceGroupOutput -def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], - num_seq_groups: int): +def create_output_by_sequence_group( + sampler_outputs: List[SamplerOutput], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e556e31f99378..22620d9fc86d9 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -83,30 +83,27 @@ def _apply_logits_processors( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - logits_row_idx = 0 found_logits_processors = False - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group + logits_processed = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params logits_processors = sampling_params.logits_processors - # handle prompt_logprobs by skipping rows in logits added for - # the prompt tokens (prompt logprobs are not processed) - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - assert len(seq_ids) == 1 - logits_row_idx += sampling_metadata.prompt_lens[i] - 1 if logits_processors: found_logits_processors = True - for seq_id in seq_ids: + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + token_ids = seq_group.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + if found_logits_processors: # verifies that no rows in logits were missed unexpectedly - assert logits_row_idx == logits.shape[0] + assert logits_processed == logits.shape[0] return logits diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c4b11cb33a677..2ffa8227cc4ed 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -7,11 +7,11 @@ from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors) -from vllm.sampling_params import SamplingParams, SamplingType + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, - SamplerOutput, SequenceData, SequenceGroupOutput, - SequenceOutput) + SamplerOutput, SequenceGroupOutput, SequenceOutput) class Sampler(nn.Module): @@ -48,11 +48,14 @@ def forward( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: + """ + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ assert logits is not None _, vocab_size = logits.shape - # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - # have not been generated yet logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. @@ -83,7 +86,6 @@ def forward( # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. - # Use log_softmax to ensure numerical stability. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. @@ -149,24 +151,28 @@ def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + have not been generated yet + """ # list of indices in logits that will be set to -inf logits_to_penalize = [] - start_idx = 0 - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - - # handle prompt_logprobs by skipping rows in logits added for the prompt - # tokens (prompt logprobs are not penalized) - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - assert len(seq_ids) == 1 - start_idx += sampling_metadata.prompt_lens[i] - 1 + logits_applied = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + + sample_indices = seq_group.sample_indices + logits_applied += len(sample_indices) + len( + seq_group.prompt_logprob_indices) + if not seq_group.do_sample: + continue + start_idx = sample_indices[0] min_tokens = sampling_params.min_tokens if min_tokens > 0: seqs_to_penalize = [] for i, seq_id in enumerate(seq_ids): - seq_data = sampling_metadata.seq_data[seq_id] + seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids) < min_tokens: seqs_to_penalize.append(i) @@ -180,15 +186,13 @@ def _apply_min_tokens_penalty( logits_to_penalize.extend( itertools.product(seqs_to_penalize, token_ids_to_penalize)) - start_idx += len(seq_ids) - if logits_to_penalize: # use zip and * to group indices along each dimension # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") # verifies that no rows in logits were missed unexpectedly - assert start_idx == logits.shape[0] + assert logits_applied == logits.shape[0] return logits @@ -265,14 +269,30 @@ def _apply_min_p( def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], + selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + """Run greedy sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ samples = samples.tolist() sample_idx = 0 results = [] for seq_group in selected_seq_groups: - seq_ids, _ = seq_group + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids num_parent_seqs = len(seq_ids) assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") @@ -284,16 +304,33 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], + selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + """Run random sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + random_samples: (num_selected_samples,) A tensor of samples. The + length of samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ # Find the maximum best_of value of the prompt phase requests. random_samples = random_samples.cpu() sample_idx = 0 results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt num_parent_seqs = len(seq_ids) if is_prompt: # Prompt phase. @@ -311,11 +348,20 @@ def _random_sample( def _beam_search_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - seq_data: Dict[int, SequenceData], + selected_seq_groups: List[SequenceGroupToSample], logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + """Run beam sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + logprobs: (num_selected_samples, vocab_size,) A tensor of logprob + on selected sample indices. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ # We sample 2 * beam_width candidates to make sure that with high # probability we can get `beam_width` candidates in addition to # the finished sequences for the next iteration. See @@ -327,8 +373,13 @@ def _beam_search_sample( # other sampling methods. sample_idx = 0 results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + is_prompt = seq_group.is_prompt + seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params num_parent_seqs = len(seq_ids) beam_width = sampling_params.best_of seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] @@ -343,7 +394,8 @@ def _beam_search_sample( else: # Generation phase. cumulative_logprobs = [ - seq_data[seq_id].cumulative_logprob for seq_id in seq_ids + seq_group.seq_data[seq_id].cumulative_logprob + for seq_id in seq_ids ] cumulative_logprobs = torch.tensor( cumulative_logprobs, @@ -371,8 +423,7 @@ def _beam_search_sample( def _multinomial( probs: torch.Tensor, num_samples: int, - seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, - generators: Optional[List[torch.Generator]] = None, + seq_groups: Optional[List[SequenceGroupToSample]] = None, ) -> torch.Tensor: if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also @@ -388,9 +439,11 @@ def _multinomial( q.exponential_() else: sample_idx = 0 - for (seq_ids, _), generator in zip(seq_groups, generators): + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids next_sample_idx = sample_idx + len(seq_ids) * num_samples - q[sample_idx:next_sample_idx].exponential_(generator=generator) + q[sample_idx:next_sample_idx].exponential_( + generator=seq_group.generator) sample_idx = next_sample_idx return probs.div_(q).argmax(dim=1).view(-1, num_samples) @@ -405,7 +458,7 @@ def _sample_with_torch( categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group + sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -429,13 +482,11 @@ def _sample_with_torch( num_tokens = len(sample_indices) if num_tokens == 0: continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices) - long_sample_indices = sample_indices.long() + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() if sampling_type == SamplingType.GREEDY: greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) @@ -455,14 +506,13 @@ def _sample_with_torch( elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of_in_batch = 1 - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, - "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( @@ -481,25 +531,22 @@ def _sample_with_torch( # GPU<->CPU sync happens in the loop below. # This also converts the sample output to Python objects. - for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ - sampling_type] + (seq_group_id, seq_groups) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, greedy_samples) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, is_prompts, + sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, + sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) + sample_results_dict.update(zip(seq_group_id, sample_results)) sample_results = [ - sample_results_dict[i] + sample_results_dict.get(i, ([], [])) for i in range(len(sampling_metadata.seq_groups)) ] return sample_results, sampled_token_ids_tensor @@ -514,7 +561,7 @@ def _sample_with_triton_kernel( categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group + sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -530,17 +577,16 @@ def _sample_with_triton_kernel( num_tokens = len(sample_indices) if num_tokens == 0: continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices, + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups, + sample_indices, sampled_token_indices) if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, SamplingType.RANDOM_SEED): - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) elif sampling_type == SamplingType.BEAM: @@ -564,22 +610,21 @@ def _sample_with_triton_kernel( for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - (seq_group_ids, seq_groups, is_prompts, sample_indices, + (seq_group_id, seq_groups, sample_indices, sampled_token_indices) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample( seq_groups, sampled_tokens[sampled_token_indices][:, 0]) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample( - seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + seq_groups, sampled_tokens[sampled_token_indices]) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, + sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) + sample_results_dict.update(zip(seq_group_id, sample_results)) sample_results = [ - sample_results_dict[i] + sample_results_dict.get(i, ([], [])) for i in range(len(sampling_metadata.seq_groups)) ] return sample_results @@ -590,6 +635,18 @@ def _sample( sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: + """ + Args: + probs: (num_query_tokens_in_batch, num_vocab) + logprobs: (num_query_tokens_in_batch, num_vocab) + sampling_metadata: The metadata for a batch for sampling. + sampling_tensors: Tensors that include sampling related metadata. + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. + """ return _sample_with_torch( probs, logprobs, @@ -626,56 +683,97 @@ def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: List[Tuple[List[int], List[int]]], -) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ - int, float]]]]: - # Prepare query indices - batched_logprobs_query_seq_indices: List[int] = [] - batched_logprobs_query_token_indices: List[int] = [] - # at least get one logprob for each token +) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: + """Return sample lobprobs and prompt logprobs. + + The logic consists of 3 parts. + - Select indices to compute logprob from, ranks of token ids, and + the top k token ids from logprobs. + - Compute prompt logprobs if required. + - Compute sample logprobs if required. + + Args: + logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's + logprob per vocab. Sequence groups' query tokens are batched in a + single flattened tensor. For example, assuming there are N + seq groups, it is sorted by prefill tokens for seq_group_1 (if + prompt logprob is enabled), decode tokens for seq_group_1 (if + sampling is required), prefill tokens for seq_group_2, ... + sampling_metadata: The sampling metadata. + sample_results: (num_seq_groups) The tuple of (next_token_ids, + parent_ids) for each sequence group. When beam search is enabled, + sample_results can contain different number of seq_ids from + sampling_metadata.seq_groups. It is because beam search creates + 2 * BEAM_WIDTH number of samples (whereas there are only up to + BEAM_WIDTH number of seq_ids). + + Returns: + A tuple of prompt and sample logprobs per sequence group in a batch. + """ + # The index of query token to calculate logprobs. It includes both + # prompt and sample logprob indices. + query_indices: List[int] = [] + # The next token ids to get the logprob value from. + next_token_ids: List[int] = [] + # The largest requested number of logprobs. We find logprobs as many as the + # largest num logprobs in this API. largest_num_logprobs = 1 - sample_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - num_parent_seqs = len(seq_ids) - if (i < sampling_metadata.num_prompts + + # Select indices to compute logprob from, ranks of token ids, and the top + # k token ids from logprobs. + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + sample_results): + sampling_params = seq_group.sampling_params + + # Update indices and tokens for prompt logprobs. + if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): largest_num_logprobs = max(largest_num_logprobs, sampling_params.prompt_logprobs) - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - batched_logprobs_query_seq_indices.extend( - sample_idx + j for j in range(prompt_len - 1)) - batched_logprobs_query_token_indices.extend( - token_id for token_id in prompt_tokens[1:]) - sample_idx += prompt_len - 1 - batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) - batched_logprobs_query_token_indices.extend(next_token_ids) - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) - - batched_logprobs_query_seq_indices_gpu = torch.tensor( - batched_logprobs_query_seq_indices, device=logprobs.device) - batched_logprobs_query_token_indices_gpu = torch.tensor( - batched_logprobs_query_token_indices, device=logprobs.device) - - # Batched query for logprobs of selected token - batched_logprobs_query_result = logprobs[[ - batched_logprobs_query_seq_indices_gpu, - batched_logprobs_query_token_indices_gpu + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + query_indices.extend(seq_group.prompt_logprob_indices) + next_token_ids.extend(next_prompt_tokens) + + # Update indices and next tokenes for sample logprob. + if seq_group.do_sample: + token_ids, parent_seq_ids = sample_result + # NOTE: We cannot directly use sample_indices because + # sample_indices only contain parent seq_ids of a previous step. + # The current step may have different number of seq_ids, and + # we can obtain it from `sample_result[1]`. + query_idx = seq_group.sample_indices[0] + query_indices.extend( + [query_idx + parent_id for parent_id in parent_seq_ids]) + next_token_ids.extend(token_ids) + + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + assert len(next_token_ids) == len(query_indices) + + if len(query_indices) == 0: + empty_sampled_logprob = [] + empty_prompt_logprob = None + return [empty_prompt_logprob], [empty_sampled_logprob] + + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] - batched_ranks_query_result = _get_ranks( - logprobs[batched_logprobs_query_seq_indices_gpu], - batched_logprobs_query_token_indices_gpu) - - # Batched query for logprobs of topk tokens + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). if largest_num_logprobs > 0: top_logprobs, top_token_ids = torch.topk(logprobs, largest_num_logprobs, @@ -685,79 +783,136 @@ def _get_logprobs( else: top_logprobs, top_token_ids = None, None - batched_logprobs_query_result = batched_logprobs_query_result.cpu() - batched_ranks_query_result = batched_ranks_query_result.cpu() - - # Gather results - result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] - result_sample_logprobs: List[SampleLogprobs] = [] - sample_idx = 0 - query_result_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result + selected_logprobs = selected_logprobs.cpu() + ranks = ranks.cpu() + + # Find prompt/sample logprobs. + prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + top_logprob_idx = 0 + selected_logprobs_idx = 0 + + for seq_group, sample_result in zip(sampling_metadata.seq_groups, + sample_results): + (prompt_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_prompt_logprob_if_needed( + seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, + selected_logprobs_idx, top_logprob_idx) + prompt_logprobs_per_seq_group.append(prompt_logprobs) + + (sampled_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_sampled_logprob_if_needed( + seq_group, sample_result, selected_logprobs, ranks, top_token_ids, + top_logprobs, selected_logprobs_idx, top_logprob_idx) + sample_logprobs_per_seq_group.append(sampled_logprobs) + + return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group + + +def _get_prompt_logprob_if_needed( + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the prompt logprob from a sequence group if needed.""" + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + + # Find prompt logprobs + prompt_logprobs: Optional[PromptLogprobs] = None + if (is_prompt and sampling_params.prompt_logprobs is not None): + prompt_logprobs = [] + num_logprobs = sampling_params.prompt_logprobs + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + for token_id in next_prompt_tokens: + # Calculate the prompt logprob of the real prompt tokens. + # Use tuple here for performance (to use to_list()). + # {token_id: (logprob, rank_from_vocab)} + prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + token_id: (selected_logprobs[selected_logprobs_idx].item(), + ranks[selected_logprobs_idx].item()) + } - # Prompt logprobs - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - num_logprobs = sampling_params.prompt_logprobs - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - group_prompt_logprobs: PromptLogprobs = [None] - for token_id in prompt_tokens[1:]: - prompt_logprobs_dict = { - token_id: - (batched_logprobs_query_result[query_result_idx].item(), - batched_ranks_query_result[query_result_idx].item()) - } - if num_logprobs > 0: - prompt_logprobs_dict.update( + # Add top K prompt logprobs along with its rank. + if num_logprobs > 0: + prompt_logprobs_dict.update( + zip( + top_token_ids[top_logprob_idx, :num_logprobs].tolist(), zip( - top_token_ids[sample_idx, :num_logprobs].tolist(), - zip( - top_logprobs[ - sample_idx, :num_logprobs].tolist(), - range(1, num_logprobs + 1)))) - group_prompt_logprobs.append({ - token_id: Logprob(*logprob_rank) - for token_id, logprob_rank in prompt_logprobs_dict.items() - }) - sample_idx += 1 - query_result_idx += 1 - result_prompt_logprobs.append(group_prompt_logprobs) - else: - result_prompt_logprobs.append(None) - - # Sample logprobs - num_logprobs = sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 - group_sample_logprobs: SampleLogprobs = [] - for next_token_id, parent_id in zip(next_token_ids, parent_ids): - sample_logprobs_dict = { + top_logprobs[ + top_logprob_idx, :num_logprobs].tolist(), + # This is ranks. Since top_logprob is sorted, + # we can just use a range here. + range(1, num_logprobs + 1)))) + prompt_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in prompt_logprobs_dict.items() + }) + # + 1 to go to the next prompt token. + top_logprob_idx += 1 + selected_logprobs_idx += 1 + return prompt_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _get_sampled_logprob_if_needed( + seq_group: SequenceGroupToSample, + sample_result: Tuple[List[int], List[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the sample logprob if needed.""" + seq_ids = seq_group.seq_ids + num_logprobs = seq_group.sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + sampled_logprobs: SampleLogprobs = [] + next_token_ids, parent_seq_ids = sample_result + + if seq_group.do_sample: + assert len(next_token_ids) > 0 + for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): + # Calculate the sample logprob of the real sampled tokens. + # Use tuple here for performance (to use to_list()). + # token_id: (logprob, rank_from_vocab) + sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { next_token_id: - (batched_logprobs_query_result[query_result_idx].item(), - batched_ranks_query_result[query_result_idx].item()) + (selected_logprobs[selected_logprobs_idx].item(), + ranks[selected_logprobs_idx].item()) } - query_result_idx += 1 + # +1 to go to the next sampled token. Note that + # selected_logprobs can contain duplicates unlike top_logprobs + # when beam search is enabled. + selected_logprobs_idx += 1 + + # Second, add top K logprobs along with its rank. if num_logprobs >= 0: - sample_logprobs_dict.update( + sampled_logprobs_dict.update( zip( - top_token_ids[sample_idx + + top_token_ids[top_logprob_idx + parent_id, :num_logprobs].tolist(), zip( - top_logprobs[sample_idx + + top_logprobs[top_logprob_idx + parent_id, :num_logprobs].tolist(), + # This is rank. Since top_logprob is sorted, we + # can just use a range here. range(1, num_logprobs + 1)))) - group_sample_logprobs.append({ - token_id: Logprob(*logprob_rank) - for token_id, logprob_rank in sample_logprobs_dict.items() + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() }) - result_sample_logprobs.append(group_sample_logprobs) - sample_idx += len(seq_ids) - - return result_prompt_logprobs, result_sample_logprobs + # There are len(seq_ids) number of sampled tokens for the current + # sequence group in top_logprobs. Jump to the next seq_group. + top_logprob_idx += len(seq_ids) + return sampled_logprobs, top_logprob_idx, selected_logprobs_idx def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, @@ -832,7 +987,7 @@ def _build_sampler_output( group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs): - seq_ids, _ = seq_group + seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs = [] for parent_id, next_token_id, logprobs in zip(parent_ids, @@ -854,3 +1009,36 @@ def _build_sampler_output( sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, ) + + +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: + """Get a list of next prompt tokens to compute logprob from a + given sequence group. + + It is used to compute prompt logprob. Imagine you have logprob for each + query token. Query token needs to know the next prompt token id to compute + prompt logprob. This is a helper to obtain next prompt token ids. + + This API has to be used only when the caller knows seq_group is in prefill + stage. + + Returns: + A list of next prompt tokens to compute logprob. + """ + assert seq_group.is_prompt, ( + "Caller should ensure the sequence group is in a prefill stage.") + seq_ids = seq_group.seq_ids + subquery_len = seq_group.subquery_len + assert subquery_len is not None + # prompt has only 1 seq id. + assert len(seq_ids) == 1 + seq_data = seq_group.seq_data[seq_ids[0]] + computed_len = seq_data.get_num_computed_tokens() + prompt_tokens = seq_data.prompt_token_ids + # +1 because we are looking for a next prompt token. + next_token_index_start = computed_len + 1 + next_token_index_end = min(computed_len + subquery_len + 1, + len(prompt_tokens)) + next_prompt_tokens = prompt_tokens[ + next_token_index_start:next_token_index_end] + return next_prompt_tokens diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 31032c4cead20..12156b2ba1aa2 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -6,57 +6,275 @@ from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SequenceData -from vllm.utils import is_pin_memory_available +from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.utils import (async_tensor_h2d, is_pin_memory_available, + maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +@dataclass +class SequenceGroupToSample: + # Sequence ids for the sequence group in a previous step. + seq_ids: List[int] + sampling_params: SamplingParams + # seq_id -> sequence data. + seq_data: Dict[int, SequenceData] + # The length of the prompt of the sequence group. None if it is in a decode + # stage. + prompt_len: Optional[int] + # The length of the query tokens to compute in the current step. None if it + # is in a decode stage. The length of subquery_len <= prompt_len. + subquery_len: Optional[int] + # A random number generator for sampling. + generator: Optional[torch.Generator] + # True if the sequence group is in prefill stage. False if it is in a + # decode stage. + is_prompt: bool + # Query token indices from logits. to compute prompt logprob. Empty if + # prompt logprob is not required. + prompt_logprob_indices: List[int] + # Sample token indices from logits. Empty if sampling is not required. + sample_indices: List[int] + + @property + def do_sample(self): + return len(self.sample_indices) > 0 + + def __post_init__(self): + if len(self.prompt_logprob_indices) > 0: + assert self.sampling_params.prompt_logprobs is not None + if self.is_prompt: + assert self.prompt_len is not None + assert self.subquery_len is not None + + class SamplingMetadata: """Metadata for input sequences. Used in sampler. + The usage is as follow; + ``` + hidden_states = execute_model(...) + logits = hidden_states[sampling_metadata.selected_token_indices] + sample(logits) + + def sample(logits): + # Use categorized_sample_indices for sampling.... + ``` + Args: - seq_groups: List of (seq_ids, sampling_params). - seq_data: Seq_id -> SequenceData. - prompt_lens: Lengths of prompts. - selected_token_indices: Token indices selected for sampling. + seq_groups: List of batched sequence groups. + selected_token_indices: (num_query_tokens_to_logprob). Indices to find + logits from the initial model output hidden states. categorized_sample_indices: SamplingType -> token indices to sample. - generators: List of torch.Generators to use for seeded sampling - perform_sampling: Whether to perform sampling. This option is used to - make the sampling only happens in the driver worker, and disable - sampling in other worker processes. + Each token indices is 2D tensor of (num_indices, num_indices) where + the first item means the sample index within the returned logit + (before pruning padding), and the second item means the sample + index after pruning using selected_token_indices. + For example, if the returned logit is [1, 2, 3], and we select + [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, + The first tuple is [1, 2] (sampled index within original logit), + and the second tuple is [0, 1] (sampled index within pruned logit). + num_prompts: Number of prompt sequence groups in seq_groups. """ def __init__( self, - seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], - seq_data: Optional[Dict[int, SequenceData]], - prompt_lens: Optional[List[int]], + seq_groups: List[SequenceGroupToSample], selected_token_indices: torch.Tensor, - categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], - generators: Optional[List[torch.Generator]] = None, - perform_sampling: bool = True, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], + num_prompts: int, ) -> None: self.seq_groups = seq_groups - self.seq_data = seq_data - self.prompt_lens = prompt_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices - self.generators = generators - self.perform_sampling = perform_sampling + self.num_prompts = num_prompts - self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 + @staticmethod + def prepare( + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + subquery_lens: Optional[List[int]], + device: str, + pin_memory: bool, + ) -> "SamplingMetadata": + ( + seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts, + ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, + subquery_lens, device) + selected_token_indices = async_tensor_h2d(selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + categorized_sample_indices = { + t: maybe_expand_dim( + async_tensor_h2d(seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + num_prompts=num_prompts, + ) + return sampling_metadata def __repr__(self) -> str: return ( "SamplingMetadata(" f"seq_groups={self.seq_groups}, " - f"seq_data={self.seq_data}, " - f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices}), " - f"perform_sampling={self.perform_sampling})") + f"categorized_sample_indices={self.categorized_sample_indices}), ") + + +def _prepare_seq_groups( + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + subquery_lens: Optional[List[int]], + device: str, +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ + SamplingType, List[Tuple[int, int]]], int]: + """Prepare sequence groups and indices for sampling. + + Args: + seq_group_metadata_list: A list of sequence group to batch. + prompt_lens: A list of prompt lens per sequence group. + Index of prompt len should match with seq_group_metadata_list. + subquery_lens: A list of query lengths. Prompt lens include the length + of entire prompt tokens, and it could be shorter. + device: A device to use for random number generator, + `SequenceGroupToSample.generator`. + + Returns: + seq_groups: A list of sequence group to sample. + selected_token_indices: See the definition from `SamplingMetadata`. + categorized_sample_indices: See the definition from `SamplingMetadata`. + num_prompts: Total number of prompts from `seq_group_metadata_list`. + """ + # Batched sequence groups for the current model forward stsep. + seq_groups: List[SequenceGroupToSample] = [] + # A list of token indices to sample/compute logprob. It is used to + # prune the outcome logits from the model for the performance. + selected_token_indices: List[int] = [] + # Used for selected_token_indices. + model_output_idx = 0 + + # Sampling type -> ( + # indices to sample/prompt logprob within pruned output logits, + # indices to sample within pruned logits) + categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } + # Index of logits to compute logprob. Logits include both prompt logprob + # and sample logprob indices. + logit_idx = 0 + # Index to sample from a sample tensor. It is used by triton sample kernel. + # See `_sample_with_triton_kernel` for more details. + sample_idx = 0 + # Total number of prompts from given sequence groups. + num_prompts = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + is_prompt = seq_group_metadata.is_prompt + generator: Optional[torch.Generator] = None + # If the current seq group is in decode stage, it is None. + prompt_len: Optional[int] = None + subquery_len: Optional[int] = None + prompt_logprob_indices: List[int] = [] + sample_indices: List[int] = [] + do_sample = seq_group_metadata.do_sample + + if seq_group_metadata.is_prompt: + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=device).manual_seed(sampling_params.seed) + + num_prompts += 1 + num_prefill_sample = len(seq_ids) + assert num_prefill_sample == 1 + assert subquery_lens is not None and prompt_lens is not None + subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] + # If we need sampling, exclude num_prefill_sample tokens from + # prompt logprob. + prompt_logprob_len = (subquery_len - num_prefill_sample + if do_sample else subquery_len) + sample_len = num_prefill_sample if do_sample else 0 + else: + # Decode + prompt_logprob_len = 0 + sample_len = len(seq_ids) if do_sample else 0 + + # Update indices to select from the model output. + """ + This blocks computes selected_token_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + """ + + if sampling_params.prompt_logprobs: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + prompt_logprob_len)) + model_output_idx += prompt_logprob_len + if do_sample: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + sample_len)) + model_output_idx += sample_len + + # We now find indices for logprob computation and sampling. + """ + This block computes categorized_sample_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + def sample(logits): + # Use categorized_sample_indices for sampling. + # prompt_logprob_indices to find prompt logprob indices. + # sample_indices to find sample indices. + """ + + if sampling_params.prompt_logprobs is not None: + prompt_logprob_indices.extend( + range(logit_idx, logit_idx + prompt_logprob_len)) + logit_idx += prompt_logprob_len + if do_sample: + sample_indices.extend(range(logit_idx, logit_idx + sample_len)) + categorized_sample_indices[sampling_params.sampling_type].extend( + list( + zip(range(logit_idx, logit_idx + sample_len), + range(sample_idx, sample_idx + sample_len)))) + logit_idx += sample_len + sample_idx += sample_len + + if sampling_params.seed is not None: + generator = seq_group_metadata.state.generator + + seq_groups.append( + SequenceGroupToSample( + seq_ids=seq_ids, + sampling_params=sampling_params, + seq_data=seq_group_metadata.seq_data, + prompt_len=prompt_len, + subquery_len=subquery_len, + generator=generator, + is_prompt=is_prompt, + prompt_logprob_indices=list(prompt_logprob_indices), + sample_indices=list(sample_indices))) + return (seq_groups, selected_token_indices, categorized_sample_indices, + num_prompts) @dataclass @@ -112,11 +330,10 @@ def from_sampling_metadata( seeds_to_generate = (extra_seeds_to_generate + get_num_triton_sampler_splits(vocab_size)) - sample_indices_start_idx = 0 assert sampling_metadata.seq_groups is not None - assert sampling_metadata.seq_data is not None - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params temperature = sampling_params.temperature p = sampling_params.presence_penalty f = sampling_params.frequency_penalty @@ -145,45 +362,46 @@ def from_sampling_metadata( or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True - if (i < sampling_metadata.num_prompts + is_prompt = seq_group.is_prompt + if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - assert sampling_metadata.prompt_lens is not None - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - prompt_tokens.extend([] for _ in range(prompt_len - 1)) - output_tokens.extend([] for _ in range(prompt_len - 1)) - for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) - - is_prompt = i < sampling_metadata.num_prompts + subquery_len = seq_group.subquery_len + assert subquery_len is not None + prefill_len = len(seq_group.prompt_logprob_indices) + temperatures += [temperature] * prefill_len + top_ps += [top_p] * prefill_len + top_ks += [top_k] * prefill_len + min_ps += [min_p] * prefill_len + presence_penalties += [0] * prefill_len + frequency_penalties += [0] * prefill_len + repetition_penalties += [1] * prefill_len + prompt_tokens.extend([] for _ in range(prefill_len)) + output_tokens.extend([] for _ in range(prefill_len)) + + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens == len(seq_ids) + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + if is_prompt: prompt_best_of.append(sampling_params.best_of) - assert sampling_metadata.prompt_lens is not None - prompt_len = sampling_metadata.prompt_lens[i] + subquery_len = seq_group.subquery_len + assert subquery_len is not None - if sampling_params.prompt_logprobs is not None: - # NOTE: the sampling position is the last token - # in the prompt - sample_indices_start_idx += prompt_len - 1 for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] + seq_data = seq_group.seq_data[seq_id] extra_entropy = extra_entropy or () seq_seeds = cls._get_sequence_seeds( seed, @@ -193,8 +411,7 @@ def from_sampling_metadata( seeds_to_generate=seeds_to_generate, is_greedy=is_greedy) sampling_seeds.append(seq_seeds) - sample_indices.append(sample_indices_start_idx) - sample_indices_start_idx += 1 + sample_indices.extend(seq_group.sample_indices) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, @@ -217,12 +434,14 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() - prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_max_len = max([len(tokens) for tokens in prompt_tokens], + default=0) prompt_padded_tokens = [ tokens + [vocab_size] * (prompt_max_len - len(tokens)) for tokens in prompt_tokens ] - output_max_len = max(len(tokens) for tokens in output_tokens) + output_max_len = max([len(tokens) for tokens in output_tokens], + default=0) output_padded_tokens = [ tokens + [vocab_size] * (output_max_len - len(tokens)) for tokens in output_tokens diff --git a/vllm/sequence.py b/vllm/sequence.py index b296b37a84f15..567fca5709518 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -28,7 +28,10 @@ class Logprob: decoded_token: Optional[str] = None +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. PromptLogprobs = List[Optional[Dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. SampleLogprobs = List[Dict[int, Logprob]] @@ -215,7 +218,7 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(prompt_token_ids) + self.data: SequenceData = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -559,6 +562,9 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + do_sample: True if sampling is required. Sampling is not required when + e.g., prefill is chunked, and the current iteration only computes + query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. state: Internal state tied to this sequence group. @@ -573,6 +579,7 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + do_sample: bool = True, token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, @@ -589,6 +596,7 @@ def __init__( self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state self._token_chunk_size = token_chunk_size + self.do_sample = do_sample if self._token_chunk_size is None: if is_prompt: @@ -650,6 +658,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], ) -> None: self.samples = samples + # Prompt logprob for each prompt query token. self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index bf0a6c84e6f07..34d7d3dffea18 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -10,9 +10,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad, maybe_expand_dim +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) @@ -38,6 +37,8 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False self.lora_config = lora_config self.vision_language_config = vision_language_config self.load_config = load_config @@ -252,99 +253,6 @@ def _prepare_decode( attn_metadata, ) - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices: Dict[SamplingType, - List[Tuple[int, int]]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - subquery_len = prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += subquery_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append( - (categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx)) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + subquery_len - 1)) - selected_token_indices.append(selected_token_start_idx + - subquery_len - 1) - selected_token_start_idx += subquery_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long) - - categorized_sample_indices = { - t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -364,8 +272,15 @@ def prepare_input_tensors( (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + # subquery_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use prompt_lens instead. + prompt_lens, + self.device, + pin_memory=False) # Broadcast the metadata. metadata_dict = { "input_tokens": input_tokens, @@ -389,7 +304,6 @@ def prepare_input_tensors( selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, - perform_sampling=False, ) return (input_tokens, input_positions, attn_metadata, @@ -421,7 +335,7 @@ def execute_model( logits = self.model.compute_logits(hidden_states, sampling_metadata) # Only perform sampling in the driver worker. - if not sampling_metadata.perform_sampling: + if not self.is_driver_worker: return None # Sample the next token. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c6da28f110325..0704f5fec54d0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,12 +20,11 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) +from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available, + make_tensor_with_pad) logger = init_logger(__name__) @@ -547,108 +546,6 @@ def _prepare_decode( slot_mapping=slot_mapping, ) - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices: Dict[SamplingType, - List[Tuple[int, int]]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - assert subquery_lens is not None - subquery_len = subquery_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += subquery_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append( - (categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx)) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + subquery_len - 1)) - selected_token_indices.append(selected_token_start_idx + - subquery_len - 1) - selected_token_start_idx += subquery_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - list( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx - + num_seqs)))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=self.pin_memory) - - categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=self.pin_memory), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -685,9 +582,9 @@ def prepare_input_tensors( decode_lora_requests, decode_slot_mapping, ) = self._prepare_decode(decode_reqs) - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, prompt_lens, subquery_lens, + self.device, self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -788,12 +685,9 @@ def prepare_input_tensors( **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, - seq_data=None, - prompt_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, - generators=None, - perform_sampling=False, + num_prompts=0, ) # if it is a mixed batch, decode attn_metadata is broadcasted @@ -852,7 +746,7 @@ def execute_model( logits = self.model.compute_logits(hidden_states, sampling_metadata) # Only perform sampling in the driver worker. - if not sampling_metadata.perform_sampling: + if not self.is_driver_worker: return None # Sample the next token. @@ -860,6 +754,7 @@ def execute_model( logits=logits, sampling_metadata=sampling_metadata, ) + return output @torch.inference_mode() diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 487df334d73e3..a974e85c22f45 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -8,10 +8,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (async_tensor_h2d, is_pin_memory_available, - make_tensor_with_pad, maybe_expand_dim) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available, make_tensor_with_pad logger = init_logger(__name__) @@ -141,106 +139,6 @@ def _prepare_decode( return input_tokens, input_positions, input_block_ids - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices: Dict[SamplingType, - List[Tuple[int, int]]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - assert prompt_lens is not None - prompt_len = prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append( - (categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx)) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) - selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=self.pin_memory) - - categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=self.pin_memory), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -256,8 +154,15 @@ def prepare_input_tensors( (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + # subquery_lens is not needed if chunked prefill is not + # supported. Since neuron worker doesn't support chunked prefill + # just use prompt_lens instead. + prompt_lens, + self.device, + self.pin_memory) return (input_tokens, input_positions, input_block_ids, sampling_metadata) From a62aaf1df558d69658a42c1ab749368ab0325f35 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 26 Apr 2024 13:41:14 -0700 Subject: [PATCH 131/413] [Misc][Refactor] Generalize linear_method to be quant_method (#4373) --- tests/quantization/test_fp8.py | 2 +- tests/tensorizer_loader/test_tensorizer.py | 4 +- vllm/lora/layers.py | 30 ++-- vllm/model_executor/layers/linear.py | 169 +++++++++--------- .../layers/quantization/__init__.py | 4 +- .../layers/quantization/aqlm.py | 13 +- .../model_executor/layers/quantization/awq.py | 19 +- .../layers/quantization/base_config.py | 31 +++- .../model_executor/layers/quantization/fp8.py | 60 +++---- .../layers/quantization/gptq.py | 19 +- .../layers/quantization/marlin.py | 13 +- .../layers/quantization/squeezellm.py | 24 +-- vllm/model_executor/model_loader/loader.py | 38 ++-- .../model_executor/model_loader/tensorizer.py | 13 +- vllm/model_executor/models/baichuan.py | 43 ++--- vllm/model_executor/models/bloom.py | 33 ++-- vllm/model_executor/models/chatglm.py | 37 ++-- vllm/model_executor/models/commandr.py | 33 ++-- vllm/model_executor/models/dbrx.py | 35 ++-- vllm/model_executor/models/decilm.py | 7 +- vllm/model_executor/models/deepseek.py | 45 +++-- vllm/model_executor/models/falcon.py | 33 ++-- vllm/model_executor/models/gemma.py | 33 ++-- vllm/model_executor/models/gpt2.py | 33 ++-- vllm/model_executor/models/gpt_bigcode.py | 33 ++-- vllm/model_executor/models/gpt_j.py | 33 ++-- vllm/model_executor/models/gpt_neox.py | 33 ++-- vllm/model_executor/models/internlm2.py | 33 ++-- vllm/model_executor/models/jais.py | 33 ++-- vllm/model_executor/models/llama.py | 32 ++-- vllm/model_executor/models/llava.py | 9 +- vllm/model_executor/models/minicpm.py | 35 ++-- vllm/model_executor/models/mixtral.py | 44 ++--- vllm/model_executor/models/mixtral_quant.py | 41 ++--- vllm/model_executor/models/mpt.py | 33 ++-- vllm/model_executor/models/olmo.py | 32 ++-- vllm/model_executor/models/opt.py | 37 ++-- vllm/model_executor/models/orion.py | 33 ++-- vllm/model_executor/models/phi.py | 35 ++-- vllm/model_executor/models/qwen.py | 33 ++-- vllm/model_executor/models/qwen2.py | 33 ++-- vllm/model_executor/models/qwen2_moe.py | 45 +++-- vllm/model_executor/models/stablelm.py | 29 +-- vllm/model_executor/models/starcoder2.py | 32 ++-- vllm/model_executor/models/xverse.py | 33 ++-- 45 files changed, 759 insertions(+), 713 deletions(-) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index fa10e60de10a7..607544a1c8394 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model fc1 = model.model.decoder.layers[0].fc1 - assert isinstance(fc1.linear_method, Fp8LinearMethod) + assert isinstance(fc1.quant_method, Fp8LinearMethod) assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index a97cc0b3706b4..df1db4e6c4001 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): mock_agent_instance.deserialize.return_value = MagicMock() result = load_with_tensorizer(tensorizer_config, - linear_method=mock_linear_method) + quant_method=mock_linear_method) mock_agent.assert_called_once_with(tensorizer_config, - linear_method=mock_linear_method) + quant_method=mock_linear_method) mock_agent_instance.deserialize.assert_called_once() assert result == mock_agent_instance.deserialize.return_value diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 98e74168002c4..4eaf73fbcfda4 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -389,10 +389,9 @@ def set_mapping( self.indices = base_indices self.indices_len = indices_len - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -416,7 +415,7 @@ def forward(self, input_): if not self.base_layer.skip_bias_add else None) # Matrix multiply. - output_parallel = self.apply_weights(input_, bias) + output_parallel = self.apply(input_, bias) if self.base_layer.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -523,10 +522,9 @@ def set_lora( index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -765,10 +763,9 @@ def set_lora( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( lora_a[2].T, non_blocking=True) - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -862,9 +859,8 @@ def set_mapping( self.indices = base_indices self.indices_len = indices_len - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x) + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) _apply_lora( x, self.lora_a_stacked, @@ -897,7 +893,7 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.apply_weights(input_parallel) + output_parallel = self.apply(input_parallel) if self.base_layer.reduce_results and self.base_layer.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6ad7ae0f63197..1bd6c42ab3fd8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,9 +1,8 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import List, Optional import torch import torch.nn.functional as F -from torch import nn from torch.nn.parameter import Parameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -12,6 +11,8 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -class LinearMethodBase(ABC): +class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod @@ -50,22 +51,15 @@ def create_weights(self, layer: torch.nn.Module, raise NotImplementedError @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" raise NotImplementedError - def process_weights_after_loading(self, layer: nn.Module) -> None: - """Process the weight after loading. - - This can be used for example, to transpose weights for computation. - """ - return - class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. @@ -92,10 +86,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = layer.weight if self.separate_bias_add: if bias is not None: @@ -104,8 +98,8 @@ def apply_weights(self, return F.linear(x, weight, bias) -class ReplicatedLinear(torch.nn.Module): - """Replicated linear layer. +class LinearBase(torch.nn.Module): + """Base linear layer. Args: input_size: input dimension of the linear layer. @@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module): bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( self, input_size: int, output_size: int, - bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -134,12 +127,43 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_method.create_weights(self, self.input_size, - [self.output_size], self.input_size, - self.output_size, self.params_dtype) + if quant_config is None: + self.quant_method = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) + + self.quant_method.create_weights(self, self.input_size, + [self.output_size], self.input_size, + self.output_size, self.params_dtype) + if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -149,12 +173,12 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self, x, bias) + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias -class ColumnParallelLinear(torch.nn.Module): +class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along @@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. """ @@ -184,34 +208,26 @@ def __init__( gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, ): - super().__init__() + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) - # Keep input parameters - self.input_size = input_size - self.output_size = output_size self.gather_output = gather_output + # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() self.output_size_per_partition = divide(output_size, tp_size) - self.skip_bias_add = skip_bias_add - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - if linear_method is None: - linear_method = UnquantizedLinearMethod() if output_sizes is None: output_sizes = [output_size] - self.linear_method = linear_method - self.linear_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights(self, + self.input_size, + [x // tp_size for x in output_sizes], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -239,7 +255,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights(self, input_, bias) + output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( @@ -278,13 +294,13 @@ def __init__( gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method, + skip_bias_add, params_dtype, quant_config, self.output_sizes) def weight_loader(self, @@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( @@ -396,7 +412,7 @@ def __init__( bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): self.hidden_size = hidden_size self.head_size = head_size @@ -424,7 +440,7 @@ def __init__( ] super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method, output_sizes) + params_dtype, quant_config, output_sizes) def weight_loader(self, param: Parameter, @@ -517,7 +533,7 @@ def weight_loader(self, param_data.copy_(loaded_weight) -class RowParallelLinear(torch.nn.Module): +class RowParallelLinear(LinearBase): """Linear layer with row parallelism. The linear layer is defined as Y = XA + b. A is parallelized along @@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( @@ -552,32 +568,24 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): - super().__init__() - # Keep input parameters - self.input_size = input_size - self.output_size = output_size + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) + self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - self.skip_bias_add = skip_bias_add - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights(self, + self.input_size_per_partition, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -616,8 +624,7 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self, input_parallel) + output_parallel = self.quant_method.apply(self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a525add458499..0820f17c5c50d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import FP8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -12,7 +12,7 @@ QUANTIZATION_METHODS = { "aqlm": AQLMConfig, "awq": AWQConfig, - "fp8": FP8Config, + "fp8": Fp8Config, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index b48c6e1702be4..83e24fadc1405 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -9,10 +9,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs def get_int_dtype(nbits: int) -> torch.dtype: @@ -207,8 +207,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": return cls(in_group_size, nbits_per_codebook, num_code_books, out_group_size) - def get_linear_method(self) -> "AQLMLinearMethod": - return AQLMLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]: + if isinstance(layer, LinearBase): + return AQLMLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -321,7 +324,7 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - def apply_weights( + def apply( self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 4f75134ee1889..f4fc7ce020e95 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,10 +4,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class AWQConfig(QuantizationConfig): @@ -62,8 +62,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - def get_linear_method(self) -> "AWQLinearMethod": - return AWQLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: + if isinstance(layer, LinearBase): + return AWQLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] @@ -147,10 +150,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = layer.qweight scales = layer.scales qzeros = layer.qzeros diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 6115e7c3be956..b755b1328504a 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -2,8 +2,33 @@ from typing import Any, Dict, List import torch +from torch import nn -from vllm.model_executor.layers.linear import LinearMethodBase + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, *weight_args, + **extra_weight_attrs): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return class QuantizationConfig(ABC): @@ -51,8 +76,8 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_linear_method(self) -> LinearMethodBase: - """Get the linear method to use for the quantized linear layer.""" + def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase: + """Get the quantize method to use for the quantized layer.""" raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 01e494c870e71..39679834b545c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,16 +1,17 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs -class FP8Config(QuantizationConfig): +class Fp8Config(QuantizationConfig): """Config class for FP8.""" @classmethod @@ -33,11 +34,14 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "FP8Config": + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": return cls() - def get_linear_method(self) -> "Fp8LinearMethod": - return Fp8LinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return Fp8LinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: FP8Config): + def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights( @@ -86,24 +90,24 @@ def create_weights( layer.register_parameter("weight_scaling_factor", w_scale) def process_weights_after_loading(self, layer: Module) -> None: - # Although the linear_method is propagated to all layers, + # Although the quant_method is propagated to all layers, # only linear layers invoke "create_weights". So we check # whether "weight_scaling_facor" is registered to determine # whether the layer is a linear layer that requires quantization. if not hasattr(layer, "weight_scaling_factor"): return - qweight, weight_scale = per_tensor_quantize(layer.weight) + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight) # torch._scaled_mm requires column-major in the second # input (weight), so we transpose the quantized weight. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scaling_factor.data.copy_(weight_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qinput, x_scale = per_tensor_quantize(x) + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qinput, x_scale = ops.scaled_fp8_quant(x) output, _ = torch._scaled_mm( qinput, layer.weight, @@ -113,27 +117,3 @@ def apply_weights(self, bias=bias, ) return output - - -def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: - """Quantize a tensor using per-tensor static scaling factor. - - Args: - tensor: The input tensor. - """ - finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax. - # Since .abs() creates a new tensor, we use aminmax to get - # the min and max first and then calculate the absmax. - min_val, max_val = tensor.aminmax() - amax = min_val.abs().max(max_val.abs()) - scale = finfo.max / amax.clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(torch.float8_e4m3fn) - scale = scale.float().reciprocal() - return qweight, scale diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 92a5cdb9af928..ae9f7019f0592 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -7,10 +7,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class GPTQConfig(QuantizationConfig): @@ -63,8 +63,11 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": desc_act = cls.get_from_keys(config, ["desc_act"]) return cls(weight_bits, group_size, desc_act) - def get_linear_method(self) -> "GPTQLinearMethod": - return GPTQLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -194,10 +197,10 @@ def create_weights( layer.exllama_state = exllama_state - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 00c3c404c2d7a..94aba620ea083 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -4,10 +4,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class MarlinConfig(QuantizationConfig): @@ -72,8 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) return cls(group_size) - def get_linear_method(self) -> "MarlinLinearMethod": - return MarlinLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: + if isinstance(layer, LinearBase): + return MarlinLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -197,7 +200,7 @@ def create_weights( layer.register_parameter("workspace", workspace) set_weight_attrs(workspace, extra_weight_attrs) - def apply_weights( + def apply( self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index cc44447d347b8..971078fe25a9b 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -4,10 +4,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hip @@ -51,14 +51,18 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) return cls(weight_bits) - def get_linear_method(self) -> "SqueezeLLMLinearMethod": - return SqueezeLLMLinearMethod(self) + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]: + if isinstance(layer, LinearBase): + return SqueezeLLMLinearMethod(self) + return def get_scaled_act_names(self) -> List[str]: return [] -class SqueezeLLMLinearMethod(LinearMethodBase): +class SqueezeLLMLinearMethod(QuantizeMethodBase): """Linear method for SqueezeLLM. Args: @@ -112,10 +116,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("lookup_table", lookup_table) set_weight_attrs(lookup_table, extra_weight_attrs) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = layer.qweight lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f75c35a69d4a9..ad80243019a65 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -3,8 +3,7 @@ import glob import os from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, - Type) +from typing import Any, Dict, Generator, List, Optional, Tuple, Type import torch from torch import nn @@ -13,6 +12,8 @@ LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, tensorizer_weights_iterator) @@ -24,9 +25,6 @@ pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.llava import LlavaForConditionalGeneration -if TYPE_CHECKING: - from vllm.model_executor.layers.linear import LinearMethodBase - _VISION_MODEL_CLASSES = [ LlavaForConditionalGeneration, ] @@ -34,11 +32,10 @@ logger = init_logger(__name__) -def _get_linear_method( +def _get_quantization_config( model_config: ModelConfig, - load_config: LoadConfig) -> Optional["LinearMethodBase"]: - """Get the (maybe quantized) linear method.""" - linear_method = None + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) capability = torch.cuda.get_device_capability() @@ -55,9 +52,8 @@ def _get_linear_method( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") - - linear_method = quant_config.get_linear_method() - return linear_method + return quant_config + return None def _get_model_initialization_kwargs( @@ -85,10 +81,10 @@ def _initialize_model( vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] - linear_method = _get_linear_method(model_config, load_config) + quant_config = _get_quantization_config(model_config, load_config) return model_class(config=model_config.hf_config, - linear_method=linear_method, + quant_config=quant_config, **_get_model_initialization_kwargs( model_class, lora_config, vision_language_config)) @@ -229,9 +225,11 @@ def load_model(self, *, model_config: ModelConfig, "fall_back_to_pt_during_load", True)), ) for _, module in model.named_modules(): - linear_method = getattr(module, "linear_method", None) - if linear_method is not None: - linear_method.process_weights_after_loading(module) + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. if hasattr(module, "process_weights_after_loading"): module.process_weights_after_loading() return model.eval() @@ -314,11 +312,11 @@ def _load_model_serialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] - linear_method = _get_linear_method(model_config, - self.load_config) + quant_config = _get_quantization_config( + model_config, self.load_config) extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, vision_language_config) - extra_kwargs["linear_method"] = linear_method + extra_kwargs["quant_config"] = quant_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 7e65d54bc522f..8fc6d16672117 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -13,7 +13,8 @@ from vllm.config import ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -251,7 +252,7 @@ class TensorizerAgent: """ def __init__(self, tensorizer_config: TensorizerConfig, - linear_method: LinearMethodBase, **extra_kwargs): + quant_config: QuantizationConfig, **extra_kwargs): if tensorizer_load_fail is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " @@ -262,10 +263,10 @@ def __init__(self, tensorizer_config: TensorizerConfig, self.tensorizer_args = ( self.tensorizer_config._construct_tensorizer_args()) self.extra_kwargs = extra_kwargs - if extra_kwargs.get("linear_method", None) is not None: - self.linear_method = extra_kwargs["linear_method"] + if extra_kwargs.get("quant_config", None) is not None: + self.quant_config = extra_kwargs["quant_config"] else: - self.linear_method = linear_method + self.quant_config = quant_config self.model = self._init_model() def _init_model(self): @@ -274,7 +275,7 @@ def _init_model(self): with no_init_or_tensor(): return self.tensorizer_config.model_class( config=model_args, - linear_method=self.linear_method, + quant_config=self.quant_config, **self.extra_kwargs) def _resize_lora_embeddings(self): diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 69162b0a92d65..186cee2584369 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -31,11 +31,12 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -77,17 +78,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -110,7 +111,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -132,13 +133,13 @@ def __init__( self.total_num_heads, self.total_num_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": @@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -196,13 +197,13 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -254,7 +255,7 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, linear_method) + BaiChuanDecoderLayer(config, position_embedding, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -303,13 +304,13 @@ def __init__( self, config, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.model = BaiChuanModel(config, position_embedding, linear_method) + self.quant_config = quant_config + self.model = BaiChuanModel(config, position_embedding, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", linear_method, lora_config) + super().__init__(config, "ROPE", quant_config, lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", linear_method, lora_config) + super().__init__(config, "ALIBI", quant_config, lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", linear_method, lora_config) + super().__init__(config, "ROPE", quant_config, lora_config) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 14f325e624f41..b425af4863c36 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -28,10 +28,11 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -70,7 +71,7 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -87,13 +88,13 @@ def __init__( self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) # Create the alibi slopes and slice them. @@ -129,21 +130,21 @@ class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -158,17 +159,17 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, linear_method) + self.self_attention = BloomAttention(config, quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) - self.mlp = BloomMLP(config, linear_method) + self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) @@ -214,7 +215,7 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -229,7 +230,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, linear_method) + BloomBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = BloomModel(config, linear_method) + self.quant_config = quant_config + self.transformer = BloomModel(config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 3cdb7a7bca1c1..e116af2ed080d 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -13,11 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -33,7 +34,7 @@ class GLMAttention(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -65,13 +66,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 @@ -123,7 +124,7 @@ class GLMMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -134,7 +135,7 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) self.activation_func = SiluAndMul() @@ -144,7 +145,7 @@ def __init__( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, hidden_states): @@ -166,7 +167,7 @@ class GLMBlock(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -180,7 +181,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, linear_method) + self.self_attention = GLMAttention(config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -188,7 +189,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config, linear_method) + self.mlp = GLMMLP(config, quant_config) def forward( self, @@ -236,7 +237,7 @@ class GLMTransformer(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -246,7 +247,7 @@ def __init__( # Transformer layers. self.layers = nn.ModuleList( - [GLMBlock(config, linear_method) for i in range(self.num_layers)]) + [GLMBlock(config, quant_config) for i in range(self.num_layers)]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -291,7 +292,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, linear_method) + self.encoder = GLMTransformer(config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config - self.linear_method = linear_method - self.transformer = ChatGLMModel(config, linear_method) + self.quant_config = quant_config + self.transformer = ChatGLMModel(config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index d80969773e163..17c2f1223d96b 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -32,11 +32,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -91,7 +92,7 @@ class CohereMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -101,13 +102,13 @@ def __init__( self.hidden_size, [self.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.act_fn = SiluAndMul() @@ -123,7 +124,7 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -158,13 +159,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, linear_method=linear_method) + self.self_attn = CohereAttention(config, quant_config=quant_config) - self.mlp = CohereMLP(config, linear_method=linear_method) + self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) @@ -257,7 +258,7 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -265,7 +266,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, linear_method=linear_method) + CohereDecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, linear_method) + self.model = CohereModel(config, quant_config) self.sampler = Sampler() @torch.no_grad() diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 179094b8fd7aa..a4a0ae50c645e 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -9,11 +9,12 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +45,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=params_dtype, - linear_method=None, + quant_config=None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -63,7 +64,7 @@ class DbrxExperts(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -165,7 +166,7 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -183,13 +184,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( self.d_model, self.d_model, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, linear_method) + self.attn = DbrxAttention(config, quant_config) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -278,11 +279,11 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method) - self.ffn = DbrxExperts(config, linear_method) + self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) + self.ffn = DbrxExperts(config, quant_config) def forward( self, @@ -307,7 +308,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.wte = VocabParallelEmbedding( @@ -315,7 +316,7 @@ def __init__( config.d_model, ) self.blocks = nn.ModuleList( - [DbrxBlock(config, linear_method) for _ in range(config.n_layers)]) + [DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, linear_method) + self.transformer = DbrxModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index d476630ee6f11..be9a6b6813f8f 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -29,7 +29,8 @@ from transformers import PretrainedConfig from vllm.config import LoRAConfig -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM @@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, - linear_method=linear_method, + quant_config=quant_config, lora_config=lora_config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 46101a152ec0d..e5f7ba086a35d 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -34,12 +34,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -56,18 +57,18 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -103,7 +104,7 @@ def __init__( DeepseekMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False) for idx in range(self.n_routed_experts) ]) @@ -112,7 +113,7 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, - linear_method=None) + quant_config=None) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -121,7 +122,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False, ) @@ -177,7 +178,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -208,14 +209,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -251,7 +252,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -266,18 +267,18 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + self.mlp = DeepseekMoE(config=config, quant_config=quant_config) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -320,7 +321,7 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -331,9 +332,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, - layer_idx, - linear_method=linear_method) + DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = DeepseekModel(config, linear_method) + self.quant_config = quant_config + self.model = DeepseekModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 25ce239d14662..4be1f064cdd3e 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -32,10 +32,11 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -76,7 +77,7 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -115,7 +116,7 @@ def __init__( self.total_num_kv_heads, bias=config.bias, skip_bias_add=True, - linear_method=linear_method, + quant_config=quant_config, ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -129,7 +130,7 @@ def __init__( self.hidden_size, bias=config.bias, skip_bias_add=True, - linear_method=linear_method, + quant_config=quant_config, reduce_results=self.reduce_row_parallel_results) self.use_rotary = config.rotary @@ -192,7 +193,7 @@ class FalconMLP(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -201,8 +202,8 @@ def __init__( 4 * hidden_size, bias=config.bias, skip_bias_add=True, - linear_method=linear_method) - quant_config = getattr(linear_method, "quant_config", None) + quant_config=quant_config) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -212,7 +213,7 @@ def __init__( bias=config.bias, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, - linear_method=linear_method) + quant_config=quant_config) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -229,13 +230,13 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, linear_method) - self.mlp = FalconMLP(config, linear_method) + self.self_attention = FalconAttention(config, quant_config) + self.mlp = FalconMLP(config, quant_config) self.config = config if config.new_decoder_architecture: @@ -311,7 +312,7 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -327,7 +328,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, linear_method) + FalconDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -359,12 +360,12 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = FalconModel(config, linear_method) + self.quant_config = quant_config + self.transformer = FalconModel(config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index c3193258d6418..bb73ff4d206da 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -27,11 +27,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -77,17 +78,17 @@ def __init__( intermediate_size: int, hidden_act: Optional[str] = None, hidden_activation: Optional[str] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): @@ -106,7 +107,7 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -135,13 +136,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -187,14 +188,14 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = GemmaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_activation=getattr(config, "hidden_activation", None), - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -235,7 +236,7 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -245,7 +246,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, linear_method) + GemmaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config # Unused. super().__init__() self.config = config - self.linear_method = linear_method - self.model = GemmaModel(config, linear_method) + self.quant_config = quant_config + self.model = GemmaModel(config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 850050c7232d0..ac1dce6dec8a6 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -27,10 +27,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -44,7 +45,7 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -61,13 +62,13 @@ def __init__( self.head_dim, total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) @@ -90,7 +91,7 @@ def __init__( self, intermediate_size: int, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -98,15 +99,15 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -122,7 +123,7 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -130,9 +131,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, linear_method) + self.attn = GPT2Attention(config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, linear_method) + self.mlp = GPT2MLP(inner_dim, config, quant_config) def forward( self, @@ -163,7 +164,7 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -174,7 +175,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, linear_method) + GPT2Block(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -203,12 +204,12 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = GPT2Model(config, linear_method) + self.quant_config = quant_config + self.transformer = GPT2Model(config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 8278ba02514d5..e52ac679f5d03 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -28,10 +28,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -72,14 +73,14 @@ def __init__( total_num_heads, total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -111,7 +112,7 @@ def __init__( self, intermediate_size: int, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -119,15 +120,15 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -143,7 +144,7 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -151,9 +152,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, linear_method) + self.attn = GPTBigCodeAttention(config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config, linear_method) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) def forward( self, @@ -184,7 +185,7 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -195,7 +196,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, linear_method) + GPTBigCodeBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -224,12 +225,12 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = GPTBigCodeModel(config, linear_method) + self.quant_config = quant_config + self.transformer = GPTBigCodeModel(config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 7a830d7f9c965..287f4186f7469 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -26,10 +26,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +45,7 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -56,13 +57,13 @@ def __init__( self.head_size, self.total_num_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( config.hidden_size, config.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -105,21 +106,21 @@ def __init__( self, intermediate_size: int, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.n_embd self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -135,14 +136,14 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, linear_method) - self.mlp = GPTJMLP(inner_dim, config, linear_method) + self.attn = GPTJAttention(config, quant_config) + self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( self, @@ -169,7 +170,7 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -179,7 +180,7 @@ def __init__( self.embed_dim, ) self.h = nn.ModuleList( - [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) + [GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,13 +208,13 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, linear_method) + self.transformer = GPTJModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b946aed92ed35..cbc5115bd377b 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -26,10 +26,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -63,13 +64,13 @@ def __init__( self.head_size, self.total_num_heads, bias=self.bias, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, bias=self.bias, - linear_method=linear_method, + quant_config=quant_config, ) scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) @@ -105,20 +106,20 @@ class GPTNeoXMLP(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.dense_4h_to_h = RowParallelLinear( config.intermediate_size, config.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -134,7 +135,7 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -142,8 +143,8 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, linear_method) - self.mlp = GPTNeoXMLP(config, linear_method) + self.attention = GPTNeoXAttention(config, quant_config) + self.mlp = GPTNeoXMLP(config, quant_config) def forward( self, @@ -182,7 +183,7 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -192,7 +193,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, linear_method) + GPTNeoXLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -223,12 +224,12 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.gpt_neox = GPTNeoXModel(config, linear_method) + self.quant_config = quant_config + self.gpt_neox = GPTNeoXModel(config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index db1da8bdc4fb9..5811cae83bf8b 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -9,11 +9,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -30,17 +31,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w2 = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -63,7 +64,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -94,13 +95,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -150,13 +151,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -195,7 +196,7 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -206,7 +207,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) + InternLMDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = InternLM2Model(config, linear_method) + self.quant_config = quant_config + self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index e7ee749e824e4..bd6a180ec8dfc 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -29,10 +29,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -68,7 +69,7 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -88,13 +89,13 @@ def __init__( self.head_dim, total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) tp_rank = get_tensor_model_parallel_rank() @@ -128,7 +129,7 @@ def __init__( self, intermediate_size: int, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -137,19 +138,19 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_fc2 = (ColumnParallelLinear( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) if self.swiglu else None) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.act = SwiGLUActivation() @@ -169,7 +170,7 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -177,9 +178,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, linear_method) + self.attn = JAISAttention(config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = JAISMLP(inner_dim, config, linear_method) + self.mlp = JAISMLP(inner_dim, config, quant_config) def forward( self, @@ -210,7 +211,7 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -227,7 +228,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, linear_method) + JAISBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = JAISModel(config, linear_method) + self.quant_config = quant_config + self.transformer = JAISModel(config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c102b40045c92..f6d7fc8733fce 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -33,11 +33,12 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -56,17 +57,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QKVParallelLinear] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -89,7 +90,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: @@ -131,13 +132,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -174,7 +175,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -199,7 +200,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, bias=attention_bias, sliding_window=sliding_window, ) @@ -207,7 +208,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -248,7 +249,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -264,7 +265,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) + LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -329,13 +330,12 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.model = LlamaModel(config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 314a2792bf167..dcde4dfa0795e 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -9,8 +9,9 @@ from vllm.attention import AttentionMetadata from vllm.config import VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, - linear_method: Optional["LinearMethodBase"] = None) -> None: + quant_config: Optional["QuantizationConfig"] = None) -> None: super().__init__() self.config = config @@ -83,8 +84,8 @@ def __init__(self, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.linear_method = linear_method - self.language_model = LlamaModel(config.text_config, linear_method) + self.quant_config = quant_config + self.language_model = LlamaModel(config.text_config, quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index f0d72fafcaf70..c90bcfbfc4707 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -35,12 +35,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -84,7 +85,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None) + quant_config=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, @@ -147,17 +148,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -180,7 +181,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -211,13 +212,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -274,7 +275,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: @@ -282,7 +283,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) else: self.mlp = MiniCPMMoE(num_experts=config.num_experts, @@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -345,7 +346,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, linear_method) + MiniCPMDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.num_experts = getattr(self.config, "num_experts", 0) - self.linear_method = linear_method + self.quant_config = quant_config self.model = MiniCPMModel(config, - linear_method, + quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a33b795d7088e..7847df735ab44 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -27,6 +27,7 @@ from torch import nn from transformers import MixtralConfig +from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, @@ -34,13 +35,13 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod, - per_tensor_quantize) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -69,7 +70,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -79,7 +80,7 @@ def __init__( self.intermediate_size = intermediate_size // self.tp_size # FIXME(pcmoritz): Make this more general to support different # quantization schemes - self.use_fp8 = isinstance(linear_method, Fp8LinearMethod) + self.use_fp8 = isinstance(quant_config, Fp8Config) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -89,7 +90,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None) + quant_config=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, @@ -140,10 +141,10 @@ def process_weights_after_loading(self): ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize( + ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant( self.ws.data[expert, :, :]) w2s[expert, :, :], self.w2s_scale[ - expert] = per_tensor_quantize(self.w2s.data[expert, :, :]) + expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) self.ws = nn.Parameter(ws, requires_grad=False) self.w2s = nn.Parameter(w2s, requires_grad=False) @@ -178,7 +179,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -203,12 +204,12 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - if isinstance(linear_method, Fp8LinearMethod): + if isinstance(quant_config, Fp8Config): print_warning_once( "For Mixtral FP8 quantization, we currently do not quantize " "the attention layers until their FP8 performance is improved." ) - linear_method = None + quant_config = None self.qkv_proj = QKVParallelLinear( hidden_size, @@ -216,13 +217,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -259,7 +260,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -272,13 +273,13 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, - linear_method=linear_method) + quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - linear_method=linear_method) + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -318,7 +319,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -334,7 +335,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) + MixtralDecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -384,14 +385,13 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method self.model = MixtralModel(config, - linear_method, + quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index acd13cc27f159..38c62afced28a 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -34,11 +34,12 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -55,7 +56,7 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.num_experts = num_experts @@ -65,15 +66,15 @@ def __init__( self.w1 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w2 = ReplicatedLinear(self.ffn_dim, self.hidden_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w3 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) # TODO: Use vllm's SiluAndMul self.act_fn = nn.SiLU() @@ -92,7 +93,7 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -115,14 +116,14 @@ def __init__( MixtralMLP(self.num_total_experts, config.hidden_size, config.intermediate_size, - linear_method=linear_method) + quant_config=quant_config) if idx in self.expert_indicies else None for idx in range(self.num_total_experts) ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, - linear_method=None) + quant_config=None) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -162,7 +163,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -193,13 +194,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -249,9 +250,9 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, - linear_method=linear_method) + quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -291,7 +292,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -302,7 +303,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) + MixtralDecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = MixtralModel(config, linear_method) + self.quant_config = quant_config + self.model = MixtralModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 340f63286739b..8c5e7e77c9306 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -11,10 +11,11 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -42,7 +43,7 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -65,7 +66,7 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) if self.qk_ln: self.q_ln = nn.LayerNorm(self.d_model) @@ -74,7 +75,7 @@ def __init__( self.d_model, self.d_model, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -133,7 +134,7 @@ class MPTMLP(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model @@ -143,15 +144,15 @@ def __init__( hidden_size, intermediate_size, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn("gelu", quant_config, intermediate_size) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -166,14 +167,14 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, linear_method) + self.attn = MPTAttention(config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MPTMLP(config, linear_method) + self.ffn = MPTMLP(config, quant_config) def forward( self, @@ -201,7 +202,7 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() assert config.embedding_fraction == 1.0 @@ -212,7 +213,7 @@ def __init__( config.d_model, ) self.blocks = nn.ModuleList( - [MPTBlock(config, linear_method) for _ in range(config.n_layers)]) + [MPTBlock(config, quant_config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -246,14 +247,14 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config assert config.tie_word_embeddings - self.linear_method = linear_method + self.quant_config = quant_config - self.transformer = MPTModel(config, linear_method) + self.transformer = MPTModel(config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 15527569b9e20..f212ea2166e1d 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -30,11 +30,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -54,7 +55,7 @@ class OlmoAttention(nn.Module): def __init__( self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -79,7 +80,7 @@ def __init__( self.head_dim, self.total_num_heads, bias=config.attention_bias, - linear_method=linear_method, + quant_config=quant_config, ) # Rotary embeddings. @@ -99,7 +100,7 @@ def __init__( self.hidden_size, self.hidden_size, bias=config.attention_bias, - linear_method=linear_method, + quant_config=quant_config, ) def forward( @@ -129,7 +130,7 @@ class OlmoMLP(nn.Module): def __init__( self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -141,7 +142,7 @@ def __init__( self.hidden_size, [self.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) # Activation function. @@ -152,7 +153,7 @@ def __init__( self.intermediate_size, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) def forward( @@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, linear_method) + self.self_attn = OlmoAttention(config, quant_config) # MLP block. - self.mlp = OlmoMLP(config, linear_method) + self.mlp = OlmoMLP(config, quant_config) # LayerNorm self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -216,14 +217,14 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, linear_method) + OlmoDecoderLayer(config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, @@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module): def __init__(self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method - self.model = OlmoModel(config, linear_method) + self.model = OlmoModel(config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight else: diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 89263166bca81..838a2f0adc4d1 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,7 +61,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -77,13 +78,13 @@ def __init__( self.head_dim, total_num_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -116,7 +117,7 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -127,16 +128,16 @@ def __init__( self.embed_dim, config.ffn_dim, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.activation_fn = get_act_fn(config.activation_function, quant_config, config.ffn_dim) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, @@ -181,7 +182,7 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -202,7 +203,7 @@ def __init__( self.project_out = ReplicatedLinear(config.hidden_size, config.word_embed_proj_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) else: self.project_out = None @@ -210,7 +211,7 @@ def __init__( self.project_in = ReplicatedLinear(config.word_embed_proj_dim, config.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) else: self.project_in = None @@ -226,7 +227,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, linear_method) + OPTDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -259,10 +260,10 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, linear_method) + self.decoder = OPTDecoder(config, quant_config) def forward( self, @@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.model = OPTModel(config, linear_method) + self.quant_config = quant_config + self.model = OPTModel(config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index bbb9fa5347cc8..9ab5dfb97c19a 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -13,11 +13,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -34,17 +35,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -67,7 +68,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -98,13 +99,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -154,13 +155,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = OrionMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -201,7 +202,7 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -212,7 +213,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, linear_method) + OrionDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = OrionModel(config, linear_method) + self.quant_config = quant_config + self.model = OrionModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index f974b78a0fbda..7a9b8dcd6a509 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -45,10 +45,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -62,7 +63,7 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -80,12 +81,12 @@ def __init__(self, self.head_size, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) scaling = self.head_size**-0.5 @@ -125,7 +126,7 @@ class PhiMLP(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() n_inner = getattr(config, "n_inner", None) @@ -134,14 +135,14 @@ def __init__(self, self.fc1 = ColumnParallelLinear( config.hidden_size, n_inner, - linear_method=linear_method, + quant_config=quant_config, ) self.fc2 = RowParallelLinear( n_inner, config.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): @@ -155,12 +156,12 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, linear_method) - self.mlp = PhiMLP(config, linear_method) + self.self_attn = PhiAttention(config, quant_config) + self.mlp = PhiMLP(config, quant_config) def forward( self, @@ -186,14 +187,14 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, linear_method) + PhiLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -225,12 +226,12 @@ class PhiForCausalLM(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config - self.model = PhiModel(config, linear_method) + self.model = PhiModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a77da7cb15984..e5e0028888c88 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -14,11 +14,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -35,17 +36,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str = "silu", - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.c_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -67,7 +68,7 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -83,13 +84,13 @@ def __init__( self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.scaling = self.head_dim**-0.5 @@ -122,7 +123,7 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -134,13 +135,13 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, - linear_method=linear_method) + quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, - linear_method=linear_method) + quant_config=quant_config) def forward( self, @@ -174,7 +175,7 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -185,7 +186,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, linear_method) + QWenBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = QWenModel(config, linear_method) + self.quant_config = quant_config + self.transformer = QWenModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 71b906e20ac19..62bc7fe22c367 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -33,11 +33,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -54,17 +55,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -86,7 +87,7 @@ def __init__(self, max_position: int = 4096 * 32, rope_theta: float = 10000, use_sliding_window: bool = False, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -117,13 +118,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -159,7 +160,7 @@ def __init__( self, config: Qwen2Config, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -174,13 +175,13 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, use_sliding_window=use_sliding_window, - linear_method=linear_method, + quant_config=quant_config, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -221,7 +222,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -233,7 +234,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, linear_method) + Qwen2DecoderLayer(config, layer_idx, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2Model(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2Model(config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 59908bc9ef26a..8da89a2b7ba6c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -36,12 +36,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -58,18 +59,18 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -105,7 +106,7 @@ def __init__( Qwen2MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False) for idx in range(self.n_routed_experts) ]) @@ -114,13 +115,13 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, - linear_method=None) + quant_config=None) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False, ) else: @@ -186,7 +187,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -217,14 +218,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -260,7 +261,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -275,18 +276,18 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) if (config.num_experts is not None and (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, - linear_method=linear_method) + quant_config=quant_config) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -338,9 +339,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, - layer_idx, - linear_method=linear_method) + Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2MoeModel(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 3e6c2db6f3c65..3d4f4f700f867 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -28,11 +28,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,7 +47,7 @@ class StablelmMLP(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -54,7 +55,7 @@ def __init__(self, self.gate_up_proj = MergedColumnParallelLinear( config.hidden_size, [config.intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False) @@ -71,7 +72,7 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -109,11 +110,11 @@ def __init__(self, self.total_num_heads, self.total_num_key_value_heads, self.qkv_bias, - linear_method=linear_method) + quant_config=quant_config) self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_ndims, @@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.self_attn = StablelmAttention(config) - self.mlp = StablelmMLP(config, linear_method) + self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, linear_method) + StablelmDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = StableLMEpochModel(config, linear_method) + self.quant_config = quant_config + self.model = StableLMEpochModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index b90f3da141c2e..29d887b21032b 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -28,10 +28,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -79,13 +80,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=self.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=self.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -121,21 +122,21 @@ class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.c_fc = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=config.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=config.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) + quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -150,12 +151,11 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, - linear_method=linear_method) - self.mlp = Starcoder2MLP(config, linear_method=linear_method) + self.self_attn = Starcoder2Attention(config, quant_config=quant_config) + self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -192,7 +192,7 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -202,7 +202,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, linear_method=linear_method) + Starcoder2DecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -227,10 +227,10 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, linear_method=linear_method) + self.model = Starcoder2Model(config, quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 4e905390c2340..0fb2662b2f715 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -31,11 +31,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -52,17 +53,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -85,7 +86,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: @@ -112,13 +113,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -171,7 +172,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, ) @@ -179,7 +180,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -220,7 +221,7 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -236,7 +237,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, linear_method) + XverseDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = XverseModel(config, linear_method) + self.quant_config = quant_config + self.model = XverseModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() From aba47be3fef57f37196a91f899068506b9075f4f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 26 Apr 2024 15:47:45 -0700 Subject: [PATCH 132/413] [Misc] add RFC issue template (#4401) Co-authored-by: Simon Mo --- .github/ISSUE_TEMPLATE/750-RFC.yml | 49 ++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/750-RFC.yml diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml new file mode 100644 index 0000000000000..5382b124dcd79 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -0,0 +1,49 @@ +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference. +- type: textarea + attributes: + label: Motivation. + description: > + The motivation of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Proposed Change. + description: > + The proposed change of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! From 258a2c58d08fc7a242556120877a89404861fbce Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 26 Apr 2024 21:14:26 -0700 Subject: [PATCH 133/413] [Core] Introduce `DistributedGPUExecutor` abstract class (#4348) --- vllm/executor/distributed_gpu_executor.py | 114 ++++++++++++++++++++++ vllm/executor/ray_gpu_executor.py | 94 ++---------------- 2 files changed, 122 insertions(+), 86 deletions(-) create mode 100644 vllm/executor/distributed_gpu_executor.py diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py new file mode 100644 index 0000000000000..9dccfa4946391 --- /dev/null +++ b/vllm/executor/distributed_gpu_executor.py @@ -0,0 +1,114 @@ +from abc import abstractmethod +from typing import Any, Dict, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +class DistributedGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model(self, *args, **kwargs) -> SamplerOutput: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + +class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + + @abstractmethod + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + async def execute_model_async(self, *args, **kwargs) -> SamplerOutput: + all_outputs = await self._run_workers_async("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6f72babe14fd5..1082984828357 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,12 +3,12 @@ import pickle from collections import defaultdict from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -27,7 +27,7 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: assert (not self.speculative_config @@ -179,50 +179,9 @@ def collect_arg_helper_func(**kwargs): self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -244,23 +203,6 @@ def execute_model(self, output = all_outputs[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - def _run_workers( self, method: str, @@ -378,7 +320,7 @@ def _check_if_any_actor_is_dead(self): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -409,23 +351,3 @@ async def _run_workers_async( all_outputs = await asyncio.gather(*coros) return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output From 12628d3c787efd3483aaa74b5ae4175b28fd5805 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 26 Apr 2024 21:49:59 -0700 Subject: [PATCH 134/413] [Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (#4343) Co-authored-by: Woosuk Kwon --- csrc/ops.h | 7 ++- csrc/pybind.cpp | 3 +- csrc/quantization/fp8/fp8_cuda_kernels.cu | 25 ++++++++++- vllm/_custom_ops.py | 12 +++-- .../layers/fused_moe/fused_moe.py | 13 ++++-- .../model_executor/layers/quantization/fp8.py | 9 +++- vllm/model_executor/models/mixtral.py | 44 ++++++++++++++++--- 7 files changed, 95 insertions(+), 18 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index ff7a3de1a0a8c..03bb1e24dc68e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -146,7 +146,12 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); -void scaled_fp8_quant( +void static_scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant( torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5b16c5abc3ed..2250c7f69f0ab 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index c3337cede1282..2477051eb60d7 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel( } // namespace vllm -void scaled_fp8_quant( +void static_scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "scaled_fp8_quant_kernel", + [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scale.data_ptr(), + num_elems); + }); +} + +void dynamic_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 508d35656eb00..5ba104bada7ac 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, # fp8 -def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) - vllm_ops.scaled_fp8_quant(output, input, scale) + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + vllm_ops.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aed2c350bdd10..d37837a0b2ce8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -220,8 +220,9 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - B_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, assert sorted_token_ids.stride(0) == 1 if not use_fp8: - A_scale = None + assert A_scale is None assert B_scale is None else: - A, A_scale = ops.scaled_fp8_quant(A) + A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ @@ -318,6 +319,8 @@ def fused_moe( use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -434,6 +437,7 @@ def fused_moe( invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + a1_scale, w1_scale, topk_weights, topk_ids, @@ -451,6 +455,7 @@ def fused_moe( invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + a2_scale, w2_scale, topk_weights, topk_ids, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 39679834b545c..ba9f3149649c1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,6 +14,12 @@ class Fp8Config(QuantizationConfig): """Config class for FP8.""" + def __init__( + self, + activation_scheme: str = "dynamic", + ) -> None: + self.activation_scheme = activation_scheme + @classmethod def get_name(cls) -> str: return "fp8" @@ -35,7 +41,8 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": - return cls() + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme) def get_quant_method( self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7847df735ab44..c5dd1a63e2f7a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -105,6 +105,13 @@ def __init__( device="cuda", dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + # Scaling factors for FP8 weights self.ws_scale = nn.Parameter( torch.ones( @@ -115,12 +122,23 @@ def __init__( self.num_total_experts, device="cuda", dtype=torch.float32), requires_grad=False) if self.use_fp8 else None - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) + # Scaling factors for FP8 activations + need_act_scales = (self.use_fp8 + and quant_config.activation_scheme == "static") + self.as_scale = nn.Parameter( + torch.zeros(1, device="cuda", dtype=torch.float32), + requires_grad=False) if need_act_scales else None + self.a2s_scale = nn.Parameter( + torch.zeros(1, device="cuda", dtype=torch.float32), + requires_grad=False) if need_act_scales else None + + if need_act_scales: + set_weight_attrs(self.as_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2s_scale, { + "weight_loader": self.weight_loader, + }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): @@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + if "act_scale" in weight_name: + param_data[:] = param_data[:].max(loaded_weight) def process_weights_after_loading(self): if self.use_fp8: @@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: inplace=True, use_fp8=self.use_fp8, w1_scale=self.ws_scale, - w2_scale=self.w2s_scale) + w2_scale=self.w2s_scale, + a1_scale=self.as_scale, + a2_scale=self.a2s_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weights for the experts # (param_name, weight_name, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", + f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) From 8947bc3c156963dfc66e7ca1e4c436506ed6a512 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 27 Apr 2024 13:08:24 +0800 Subject: [PATCH 135/413] [Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355) --- requirements-common.txt | 1 + requirements-dev.txt | 1 - tests/entrypoints/test_openai_server.py | 16 +++++ vllm/entrypoints/openai/cli_args.py | 4 +- vllm/entrypoints/openai/protocol.py | 64 ++++++++++--------- vllm/entrypoints/openai/serving_chat.py | 55 ++++++++++++---- vllm/entrypoints/openai/serving_completion.py | 9 +-- vllm/entrypoints/openai/serving_engine.py | 18 +++--- 8 files changed, 113 insertions(+), 55 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 3cc7bba8f84db..e9db261c6aec9 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,6 +8,7 @@ py-cpuinfo transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. tokenizers >= 0.19.1 # Required for Llama 3. fastapi +openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index d9816828d007d..324039186142b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -21,7 +21,6 @@ pytest-rerunfailures pytest-shard httpx einops # required for MPT -openai requests ray peft diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 85a7ef464c032..68332228ace08 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -15,6 +15,7 @@ import requests # downloading lora to test lora requests from huggingface_hub import snapshot_download +from openai import BadRequestError from vllm.transformers_utils.tokenizer import get_tokenizer @@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +async def test_extra_fields(server, client: openai.AsyncOpenAI): + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + "extra_field": "0", + }], # type: ignore + temperature=0, + seed=0) + + assert "extra_forbidden" in exc_info.value.message + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5c361b4d184ee..16c5b6c08d37f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,7 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.serving_engine import LoRA +from vllm.entrypoints.openai.serving_engine import LoRAModulePath class LoRAParserAction(argparse.Action): @@ -18,7 +18,7 @@ def __call__(self, parser, namespace, values, option_string=None): lora_list = [] for item in values: name, path = item.split('=') - lora_list.append(LoRA(name, path)) + lora_list.append(LoRAModulePath(name, path)) setattr(namespace, self.dest, lora_list) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d9763d024eb83..0a949f9867754 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,14 +4,20 @@ from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, model_validator +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -class ErrorResponse(BaseModel): +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): object: str = "error" message: str type: str @@ -19,7 +25,7 @@ class ErrorResponse(BaseModel): code: int -class ModelPermission(BaseModel): +class ModelPermission(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") object: str = "model_permission" created: int = Field(default_factory=lambda: int(time.time())) @@ -34,7 +40,7 @@ class ModelPermission(BaseModel): is_blocking: bool = False -class ModelCard(BaseModel): +class ModelCard(OpenAIBaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) @@ -44,26 +50,26 @@ class ModelCard(BaseModel): permission: List[ModelPermission] = Field(default_factory=list) -class ModelList(BaseModel): +class ModelList(OpenAIBaseModel): object: str = "list" data: List[ModelCard] = Field(default_factory=list) -class UsageInfo(BaseModel): +class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 -class ResponseFormat(BaseModel): +class ResponseFormat(OpenAIBaseModel): # type must be "json_object" or "text" type: Literal["text", "json_object"] -class ChatCompletionRequest(BaseModel): +class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create - messages: List[Dict[str, str]] + messages: List[ChatCompletionMessageParam] model: str frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -204,7 +210,7 @@ def check_guided_decoding_count(cls, data): return data -class CompletionRequest(BaseModel): +class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: str @@ -343,19 +349,19 @@ def check_guided_decoding_count(cls, data): return data -class LogProbs(BaseModel): +class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None -class CompletionResponseChoice(BaseModel): +class CompletionResponseChoice(OpenAIBaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel): ) -class CompletionResponse(BaseModel): +class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -373,12 +379,12 @@ class CompletionResponse(BaseModel): usage: UsageInfo -class CompletionResponseStreamChoice(BaseModel): +class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel): ) -class CompletionStreamResponse(BaseModel): +class CompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel): usage: Optional[UsageInfo] = Field(default=None) -class ChatMessage(BaseModel): +class ChatMessage(OpenAIBaseModel): role: str content: str -class ChatCompletionResponseChoice(BaseModel): +class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionResponse(BaseModel): +class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel): usage: UsageInfo -class DeltaMessage(BaseModel): +class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None -class ChatCompletionResponseStreamChoice(BaseModel): +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionStreamResponse(BaseModel): +class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f6011b6fc4cb6..629dd929dc1af 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,11 @@ import codecs import time -from typing import AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, + Optional, Tuple, TypedDict, Union, final) from fastapi import Request +from openai.types.chat import (ChatCompletionContentPartParam, + ChatCompletionRole) from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( @@ -10,7 +13,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -20,20 +24,41 @@ logger = init_logger(__name__) +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict): + role: str + content: str + + class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], response_role: str, - lora_modules: Optional[List[LoRA]] = None, - chat_template=None): + lora_modules: Optional[List[LoRAModulePath]] = None, + chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) + def _parse_chat_message_content( + self, + role: ChatCompletionRole, + content: Optional[Union[str, + Iterable[ChatCompletionContentPartParam]]], + ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: + if content is None: + return [], [] + if isinstance(content, str): + return [ConversationMessage(role=role, content=content)], [] + + # To be implemented: https://github.com/vllm-project/vllm/pull/3467 + # To be implemented: https://github.com/vllm-project/vllm/pull/4200 + raise NotImplementedError("Complex input not supported yet") + async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request ) -> Union[ErrorResponse, AsyncGenerator[str, None], @@ -52,10 +77,19 @@ async def create_chat_completion( return error_check_ret try: + conversation: List[ConversationMessage] = [] + + for m in request.messages: + messages, _ = self._parse_chat_message_content( + m["role"], m["content"]) + + conversation.extend(messages) + prompt = self.tokenizer.apply_chat_template( - conversation=request.messages, + conversation=conversation, tokenize=False, - add_generation_prompt=request.add_generation_prompt) + add_generation_prompt=request.add_generation_prompt, + ) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) @@ -105,9 +139,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str - ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" @@ -252,7 +285,7 @@ async def chat_completion_full_generator( model_name = self.served_model_names[0] created_time = int(time.time()) - final_res: RequestOutput = None + final_res: Optional[RequestOutput] = None async for res in result_generator: if await raw_request.is_disconnected(): @@ -317,7 +350,7 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template): + def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer if chat_template is not None: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 211b2e0424c3e..7904bb698c45a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -11,7 +11,8 @@ CompletionResponseStreamChoice, CompletionStreamResponse, LogProbs, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules: Optional[List[LoRA]] = None): + lora_modules: Optional[List[LoRAModulePath]] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules) @@ -84,7 +85,7 @@ async def create_completion(self, request: CompletionRequest, created_time = int(time.time()) # Schedule the request and get the result generator. - generators = [] + generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) @@ -148,7 +149,7 @@ async def create_completion(self, request: CompletionRequest, num_prompts=len(prompts)) # Non-streaming response - final_res_batch: RequestOutput = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 31da27a447c6c..3d5ed328b9d19 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -22,17 +22,15 @@ @dataclass -class LoRA: +class LoRAModulePath: name: str local_path: str class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, - served_model_names: List[str], - lora_modules=Optional[List[LoRA]]): + def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]]): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: @@ -158,7 +156,9 @@ def create_streaming_error_response( }) return json_str - async def _check_model(self, request) -> Optional[ErrorResponse]: + async def _check_model( + self, request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None if request.model in [lora.lora_name for lora in self.lora_requests]: @@ -168,14 +168,16 @@ async def _check_model(self, request) -> Optional[ErrorResponse]: err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora(self, request) -> Optional[LoRARequest]: + def _maybe_get_lora( + self, request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None for lora in self.lora_requests: if request.model == lora.lora_name: return lora # if _check_model has been called earlier, this will be unreachable - raise ValueError("The model `{request.model}` does not exist.") + raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( self, From 87f545ba6fdbbbe9813736fc398874563e2604a7 Mon Sep 17 00:00:00 2001 From: Roy Date: Sat, 27 Apr 2024 13:45:02 +0800 Subject: [PATCH 136/413] [Misc] Fix logger format typo (#4396) --- vllm/engine/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index d3560f5fefff1..eb54f5641171e 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -230,8 +230,8 @@ def log(self, stats: Stats) -> None: "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " - "Pending: %d reqs, GPU KV cache usage: %.1f%, " - "CPU KV cache usage: %.1f%", + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%", prompt_throughput, generation_throughput, stats.num_running, From 18d23f642af9c56f45ccf684a22f386fc54c6ead Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Sat, 27 Apr 2024 02:37:40 -0400 Subject: [PATCH 137/413] [ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406) --- vllm/attention/backends/rocm_flash_attn.py | 53 +++++++++----------- vllm/attention/ops/triton_flash_attention.py | 24 ++++----- 2 files changed, 36 insertions(+), 41 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7c5863a030ff5..934acea0a3d60 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -253,36 +253,31 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_triton_flash_attn or self.use_naive_attn: + if self.use_triton_flash_attn: + out, _ = self.attn_func( + query, + key, + value, + None, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, + True, + self.scale, + ) + elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - if self.use_naive_attn: - out = self.attn_func( - query, - key, - value, - prefill_meta.prompt_lens, - self.scale, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - out, _ = self.attn_func( - query, - key, - value, - None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, - True, - self.scale, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + out = self.attn_func( + query, + key, + value, + prefill_meta.prompt_lens, + self.scale, + ) else: out = self.attn_func( q=query, @@ -295,8 +290,10 @@ def forward( softmax_scale=self.scale, causal=True, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention output[:num_prefill_tokens] = PagedAttention.forward_prefix( diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index e160411859f0b..1147664183ff1 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -293,7 +293,7 @@ def _attn_fwd_inner( num_warps=4, ), ], - key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit def attn_fwd( @@ -330,8 +330,8 @@ def attn_fwd( philox_seed, philox_offset_base, encoded_softmax, - hq, - hk, + HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, @@ -403,7 +403,7 @@ def attn_fwd( # We still need to write 0s to the result # tl.store(O_block_ptr, # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q # + offs_m # We store inf to LSE, not -inf because in the bwd pass, # we subtract this @@ -414,11 +414,9 @@ def attn_fwd( # TODO: Should dropout and return encoded softmax be handled here? return - is_mqa = hq != hk - if is_mqa: # noqa: SIM108 - off_h_k = off_h_q % hk - else: - off_h_k = off_h_q + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: @@ -471,7 +469,7 @@ def attn_fwd( bias_ptr = None if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base \ - + (off_z * hq + off_h_q) \ + + (off_z * HQ + off_h_q) \ * seqlen_q * seqlen_k else: batch_philox_offset = 0 @@ -624,7 +622,7 @@ def attn_fwd( z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last # few rows. This is only true for the last M block. For others, # overflow_size will be -ve @@ -784,8 +782,8 @@ def forward( philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - hq=nheads_q, - hk=nheads_k, + HQ=nheads_q, + HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, From eefeb16464af5f3a61e3052d1a4128480bff7f47 Mon Sep 17 00:00:00 2001 From: Austin Veselka <50646302+FurtherAI@users.noreply.github.com> Date: Sat, 27 Apr 2024 02:03:48 -0500 Subject: [PATCH 138/413] [Kernel] Full Tensor Parallelism for LoRA Layers (#3524) Co-authored-by: Antoni Baum --- csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu | 1 + csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu | 1 + csrc/punica/bgmv/bgmv_config.h | 78 +++++++ csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu | 1 + csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu | 1 + csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu | 1 + csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu | 1 + csrc/punica/bgmv/bgmv_impl.cuh | 5 +- csrc/punica/bgmv/generator.py | 1 + csrc/punica/punica_ops.cc | 2 +- tests/lora/test_layers.py | 29 ++- tests/lora/test_punica.py | 51 ++++- vllm/config.py | 1 + vllm/engine/arg_utils.py | 10 + vllm/lora/fully_sharded_layers.py | 262 ++++++++++++++++++++++++ vllm/lora/layers.py | 243 +++++++++++++--------- vllm/lora/models.py | 6 +- vllm/lora/punica.py | 43 ++++ vllm/lora/utils.py | 60 +++++- 19 files changed, 686 insertions(+), 111 deletions(-) create mode 100644 vllm/lora/fully_sharded_layers.py diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu index c642e94925fe5..86846c274c90f 100644 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu index 0607cebfeac40..de39c3121f5d3 100644 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index fec484d693055..19c058cacfbc4 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py +// Used for defining kernels going from the variety of +// dim in to the narrow dim out + // Using it for the fully sharded column + // parallel LoRA A which splits the rank dim +#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, 128, narrow) \ + f(in_T, out_T, W_T, 256, narrow) \ + f(in_T, out_T, W_T, 512, narrow) \ + f(in_T, out_T, W_T, 640, narrow) \ + f(in_T, out_T, W_T, 768, narrow) \ + f(in_T, out_T, W_T, 1024, narrow) \ + f(in_T, out_T, W_T, 1152, narrow) \ + f(in_T, out_T, W_T, 1280, narrow) \ + f(in_T, out_T, W_T, 1536, narrow) \ + f(in_T, out_T, W_T, 1728, narrow) \ + f(in_T, out_T, W_T, 1792, narrow) \ + f(in_T, out_T, W_T, 2048, narrow) \ + f(in_T, out_T, W_T, 2304, narrow) \ + f(in_T, out_T, W_T, 2560, narrow) \ + f(in_T, out_T, W_T, 2752, narrow) \ + f(in_T, out_T, W_T, 2816, narrow) \ + f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3456, narrow) \ + f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 4096, narrow) \ + f(in_T, out_T, W_T, 4608, narrow) \ + f(in_T, out_T, W_T, 5120, narrow) \ + f(in_T, out_T, W_T, 5504, narrow) \ + f(in_T, out_T, W_T, 5632, narrow) \ + f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6848, narrow) \ + f(in_T, out_T, W_T, 6912, narrow) \ + f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 8192, narrow) \ + f(in_T, out_T, W_T, 9216, narrow) \ + f(in_T, out_T, W_T, 10240, narrow) \ + f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 12288, narrow) \ + f(in_T, out_T, W_T, 13696, narrow) \ + f(in_T, out_T, W_T, 13824, narrow) \ + f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 15360, narrow) \ + f(in_T, out_T, W_T, 16384, narrow) \ + f(in_T, out_T, W_T, 20480, narrow) \ + f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 24576, narrow) \ + f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 32000, narrow) \ + f(in_T, out_T, W_T, 32256, narrow) \ + f(in_T, out_T, W_T, 32512, narrow) \ + f(in_T, out_T, W_T, 32768, narrow) \ + f(in_T, out_T, W_T, 33024, narrow) \ + f(in_T, out_T, W_T, 36864, narrow) \ + f(in_T, out_T, W_T, 43264, narrow) \ + f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 64000, narrow) \ + f(in_T, out_T, W_T, 64256, narrow) \ + f(in_T, out_T, W_T, 64512, narrow) \ + f(in_T, out_T, W_T, 102400, narrow) \ + f(in_T, out_T, W_T, 102656, narrow) \ + f(in_T, out_T, W_T, 102912, narrow) \ + f(in_T, out_T, W_T, 128000, narrow) \ + f(in_T, out_T, W_T, 128256, narrow) \ + f(in_T, out_T, W_T, 128512, narrow) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + + // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ @@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ + f(in_T, out_T, W_T, 8, 64) \ + f(in_T, out_T, W_T, 16, 64) \ + f(in_T, out_T, W_T, 32, 64) \ + f(in_T, out_T, W_T, 64, 64) + // clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu index f1db6df5f7338..d225a1eaa82b0 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu index c01ddd009d74e..b37d288a75561 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu index f45183ffd3486..a1ab2deecbabf 100644 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu index 4097743488087..0b35bf5699898 100644 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 995de26e8bada..dad8805c750cb 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, constexpr int tz = 4; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if constexpr (feat_in < feat_out) { + if constexpr (feat_in <= feat_out) { static_assert(feat_in % vec_size == 0); constexpr int tx = feat_in / vec_size; @@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t num_layers, int64_t layer_idx, float scale); +#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ + INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) + #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index 9bf7f6358880f..972df5a7208c2 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -10,6 +10,7 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) """.lstrip() # noqa: E501 for input_dtype in DTYPES: diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index a1eaa90e85f27..8797fde85744a 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) FOR_BGMV_WIDE_NARROW(CASE, _, _, _) + FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) #undef CASE #undef CASE_ONESIDE default: return false; } - return true; } diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 1616fdfd4cff9..0eb04f4ccd133 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,6 +8,10 @@ import torch.nn.functional as F from vllm.config import LoRAConfig +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, @@ -524,13 +528,16 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: +def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_random_linear_parallel_layer(): @@ -540,14 +547,17 @@ def create_random_linear_parallel_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = RowParallelLinearWithLoRA(linear) + lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard + else RowParallelLinearWithShardedLoRA(linear)) else: linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = ColumnParallelLinearWithLoRA(linear) + lora_linear = (ColumnParallelLinearWithLoRA(linear) + if not fully_shard else + ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) return linear, lora_linear @@ -629,13 +639,16 @@ def create_random_linear_parallel_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_column_parallel_packed_layer(): @@ -644,7 +657,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedColumnParallelLinearWithLoRA(linear) + lora_linear = (MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard else + MergedColumnParallelLinearWithShardedLoRA(linear)) elif repeats == 3: linear = QKVParallelLinear(4096, 64, @@ -652,7 +667,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedQKVParallelLinearWithLora(linear) + lora_linear = (MergedQKVParallelLinearWithLora(linear) + if not fully_shard else + MergedQKVParallelLinearWithShardedLora(linear)) else: linear = QKVParallelLinear(4096, 64, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f3b9bd5912967..fd2a1b75f460c 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -34,11 +34,14 @@ def _lora_ref_impl( for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): xi = x[i].unsqueeze(0).to(torch.float32) wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + if wb_T_all is not None: + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, + -2).to(torch.float32) tmp = xi @ wa y_stage_1[i] = tmp.squeeze(0) - y_final[i] += (tmp @ wb).squeeze(0) * s + y_final[i] += ((tmp @ wb).squeeze(0) * + s if wb_T_all is not None else y_stage_1[i]) return y_final, y_stage_1 @@ -91,12 +94,56 @@ def _lora_ref_impl( 128000, 128256, ] +H2 = [64] + H2 +R = [1, 2, 4] SEED = [0xabcdabcd987] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("r", R) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_a_extra_shapes(dtype_str, h1, r, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + bs = 32 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, r, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl( + y_ref, + x, + wa_T_all, + None, + indices, + layer_idx, + 1.0, + ) + + y_our = y.clone() + punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) + + assert_close(y_ref, y_our) + + @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h2", H2) diff --git a/vllm/config.py b/vllm/config.py index 887a73d9462dc..aedb589247646 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -862,6 +862,7 @@ def __repr__(self) -> str: class LoRAConfig: max_lora_rank: int max_loras: int + fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6a6ac49ae3211..bd6437ee44c28 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -52,6 +52,7 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None @@ -376,6 +377,14 @@ def add_cli_args( help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -509,6 +518,7 @@ def create_engine_config(self, ) -> EngineConfig: lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 0000000000000..1720566840bb1 --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,262 @@ +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA) +from vllm.lora.punica import bgmv, dispatch_bgmv_low_level + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs['lora_config'].fully_sharded_loras) + + return dec + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + bgmv(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + # now have column partitioned output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +def _mcp_apply_weights(x, bias, layer): + """ + MergedColumnParallelLinearWithShardedLoRA and + QKVParallelLinearWithShardedLora share the same + LoRa weight application method. + + The main difference is the step by shard_size for lora_b which can + vary for QKVParallelLinearWithShardedLora but is constant for + MergedColumnParallelLinearWithShardedLoRA. + """ + # expecting 2 for column parallel and 3 for qkv + n = len(layer.lora_a_stacked) + output = layer.base_layer.linear_method.apply_weights( + layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device) + for idx in range(n): + bgmv(buffers[idx], x, layer.lora_a_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0) + + buffers = tensor_model_parallel_all_gather(buffers) + left_offset = 0 + for idx in range(n): + shard_size = layer.lora_b_stacked[idx].shape[2] + dispatch_bgmv_low_level(output, buffers[idx], + layer.lora_b_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0, + left_offset, shard_size) + left_offset += shard_size + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] + for i in range(2) + ] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply_weights(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[i][:, start_idx[i]:start_idx[i] + + shard_size[i]] if lora_a[i] is not None else None + for i in range(3) + ] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply_weights(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0, + start_idx, shard_size) + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4eaf73fbcfda4..b3609666b2ec7 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,8 +1,7 @@ # pylint: disable=unused-argument -import inspect import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +15,7 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_gather) +from vllm.distributed.utils import divide from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -23,7 +23,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + VocabParallelEmbedding) if TYPE_CHECKING: pass @@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: raise ValueError(f"Unsupported base layer: {base_layer}") +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True + condition = (not kwargs['lora_config'].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -130,6 +145,14 @@ def __post_init__(self): class BaseLayerWithLoRA(nn.Module): + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + """Slice lora b if splitting with tensor parallelism.""" + ... + def create_lora_weights( self, max_loras: int, @@ -317,6 +340,11 @@ def can_replace_layer(cls, source_layer: nn.Module, class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + + LoRA B is sliced for tensor parallelism. + """ def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() @@ -331,10 +359,15 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -357,6 +390,17 @@ def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + def set_lora( self, index: int, @@ -365,12 +409,11 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) @@ -426,6 +469,7 @@ def forward(self, input_): return output, output_bias @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -451,6 +495,7 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices and self.base_layer.output_sizes[0] @@ -459,12 +504,17 @@ def create_lora_weights( "LoRAColumnParallelLinear2Slice requires 2 slices with " "the same size.") self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = tuple( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -489,6 +539,18 @@ def reset_lora(self, index: int): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + return lora_a + + def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + shard_size = self.output_dim + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = [ + lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] + ] + return lora_b + def set_lora( self, index: int, @@ -499,13 +561,8 @@ def set_lora( self.reset_lora(index) if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[0][:, - start_idx:end_idx], lora_b[1][:, - start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -536,6 +593,7 @@ def apply(self, x: torch.Tensor, return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -627,21 +685,25 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.base_layer.head_size) - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) # q, k, v self.lora_a_stacked = ( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -649,7 +711,7 @@ def create_lora_weights( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -657,7 +719,7 @@ def create_lora_weights( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -705,6 +767,25 @@ def reset_lora(self, index: int): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + return lora_a + + def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + lora_b = [lora_b_q, lora_b_k, lora_b_v] + return lora_b + def set_lora( self, index: int, @@ -715,40 +796,24 @@ def set_lora( self.reset_lora(index) if self.tp_size > 1: - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - else: - if lora_b[0] is not None: - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if lora_b[1] is not None: - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if lora_b[2] is not None: - self.lora_b_stacked[2][ - index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( - lora_b[2].T, non_blocking=True) + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + if lora_b[0] is not None: + lora_b_q = lora_b[0] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -777,6 +842,7 @@ def apply(self, x: torch.Tensor, return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -798,6 +864,8 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( ( max_loras, @@ -808,11 +876,16 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) + tp_size = get_tensor_model_parallel_world_size() + lora_b_output_size_per_partition = ( + self.output_size if not lora_config.fully_sharded_loras else + divide(self.output_size, tp_size)) + self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.output_size, + lora_b_output_size_per_partition, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -826,6 +899,17 @@ def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + def set_lora( self, index: int, @@ -834,12 +918,10 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.base_layer.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.input_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -915,6 +997,7 @@ def weight(self): self.base_layer, "weight") else self.base_layer.qweight @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -1096,37 +1179,3 @@ def can_replace_layer(cls, source_layer: nn.Module, model_config: Optional[PretrainedConfig]) -> bool: # Special handling for the LogitsProcessor. return False - - -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { - cls - for cls in globals().values() if inspect.isclass(cls) - and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA -} - - -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - packed_modules_list: List, - model_config: Optional[PretrainedConfig] = None) -> nn.Module: - for lora_cls in _all_lora_classes: - if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list, - model_config): - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret - return layer - - -def from_layer_logits_processor( - layer: LogitsProcessor, - lm_head: ParallelLMHead, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, -) -> LogitsProcessorWithLoRA: - ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6a077e9b0c755..50d7e9133e0e8 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -11,10 +11,10 @@ from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_logits_processor) +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + parse_fine_tuned_lora_name, replace_submodule) from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index fc74269e55876..c87bed54726fc 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -49,6 +49,49 @@ def bgmv( punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) +def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, indicies: torch.LongTensor, + layer_idx: int, scale: float, y_offset: int, + y_slice_size: int): + """ + Same as `bgmv` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of + all of the transposed LoRA matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + punica_kernels.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + x.size(1), + y_slice_size, + y_offset, + ) + + def add_lora(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 39e08f0412e4a..9942a5fd40dec 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,11 +1,69 @@ -from typing import Tuple +from typing import List, Optional, Set, Tuple, Type from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + def replace_submodule(model: nn.Module, module_name: str, new_module: nn.Module) -> nn.Module: From 3da24c2df735354ccb463650c29cca8ce506fa07 Mon Sep 17 00:00:00 2001 From: Caio Mendes Date: Sat, 27 Apr 2024 07:08:15 -0300 Subject: [PATCH 139/413] [Model] Phi-3 4k sliding window temp. fix (#4380) --- vllm/core/block_manager_v1.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index be093922b84f2..1fac2636e86fa 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" +import math from abc import ABC, abstractmethod from itertools import count, takewhile from os.path import commonprefix @@ -220,9 +221,9 @@ def __init__( self.block_sliding_window = None if sliding_window is not None: - assert sliding_window % block_size == 0, (sliding_window, - block_size) - self.block_sliding_window = sliding_window // block_size + # Round up to nearest block size to regularize sliding window + # allocation sizes. + self.block_sliding_window = math.ceil(sliding_window / block_size) self.watermark = watermark assert watermark >= 0.0 From 7134303cbbb7c82cdfcb0c87d59bb48fe6ad642b Mon Sep 17 00:00:00 2001 From: Roy Date: Sat, 27 Apr 2024 19:30:08 +0800 Subject: [PATCH 140/413] [Bugfix][Core] Fix get decoding config from ray (#4335) --- tests/async_engine/test_async_llm_engine.py | 2 + tests/async_engine/test_openapi_server_ray.py | 157 ++++++++++++++++++ vllm/engine/async_llm_engine.py | 10 +- vllm/engine/llm_engine.py | 4 + vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- 6 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 tests/async_engine/test_openapi_server_ray.py diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index cb125a7bfec30..b69cdc0a21409 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -91,4 +91,6 @@ async def test_new_requests_event(): assert engine.engine.step_calls == old_step_calls + 1 engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True) + assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None + assert engine.get_decoding_config() is not None diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py new file mode 100644 index 0000000000000..4b97af88012b9 --- /dev/null +++ b/tests/async_engine/test_openapi_server_ray.py @@ -0,0 +1,157 @@ +# imports for guided decoding tests +import os +import subprocess +import sys +import time + +import openai # use the official client for correctness check +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray +import requests + +MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@ray.remote(num_gpus=1) +class ServerRunner: + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--enforce-eager", + "--engine-use-ray" + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest.mark.asyncio +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + + +@pytest.mark.asyncio +async def test_single_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + + +@pytest.mark.asyncio +async def test_single_chat_session(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert chat_completion.id is not None + assert chat_completion.choices is not None and len( + chat_completion.choices) == 1 + assert chat_completion.choices[0].message is not None + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.top_logprobs is not None + assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 89ee3f0db491c..7c1eb2ecbe550 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,7 @@ from transformers import PreTrainedTokenizer -from vllm.config import ModelConfig +from vllm.config import DecodingConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -697,6 +697,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_decoding_config.remote( # type: ignore + ) + else: + return self.engine.get_decoding_config() + async def do_log_stats(self) -> None: if self.engine_use_ray: await self.engine.do_log_stats.remote() # type: ignore diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 741d3bcd80890..292504631b06d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -467,6 +467,10 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return self.scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 629dd929dc1af..5ed042ef386ea 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -101,7 +101,7 @@ async def create_chat_completion( request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - decoding_config = self.engine.engine.decoding_config + decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7904bb698c45a..6a7f29c4c96f2 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -89,7 +89,7 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - decoding_config = self.engine.engine.decoding_config + decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( From dfea17314827845d55dabb03ebe905f58e6682e4 Mon Sep 17 00:00:00 2001 From: Ruoyu Qin Date: Sun, 28 Apr 2024 00:48:37 +0800 Subject: [PATCH 141/413] [Bugfix] Abort requests when the connection to /v1/completions is interrupted (#4363) --- .../test_merge_async_iterators.py | 41 +++++++++++++++++++ vllm/utils.py | 17 +++++--- 2 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/async_engine/test_merge_async_iterators.py diff --git a/tests/async_engine/test_merge_async_iterators.py b/tests/async_engine/test_merge_async_iterators.py new file mode 100644 index 0000000000000..ea453526c77f8 --- /dev/null +++ b/tests/async_engine/test_merge_async_iterators.py @@ -0,0 +1,41 @@ +import asyncio +from typing import AsyncIterator, Tuple + +import pytest + +from vllm.utils import merge_async_iterators + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + + async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + iterators = [mock_async_iterator(i) for i in range(3)] + merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + *iterators) + + async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e diff --git a/vllm/utils.py b/vllm/utils.py index 76c2fc66e47c3..88447878f1706 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -225,11 +225,18 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item + try: + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + # NOTE: Pass the error msg in cancel() + # when only Python 3.9+ is supported. + task.cancel() + raise e await asyncio.gather(*_tasks) return consumer() From 81661da7b2d446cff7065fd6b34f1b7051098d24 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 27 Apr 2024 09:52:46 -0700 Subject: [PATCH 142/413] [BugFix] Fix `min_tokens` when `eos_token_id` is None (#4389) Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com> --- tests/samplers/test_sampler.py | 9 +++------ vllm/engine/llm_engine.py | 5 +++-- vllm/model_executor/layers/sampler.py | 14 ++++++-------- vllm/sampling_params.py | 4 ++-- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 6f2145f8cdcf4..7859f0b21812f 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, *, - stop_token_ids: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, @@ -216,7 +216,7 @@ def create_sampling_params(min_tokens, # requesting prompt_logprobs changes the structure of `logits` prompt_logprobs=prompt_logprobs, ) - sampling_params.eos_token_id = eos_token_id + sampling_params.all_stop_token_ids.add(eos_token_id) return sampling_params def create_sequence_data(num_input=3, num_generated=0): @@ -461,10 +461,7 @@ def run_test_case(*, for logits_idx, (should_penalize, sampling_params) in enumerate( zip(expected_penalization, sampling_params_per_row)): - tokens_to_check = [sampling_params.eos_token_id] - if sampling_params.stop_token_ids: - tokens_to_check.extend(sampling_params.stop_token_ids) - tokens_to_check = set(tokens_to_check) + tokens_to_check = sampling_params.all_stop_token_ids if should_penalize: for token_id in tokens_to_check: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 292504631b06d..7e9553d839cea 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -431,9 +431,10 @@ def add_request( # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens + # Add the eos token id into the sampling_params to support min_tokens # processing - sampling_params.eos_token_id = seq.eos_token_id + if seq.eos_token_id is not None: + sampling_params.all_stop_token_ids.add(seq.eos_token_id) sampling_params.update_from_generation_config( self.generation_config_fields) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ffa8227cc4ed..4ef25edecfd24 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -169,19 +169,17 @@ def _apply_min_tokens_penalty( start_idx = sample_indices[0] min_tokens = sampling_params.min_tokens - if min_tokens > 0: + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: seqs_to_penalize = [] - for i, seq_id in enumerate(seq_ids): + for j, seq_id in enumerate(seq_ids): seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids) < min_tokens: - seqs_to_penalize.append(i) + seqs_to_penalize.append(j) if seqs_to_penalize: # convert to the index into logits - seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] - # use set() to remove any duplicates - token_ids_to_penalize = set(sampling_params.stop_token_ids + - [sampling_params.eos_token_id]) + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] # itertools.product pairs each seq index with every token id logits_to_penalize.extend( itertools.product(seqs_to_penalize, token_ids_to_penalize)) @@ -645,7 +643,7 @@ def _sample( Returns: (next_token_ids, parent_seq_ids) for each seq group in a batch. If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. + sampled_token_ids_tensor: A tensor of sampled token ids. """ return _sample_with_torch( probs, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dc0e60344d858..0ed6a01a62212 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -185,8 +185,8 @@ def __init__( self.top_k = -1 self.min_p = 0.0 self._verify_greedy_sampling() - # injected by the engine - self.eos_token_id = None + # eos_token_id is added to this by the engine + self.all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: if self.n < 1: From d6e520e1700f78de2d5efdb8607a76cbab61182e Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Sat, 27 Apr 2024 09:59:55 -0700 Subject: [PATCH 143/413] [Core] Support offline use of local cache for models (#4374) Signed-off-by: Prashant Gupta Co-authored-by: Travis Johnson --- tests/model_executor/weight_utils.py | 30 +++++++++- vllm/model_executor/model_loader/loader.py | 5 +- .../model_loader/weight_utils.py | 59 +++++++++++-------- vllm/transformers_utils/tokenizer.py | 2 + 4 files changed, 69 insertions(+), 27 deletions(-) diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index b0086dd7a7d71..c8b9bed691bba 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -1,9 +1,12 @@ import os +import tempfile import huggingface_hub.constants import pytest +from huggingface_hub.utils import LocalEntryNotFoundError -from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, enable_hf_transfer) def test_hf_transfer_auto_activation(): @@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation(): HF_TRANFER_ACTIVE) +def test_download_weights_from_hf(): + with tempfile.TemporaryDirectory() as tmpdir: + # assert LocalEntryNotFoundError error is thrown + # if offline is set and model is not cached + huggingface_hub.constants.HF_HUB_OFFLINE = True + with pytest.raises(LocalEntryNotFoundError): + download_weights_from_hf("facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) + + # download the model + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) + + # now it should work offline + huggingface_hub.constants.HF_HUB_OFFLINE = True + assert download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) is not None + + if __name__ == "__main__": test_hf_transfer_auto_activation() + test_download_weights_from_hf() diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ad80243019a65..70e64167f8698 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Tuple, Type +import huggingface_hub import torch from torch import nn @@ -131,7 +132,9 @@ def _maybe_download_from_modelscope( model_path = snapshot_download( model_id=model, cache_dir=self.load_config.download_dir, - revision=revision) + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ) else: model_path = model return model_path diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c0905b9051314..c1abde9af7701 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig, if not is_local: # Download the config files. with get_lock(model_name_or_path, load_config.download_dir): - hf_folder = snapshot_download(model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=load_config.download_dir, - tqdm_class=DisabledTqdm) + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) else: hf_folder = model_name_or_path @@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig, return quant_cls.from_config(config) -def download_weights_from_hf(model_name_or_path: str, - cache_dir: Optional[str], - allow_patterns: List[str], - revision: Optional[str] = None) -> str: +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> str: """Download model weights from Hugging Face Hub. - + Args: model_name_or_path (str): The model name or path. cache_dir (Optional[str]): The cache directory to store the model @@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str, Returns: str: The path to the downloaded model weights. """ - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision) + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) return hf_folder diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2fcddc3bea5ab..fa4693cb7dac1 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,6 +1,7 @@ import os from typing import Optional, Union +import huggingface_hub from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -76,6 +77,7 @@ def get_tokenizer( model_id=tokenizer_name, cache_dir=download_dir, revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, # Ignore weights - we only need the tokenizer. ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = tokenizer_path From ba4be44c32761d30f1e17656b863d2cc078af9e4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 27 Apr 2024 11:17:45 -0700 Subject: [PATCH 144/413] [BugFix] Fix return type of executor execute_model methods (#4402) --- vllm/executor/cpu_executor.py | 2 +- vllm/executor/distributed_gpu_executor.py | 7 ++++--- vllm/executor/executor_base.py | 2 +- vllm/executor/gpu_executor.py | 2 +- vllm/executor/neuron_executor.py | 2 +- vllm/executor/ray_gpu_executor.py | 2 +- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index aa810f9743395..e4436b2144bd3 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -109,7 +109,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 9dccfa4946391..4c922ef63ee04 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor @@ -52,7 +52,7 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def execute_model(self, *args, **kwargs) -> SamplerOutput: + def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: all_outputs = self._run_workers("execute_model", driver_args=args, driver_kwargs=kwargs) @@ -105,7 +105,8 @@ async def _run_workers_async( """Runs the given method on all workers.""" raise NotImplementedError - async def execute_model_async(self, *args, **kwargs) -> SamplerOutput: + async def execute_model_async(self, *args, + **kwargs) -> List[SamplerOutput]: all_outputs = await self._run_workers_async("execute_model", driver_args=args, driver_kwargs=kwargs) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 1838c34be2fda..c36aa18fb25bb 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -112,7 +112,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index d2c60a3b68e14..5ac62f02b99c7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -163,7 +163,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 5a137d1bdcb3b..f406287f3c1d8 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -84,7 +84,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, ) return output diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 1082984828357..b6bcda4e6b18c 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -188,7 +188,7 @@ def execute_model(self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int = 0) -> SamplerOutput: + num_lookahead_slots: int = 0) -> List[SamplerOutput]: all_outputs = self._run_workers( "execute_model", driver_kwargs={ From 4ea1f9678dd93f02424ab3de2149f83a490e6c6f Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 27 Apr 2024 14:35:33 -0400 Subject: [PATCH 145/413] [BugFix] Resolved Issues For LinearMethod --> QuantConfig (#4418) --- vllm/model_executor/models/bloom.py | 1 - vllm/model_executor/models/falcon.py | 1 - vllm/model_executor/models/gpt2.py | 1 - vllm/model_executor/models/gpt_bigcode.py | 1 - vllm/model_executor/models/gpt_j.py | 1 - vllm/model_executor/models/gpt_neox.py | 1 - vllm/model_executor/models/mpt.py | 1 - vllm/model_executor/models/opt.py | 1 - vllm/model_executor/models/phi.py | 1 - vllm/model_executor/models/starcoder2.py | 1 - 10 files changed, 10 deletions(-) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index b425af4863c36..1d7e5d2517c72 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -139,7 +139,6 @@ def __init__( 4 * hidden_size, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 4be1f064cdd3e..08dd69923dc6d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -203,7 +203,6 @@ def __init__( bias=config.bias, skip_bias_add=True, quant_config=quant_config) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index ac1dce6dec8a6..75eaebf0dbd15 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -107,7 +107,6 @@ def __init__( bias=True, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index e52ac679f5d03..d057fd928fdb5 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -128,7 +128,6 @@ def __init__( bias=True, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 287f4186f7469..8d7fe8a5beef7 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -120,7 +120,6 @@ def __init__( hidden_size, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index cbc5115bd377b..bab563b9c5a39 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -119,7 +119,6 @@ def __init__( config.hidden_size, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8c5e7e77c9306..6fa5c5bd3014a 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -146,7 +146,6 @@ def __init__( bias=not config.no_bias, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn("gelu", quant_config, intermediate_size) self.down_proj = RowParallelLinear( intermediate_size, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 838a2f0adc4d1..336f765ababaa 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -130,7 +130,6 @@ def __init__( bias=config.enable_bias, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.activation_fn = get_act_fn(config.activation_function, quant_config, config.ffn_dim) self.fc2 = RowParallelLinear( diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 7a9b8dcd6a509..4a45879201af3 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -142,7 +142,6 @@ def __init__(self, config.hidden_size, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 29d887b21032b..33998e2aad5c5 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -136,7 +136,6 @@ def __init__(self, bias=config.use_bias, quant_config=quant_config, ) - quant_config = getattr(quant_config, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) From 9c7306ac114da3e31a5ff040a76f6c640354cce8 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Sun, 28 Apr 2024 18:58:30 +0800 Subject: [PATCH 146/413] [Misc] fix typo in llm_engine init logging (#4428) --- vllm/engine/llm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7e9553d839cea..53a680580390f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -101,7 +101,7 @@ def __init__( "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " "max_seq_len=%d, download_dir=%r, load_format=%s, " - "tensor_parallel_size=%d, disable_custom_all_reduce=%s" + "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d)", From bf480c53027cb009427581af0100de75ac7a2e5f Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Mon, 29 Apr 2024 01:59:33 +0300 Subject: [PATCH 147/413] Add more Prometheus metrics (#2764) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Robert Shaw --- examples/production_monitoring/grafana.json | 283 ++++++++++++++++++++ requirements-common.txt | 1 + vllm/core/scheduler.py | 2 +- vllm/engine/llm_engine.py | 171 ++++++++---- vllm/engine/metrics.py | 221 +++++++++++---- vllm/sequence.py | 18 +- 6 files changed, 582 insertions(+), 114 deletions(-) diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json index 071f134c6e5e0..5e9bd5bd03869 100644 --- a/examples/production_monitoring/grafana.json +++ b/examples/production_monitoring/grafana.json @@ -873,6 +873,289 @@ ], "title": "Cache Utilization", "type": "timeseries" + }, + { + "type": "heatmap", + "title": "Request Prompt Length", + "description": "Heatmap of request prompt length", + "gridPos": { + "x": 0, + "y": 24, + "w": 12, + "h": 8 + }, + "datasource": { + "uid": "prometheus", + "type": "prometheus" + }, + "id": 12, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "refId": "A", + "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "range": true, + "instant": false, + "editorMode": "builder", + "legendFormat": "{{le}}", + "useBackend": false, + "disableTextWrap": false, + "fullMetaSearch": false, + "includeNullMetadata": true, + "format": "heatmap" + } + ], + "options": { + "calculate": false, + "yAxis": { + "axisPlacement": "left", + "reverse": false, + "unit": "none", + "axisLabel": "Prompt Length" + }, + "rowsFrame": { + "layout": "auto", + "value": "Request count" + }, + "color": { + "mode": "scheme", + "fill": "dark-orange", + "scale": "exponential", + "exponent": 0.5, + "scheme": "Spectral", + "steps": 64, + "reverse": false, + "min": 0 + }, + "cellGap": 1, + "filterValues": { + "le": 1e-9 + }, + "tooltip": { + "show": true, + "yHistogram": true + }, + "legend": { + "show": true + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "cellValues": { + "unit": "none" + } + }, + "fieldConfig": { + "defaults": { + "custom": { + "scaleDistribution": { + "type": "linear" + }, + "hideFrom": { + "tooltip": false, + "viz": false, + "legend": false + } + } + }, + "overrides": [] + }, + "pluginVersion": "10.2.0" + }, + { + "datasource": { + "uid": "prometheus", + "type": "prometheus" + }, + "type": "heatmap", + "title": "Request Generation Length", + "description": "Heatmap of request generation length", + "gridPos": { + "x": 12, + "y": 24, + "w": 12, + "h": 8 + }, + "id": 13, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "refId": "A", + "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "range": true, + "instant": false, + "editorMode": "builder", + "legendFormat": "{{le}}", + "useBackend": false, + "disableTextWrap": false, + "fullMetaSearch": false, + "includeNullMetadata": true, + "format": "heatmap" + } + ], + "options": { + "calculate": false, + "yAxis": { + "axisPlacement": "left", + "reverse": false, + "unit": "none", + "axisLabel": "Generation Length" + }, + "rowsFrame": { + "layout": "auto", + "value": "Request count" + }, + "color": { + "mode": "scheme", + "fill": "dark-orange", + "scale": "exponential", + "exponent": 0.5, + "scheme": "Spectral", + "steps": 64, + "reverse": false, + "min": 0 + }, + "cellGap": 1, + "filterValues": { + "le": 1e-9 + }, + "tooltip": { + "show": true, + "yHistogram": true + }, + "legend": { + "show": true + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "cellValues": { + "unit": "none" + } + }, + "fieldConfig": { + "defaults": { + "custom": { + "scaleDistribution": { + "type": "linear" + }, + "hideFrom": { + "tooltip": false, + "viz": false, + "legend": false + } + } + }, + "overrides": [] + }, + "pluginVersion": "10.2.0" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "custom": { + "drawStyle": "line", + "lineInterpolation": "linear", + "barAlignment": 0, + "lineWidth": 1, + "fillOpacity": 0, + "gradientMode": "none", + "spanNulls": false, + "insertNulls": false, + "showPoints": "auto", + "pointSize": 5, + "stacking": { + "mode": "none", + "group": "A" + }, + "axisPlacement": "auto", + "axisLabel": "", + "axisColorMode": "text", + "axisBorderShow": false, + "scaleDistribution": { + "type": "linear" + }, + "axisCenteredZero": false, + "hideFrom": { + "tooltip": false, + "viz": false, + "legend": false + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "color": { + "mode": "palette-classic" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 32 + }, + "id": 11, + "options": { + "tooltip": { + "mode": "single", + "sort": "none" + }, + "legend": { + "showLegend": true, + "displayMode": "list", + "placement": "bottom", + "calcs": [] + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(finished_reason) (increase(vllm:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "interval": "", + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Finish Reason", + "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", + "type": "timeseries" } ], "refresh": "", diff --git a/requirements-common.txt b/requirements-common.txt index e9db261c6aec9..3abb828116680 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -12,6 +12,7 @@ openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 +prometheus-fastapi-instrumentator >= 7.0.0 tiktoken == 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.9.8 outlines == 0.0.34 # Requires torch >= 2.1.0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7439f7dc33e8d..024b7e7013441 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -320,7 +320,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: for seq_group in state_queue: if not request_ids: # Using 'break' here may add two extra iterations, - # but is acceptable to reduce complexity . + # but is acceptable to reduce complexity. break if seq_group.request_id in request_ids: # Appending aborted group into pending list. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 53a680580390f..835803fd4e75d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata) + SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -217,7 +218,8 @@ def __init__( if self.log_stats: self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.model)) + labels=dict(model_name=model_config.model), + max_model_len=self.model_config.max_model_len) self.stat_logger.info("cache_config", self.cache_config) # Create sequence output processor, e.g. for beam search or @@ -619,59 +621,109 @@ def _get_stats( """ now = time.time() - # KV Cache Usage in %. + # System State + # Scheduler State + num_running_sys = len(self.scheduler.running) + num_swapped_sys = len(self.scheduler.swapped) + num_waiting_sys = len(self.scheduler.waiting) + + # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage = 0. + cpu_cache_usage_sys = 0. if num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) - cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) - - # Scheduler State - num_running = len(self.scheduler.running) - num_swapped = len(self.scheduler.swapped) - num_waiting = len(self.scheduler.waiting) - - # Iteration stats if we have scheduler output. - num_prompt_tokens = 0 - num_generation_tokens = 0 - time_to_first_tokens = [] - time_per_output_tokens = [] - time_e2e_requests = [] + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + best_of_requests: List[int] = [] + n_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: - prompt_run = scheduler_outputs.num_prefill_groups > 0 - - # Number of Tokens. - if prompt_run: - num_prompt_tokens = sum( - len(scheduled_seq_group.seq_group.prompt_token_ids) - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - num_generation_tokens = sum( - scheduled_seq_group.seq_group.num_seqs() - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - else: - num_generation_tokens = scheduler_outputs.num_batched_tokens - - # Latency Timings. - time_last_iters = [] - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + num_generation_tokens_from_prefill_groups = 0. + if scheduler_outputs.num_prefill_groups > 0 and len( + scheduler_outputs.scheduled_seq_groups + ) != scheduler_outputs.num_prefill_groups: + print("DETECTED CHUNKED") + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group - # Time since last token. - # (n.b. updates seq_group.metrics.last_token_time) - time_last_iters.append(seq_group.get_last_latency(now)) - # Time since arrival for all finished requests. + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_latency(now) + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_latency(now) + time_per_output_tokens_iter.append(latency) + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. if seq_group.is_finished(): + # Latency timings time_e2e_requests.append(now - seq_group.metrics.arrival_time) - time_to_first_tokens = time_last_iters if prompt_run else [] - time_per_output_tokens = [] if prompt_run else time_last_iters + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + best_of_requests.append(seq_group.sampling_params.best_of) + n_requests.append(seq_group.sampling_params.n) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) # Spec decode, if enabled, emits specialized metrics from the worker in # sampler output. @@ -683,17 +735,32 @@ def _get_stats( return Stats( now=now, - num_running=num_running, - num_swapped=num_swapped, - num_waiting=num_waiting, - gpu_cache_usage=gpu_cache_usage, - cpu_cache_usage=cpu_cache_usage, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=num_generation_tokens, - time_to_first_tokens=time_to_first_tokens, - time_per_output_tokens=time_per_output_tokens, - time_e2e_requests=time_e2e_requests, + + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + best_of_requests=best_of_requests, + n_requests=n_requests, + finished_reason_requests=finished_reason_requests, ) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index eb54f5641171e..45bfad03ec867 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,8 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Protocol +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Protocol, Union import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -21,8 +23,9 @@ # begin-metrics-definitions class Metrics: + labelname_finish_reason = "finished_reason" - def __init__(self, labelnames: List[str]): + def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors for collector in list(REGISTRY._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: @@ -34,18 +37,20 @@ def __init__(self, labelnames: List[str]): documentation='information of cache_config') # System stats + # Scheduler State self.gauge_scheduler_running = Gauge( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_swapped = Gauge( - name="vllm:num_requests_swapped", - documentation="Number of requests swapped to CPU.", - labelnames=labelnames) self.gauge_scheduler_waiting = Gauge( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) + self.gauge_scheduler_swapped = Gauge( + name="vllm:num_requests_swapped", + documentation="Number of requests swapped to CPU.", + labelnames=labelnames) + # KV Cache Usage in % self.gauge_gpu_cache_usage = Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", @@ -55,7 +60,7 @@ def __init__(self, labelnames: List[str]): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - # Raw stats from last model iteration + # Iteration stats self.counter_prompt_tokens = Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -80,18 +85,51 @@ def __init__(self, labelnames: List[str]): 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5 ]) - self.histogram_e2e_request_latency = Histogram( + + # Request stats + # Latency + self.histogram_e2e_time_request = Histogram( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) + # Metadata + self.histogram_num_prompt_tokens_request = Histogram( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = Histogram( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_best_of_request = Histogram( + name="vllm:request_params_best_of", + documentation="Histogram of the best_of request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.histogram_n_request = Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.counter_request_success = Counter( + name="vllm:request_success", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) - # Legacy metrics + # Deprecated in favor of vllm:prompt_tokens_total self.gauge_avg_prompt_throughput = Gauge( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) + # Deprecated in favor of vllm:generation_tokens_total self.gauge_avg_generation_throughput = Gauge( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", @@ -102,24 +140,57 @@ def __init__(self, labelnames: List[str]): # end-metrics-definitions +def build_1_2_5_buckets(max_value: int): + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values (1, 2, 5) until the value exceeds the specified maximum. + + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + mantissa_lst = [1, 2, 5] + exponent = 0 + buckets = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + @dataclass class Stats: """Created by LLMEngine for use by StatLogger.""" now: float - # System stats. - num_running: int - num_waiting: int - num_swapped: int - gpu_cache_usage: float - cpu_cache_usage: float - - # Raw stats from last model iteration. - num_prompt_tokens: int - num_generation_tokens: int - time_to_first_tokens: List[float] - time_per_output_tokens: List[float] + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + + # Request stats (should have _requests suffix) + # Latency time_e2e_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + best_of_requests: List[int] + n_requests: List[int] + finished_reason_requests: List[str] spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None @@ -133,7 +204,8 @@ def metrics_info(self) -> Dict[str, str]: class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" - def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: + def __init__(self, local_interval: float, labels: Dict[str, str], + max_model_len: int) -> None: # Metadata for logging locally. self.last_local_log = time.time() self.local_interval = local_interval @@ -144,7 +216,8 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: # Prometheus metrics self.labels = labels - self.metrics = Metrics(labelnames=list(labels.keys())) + self.metrics = Metrics(labelnames=list(labels.keys()), + max_model_len=max_model_len) def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": @@ -158,34 +231,66 @@ def _local_interval_elapsed(self, now: float) -> bool: return elapsed_time > self.local_interval def _log_prometheus(self, stats: Stats) -> None: - # Set system stat gauges. - self.metrics.gauge_scheduler_running.labels(**self.labels).set( - stats.num_running) - self.metrics.gauge_scheduler_swapped.labels(**self.labels).set( - stats.num_swapped) - self.metrics.gauge_scheduler_waiting.labels(**self.labels).set( - stats.num_waiting) - self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set( - stats.gpu_cache_usage) - self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set( - stats.cpu_cache_usage) - - # Add to token counters. - self.metrics.counter_prompt_tokens.labels(**self.labels).inc( - stats.num_prompt_tokens) - self.metrics.counter_generation_tokens.labels(**self.labels).inc( - stats.num_generation_tokens) - - # Observe request level latencies in histograms. - for ttft in stats.time_to_first_tokens: - self.metrics.histogram_time_to_first_token.labels( - **self.labels).observe(ttft) - for tpot in stats.time_per_output_tokens: - self.metrics.histogram_time_per_output_token.labels( - **self.labels).observe(tpot) - for e2e in stats.time_e2e_requests: - self.metrics.histogram_e2e_request_latency.labels( - **self.labels).observe(e2e) + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_swapped, + stats.num_swapped_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_cache_usage, + stats.cpu_cache_usage_sys) + + # Iteration level data + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram(self.metrics.histogram_best_of_request, + stats.best_of_requests) + + def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter: Counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter: Counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram: Histogram, + data: Union[List[int], List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: @@ -210,8 +315,8 @@ def log(self, stats: Stats) -> None: self._log_prometheus(stats) # Save tracked stats for token counters. - self.num_prompt_tokens.append(stats.num_prompt_tokens) - self.num_generation_tokens.append(stats.num_generation_tokens) + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. if self._local_interval_elapsed(stats.now): @@ -234,11 +339,11 @@ def log(self, stats: Stats) -> None: "CPU KV cache usage: %.1f%%", prompt_throughput, generation_throughput, - stats.num_running, - stats.num_swapped, - stats.num_waiting, - stats.gpu_cache_usage * 100, - stats.cpu_cache_usage * 100, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, ) # Reset tracked stats for next interval. diff --git a/vllm/sequence.py b/vllm/sequence.py index 567fca5709518..0e931ebbb6571 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -442,15 +442,27 @@ def prompt_token_ids(self) -> List[int]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - def get_last_latency(self, now: float) -> float: - """Gets last token latency for Request level timings.""" + def get_last_latency(self, now: float) -> Optional[float]: + """Sets the last token time for Request level timings.""" + # If still in prefill phase, raise Error. + if self.is_prefill(): + raise ValueError( + "seq_group.get_last_latency() should not be called " + "if the seq_group is in prefill phase.") + + # Otherwise return token latency. latency = now - self.metrics.last_token_time self.metrics.last_token_time = now return latency def maybe_set_first_token_time(self, time: float) -> None: """Sets the first token time for Request level timings.""" - if self.metrics.first_token_time is None: + # Note: in a case where a sequence_group is swapped and + # recomputed, the time between iterations is counted + # in TPOT, rather than recalculating TTFT (since from the ) + # POV of the user, there is simply a long generation delay. + if (self.metrics.first_token_time is None + and self.get_seqs()[0].get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: From 03dd7d52bfcc4f21ba964a0cfc3fb6e7a47fb242 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sun, 28 Apr 2024 16:32:07 -0700 Subject: [PATCH 148/413] [CI] clean docker cache for neuron (#4441) --- .buildkite/run-neuron-test.sh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 8ba03b78e8dbf..252c0f7fecd12 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -4,6 +4,20 @@ set -e # Try building the docker image aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com + +# prune old image and containers to save disk space, and only once a day +# by using a timestamp file in tmp. +if [ -f /tmp/neuron-docker-build-timestamp ]; then + last_build=$(cat /tmp/neuron-docker-build-timestamp) + current_time=$(date +%s) + if [ $((current_time - last_build)) -gt 86400 ]; then + docker system prune -f + echo $current_time > /tmp/neuron-docker-build-timestamp + fi +else + echo $(date +%s) > /tmp/neuron-docker-build-timestamp +fi + docker build -t neuron -f Dockerfile.neuron . # Setup cleanup From df29793dc73a83f3c86c19de967adffda1a28a93 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Mon, 29 Apr 2024 11:01:26 +0900 Subject: [PATCH 149/413] [mypy][5/N] Support all typing on model executor (#4427) --- .github/workflows/mypy.yaml | 2 +- format.sh | 2 +- .../lm_format_enforcer_decoding.py | 1 + vllm/model_executor/layers/linear.py | 12 ++++- .../layers/quantization/__init__.py | 4 +- .../layers/quantization/base_config.py | 14 ++++-- .../layers/quantization/squeezellm.py | 5 +- .../model_executor/layers/rotary_embedding.py | 4 +- vllm/model_executor/layers/sampler.py | 47 +++++++++++-------- .../model_executor/model_loader/tensorizer.py | 4 +- 10 files changed, 61 insertions(+), 34 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 089c7d18ad6f2..a19be8525f902 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -43,8 +43,8 @@ jobs: mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml # TODO(sang): Fix nested dir - mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/format.sh b/format.sh index 4ac1842daef0a..bd12e61d77806 100755 --- a/format.sh +++ b/format.sh @@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml -mypy vllm/model_executor/*.py --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 0d74a5f8e81ff..d0a5ca5592f9d 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: return schema if isinstance(schema, BaseModel): return schema.model_json_schema() + raise AssertionError(f"Unsupported schema type {schema}") @lru_cache diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1bd6c42ab3fd8..4d43ed4c5f14a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -128,7 +128,8 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype if quant_config is None: - self.quant_method = UnquantizedLinearMethod() + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self) @@ -160,6 +161,8 @@ def __init__( super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) + # All the linear layer supports quant method. + assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size, [self.output_size], self.input_size, self.output_size, self.params_dtype) @@ -173,6 +176,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -221,6 +225,8 @@ def __init__( self.output_size_per_partition = divide(output_size, tp_size) if output_sizes is None: output_sizes = [output_size] + # All the linear layer supports quant method. + assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size, [x // tp_size for x in output_sizes], @@ -255,6 +261,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. + assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. @@ -579,6 +586,8 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) + # All the linear layer supports quant method. + assert self.quant_method is not None self.quant_method.create_weights(self, self.input_size_per_partition, [self.output_size], @@ -624,6 +633,7 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. + assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0820f17c5c50d..70e0a7cfe3e3b 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Dict, Type from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig @@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -QUANTIZATION_METHODS = { +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "fp8": Fp8Config, diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index b755b1328504a..ff5cf0b2bd61a 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -76,8 +76,16 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase: - """Get the quantize method to use for the quantized layer.""" + def get_quant_method( + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 971078fe25a9b..207dbcee8afc5 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -52,11 +52,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": return cls(weight_bits) def get_quant_method( - self, - layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]: + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: if isinstance(layer, LinearBase): return SqueezeLLMLinearMethod(self) - return + return None def get_scaled_act_names(self) -> List[str]: return [] diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b8361af61ae3f..25365a9b50a1f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -431,8 +431,8 @@ def forward( torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) - self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( - idx.device) + self.long_short_cos_sin_cache: torch.Tensor = ( + self.long_short_cos_sin_cache.to(idx.device)) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4ef25edecfd24..d79c99e5d0a45 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -13,6 +13,9 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceGroupOutput, SequenceOutput) +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -155,7 +158,7 @@ def _apply_min_tokens_penalty( have not been generated yet """ # list of indices in logits that will be set to -inf - logits_to_penalize = [] + logits_to_penalize: List[Tuple[int, int]] = [] logits_applied = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -269,7 +272,7 @@ def _apply_min_p( def _greedy_sample( selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run greedy sampling on a given samples. Args: @@ -284,7 +287,7 @@ def _greedy_sample( """ samples = samples.tolist() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -304,7 +307,7 @@ def _greedy_sample( def _random_sample( selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run random sampling on a given samples. Args: @@ -320,7 +323,7 @@ def _random_sample( # Find the maximum best_of value of the prompt phase requests. random_samples = random_samples.cpu() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -348,7 +351,7 @@ def _random_sample( def _beam_search_sample( selected_seq_groups: List[SequenceGroupToSample], logprobs: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run beam sampling on a given samples. Args: @@ -370,7 +373,7 @@ def _beam_search_sample( # NOTE: Beam search is not vectorized, so its speed can be slower than # other sampling methods. sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -391,16 +394,16 @@ def _beam_search_sample( next_token_ids = next_token_ids.tolist() else: # Generation phase. - cumulative_logprobs = [ + cumulative_logprobs: List[int] = [ seq_group.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids ] - cumulative_logprobs = torch.tensor( + cumulative_logprobs_tensor = torch.tensor( cumulative_logprobs, dtype=torch.float, device=seq_group_logprobs.device) seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs.unsqueeze(dim=1)) + cumulative_logprobs_tensor.unsqueeze(dim=1)) _, topk_ids = torch.topk(seq_group_logprobs.flatten(), 2 * beam_width) topk_ids = topk_ids.tolist() @@ -452,8 +455,10 @@ def _sample_with_torch( sampling_metadata: SamplingMetadata, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -555,8 +560,10 @@ def _sample_with_triton_kernel( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> SampleResultType: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -632,7 +639,7 @@ def _sample( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool -) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: """ Args: probs: (num_query_tokens_in_batch, num_vocab) @@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: """Return sample lobprobs and prompt logprobs. @@ -751,8 +758,8 @@ def _get_logprobs( assert len(next_token_ids) == len(query_indices) if len(query_indices) == 0: - empty_sampled_logprob = [] - empty_prompt_logprob = None + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None return [empty_prompt_logprob], [empty_sampled_logprob] query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) @@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], @@ -1009,7 +1016,7 @@ def _build_sampler_output( ) -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: """Get a list of next prompt tokens to compute logprob from a given sequence group. diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 8fc6d16672117..2d654b2fefb8d 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -64,7 +64,7 @@ def _construct_tensorizer_args(self) -> "TensorizerArgs": "s3_secret_access_key": self.s3_secret_access_key, "s3_endpoint": self.s3_endpoint, } - return TensorizerArgs(**tensorizer_args) + return TensorizerArgs(**tensorizer_args) # type: ignore def verify_with_parallel_config( self, @@ -270,8 +270,10 @@ def __init__(self, tensorizer_config: TensorizerConfig, self.model = self._init_model() def _init_model(self): + assert self.tensorizer_config.hf_config is not None model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype + assert self.tensorizer_config.model_class is not None with no_init_or_tensor(): return self.tensorizer_config.model_class( config=model_args, From 73c8d677e57e42374bcfb2271b8f1cf7f2c0a486 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Mon, 29 Apr 2024 12:35:34 -0400 Subject: [PATCH 150/413] [Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922) Co-authored-by: alexm Co-authored-by: mgoin --- CMakeLists.txt | 2 + csrc/ops.h | 18 + csrc/pybind.cpp | 2 + csrc/quantization/gptq_marlin/gptq_marlin.cu | 1520 +++++++++++++++++ csrc/quantization/gptq_marlin/gptq_marlin.cuh | 74 + .../gptq_marlin/gptq_marlin_repack.cu | 324 ++++ tests/models/test_gptq_marlin.py | 93 + tests/models/test_marlin.py | 45 +- tests/models/utils.py | 29 + .../test_autogptq_marlin_configs.py | 64 - tests/quantization/test_configs.py | 73 + vllm/config.py | 39 +- .../layers/quantization/__init__.py | 3 + .../layers/quantization/gptq_marlin.py | 444 +++++ 14 files changed, 2626 insertions(+), 104 deletions(-) create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin.cu create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin.cuh create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin_repack.cu create mode 100644 tests/models/test_gptq_marlin.py create mode 100644 tests/models/utils.py delete mode 100644 tests/quantization/test_autogptq_marlin_configs.py create mode 100644 tests/quantization/test_configs.py create mode 100644 vllm/model_executor/layers/quantization/gptq_marlin.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e9262b57d0867..1558dbf313ce7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/marlin_cuda_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu") endif() diff --git a/csrc/ops.h b/csrc/ops.h index 03bb1e24dc68e..04b97d1784cd2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -124,6 +124,24 @@ torch::Tensor marlin_gemm( int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_gemm( + torch::Tensor &a, + torch::Tensor &b_q_weight, + torch::Tensor &b_scales, + torch::Tensor &g_idx, + torch::Tensor &perm, + torch::Tensor &workspace, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full); + +torch::Tensor gptq_marlin_repack( + torch::Tensor &b_q_weight, + torch::Tensor &perm, + int64_t size_k, + int64_t size_n); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 2250c7f69f0ab..9839bfc0331c4 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,6 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu new file mode 100644 index 0000000000000..9902f55167d89 --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -0,0 +1,1520 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "gptq_marlin.cuh" + +template inline std::string str(T x) { return std::to_string(x); } + +namespace gptq_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) {} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &g_idx, + torch::Tensor &perm, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, + FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template __device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, + FragS &frag_s_3, FragS &frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const *a_row_half = + reinterpret_cast(a_int4_ptr + offset); + half *out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_g_idx = sh_b + (stages * b_sh_stage); + int4 *sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const *cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + + FragB frag_b0 = dequant(b_quant); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + FragB frag_b1 = dequant(b_quant_shift); + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped portioning + // minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half *>(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half *>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((half2 *)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + if (s_sh_wr_pred) { + cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + + // if (blockIdx.x == 0 && threadIdx.x == 0) { + // printf("Move\n"); + // } + start_pipes(); + } + } + } +} + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 128, 128}, // Reduce N 2X, same K + // {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + + // TODO: Enable if needed after some more testing + if (prob_m <= 0) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, + void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, + void *workspace, bool has_act_order, bool is_k_full, + int num_groups, int group_size, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, + int sms = -1, int max_par = 16) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, default_threads}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + + " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + + str(prob_n) + "]"); + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by having + // a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(16, 4, 256) + CALL_IF(8, 8, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &g_idx, + torch::Tensor &perm, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { + // Verify A + TORCH_CHECK(a.size(0) == size_m, + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + TORCH_CHECK(a.size(1) == size_k, + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(gptq_marlin::tile_size)); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(gptq_marlin::tile_size)); + TORCH_CHECK( + b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(gptq_marlin::tile_size)); + int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * + gptq_marlin::pack_factor_4bit; + TORCH_CHECK(size_n == actual_size_n, + "size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || + (g_idx.size(0) == size_k && perm.size(0) == size_k), + "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) + + " and perm.size(0) = " + str(perm.size(0)) + + ", where size_k = " + str(size_k)); + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(0) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + num_groups = b_scales.size(0); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, + "size_k = " + str(size_k) + + ", is not divisible by num_groups = " + str(num_groups)); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK(size_k % num_groups == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + // Verify workspace size + TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(gptq_marlin::min_thread_n)); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + gptq_marlin::marlin_cuda( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, + sms, gptq_marlin::max_par); + + return c; +} + +#endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh new file mode 100644 index 0000000000000..8cfce6b2575d5 --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace gptq_marlin { + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per +// schedule allows some more latency hiding. At the same time, we want relatively few warps to have +// many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit + +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + // No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace gptq_marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu new file mode 100644 index 0000000000000..fa45ce68a0c77 --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -0,0 +1,324 @@ +#include "gptq_marlin.cuh" + +namespace gptq_marlin { + +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void +marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, + int64_t size_k, int64_t size_n) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template +__global__ void +marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, int size_k, int size_n) { + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4 *sh_perm_ptr = sh; + int4 *sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = + has_perm ? tile_k_size : tile_k_size / pack_factor_4bit; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const *sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor_4bit; + + cp_async4_stream( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor_4bit; + + cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[pack_factor_4bit]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor_4bit; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + + uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; + uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; + + uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8]; + uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8]; + +#pragma unroll + for (int i = 0; i < 2; i++) { + int cur_elem = tc_row + tc_offsets[i]; + vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; + vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf; + } + +#pragma unroll + for (int i = 2; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i] - 8; + vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; + vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; + } + } + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < pack_factor_4bit; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, + int64_t size_k, int64_t size_n) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); + TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + + // Verify B + TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, + ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit); + TORCH_CHECK(b_q_weight.size(1) == size_n, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not size_n = ", size_n); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = torch::empty( + {size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, + options); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const *b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const *perm_ptr = + reinterpret_cast(perm.data_ptr()); + uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (has_perm) { + cudaFuncSetAttribute( + gptq_marlin::marlin_repack_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + gptq_marlin::marlin_repack_kernel + <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + + } else { + cudaFuncSetAttribute( + gptq_marlin::marlin_repack_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + gptq_marlin::marlin_repack_kernel + <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + } + + return out; +} + +#endif diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py new file mode 100644 index 0000000000000..dc027697ffd4d --- /dev/null +++ b/tests/models/test_gptq_marlin.py @@ -0,0 +1,93 @@ +"""Compares the outputs of gptq vs gptq_marlin +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. +Note: This test currently fails running with --forked with the following: + RuntimeError: Cannot re-initialize CUDA in forked subprocess. + To use CUDA with multiprocessing, you must use the 'spawn' start method +Run `pytest tests/models/test_gptq_marlin.py`. +""" +import os + +import pytest +import torch + +from tests.models.utils import check_logprobs_close +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +gptq_marlin_not_supported = ( + capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) + +MODELS = [ + # act_order==False, group_size=channelwise + ("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"), + # act_order==False, group_size=128 + ("TheBloke/Llama-2-7B-GPTQ", "main"), + + # act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"), + # act_order==True, group_size=64 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), + # act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(gptq_marlin_not_supported, + reason="gptq_marlin is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + model_name, revision = model + + # Run marlin. + gptq_marlin_model = vllm_runner(model_name=model_name, + revision=revision, + dtype=dtype, + quantization="marlin", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + disable_custom_all_reduce=True) + + gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del gptq_marlin_model + + # Run gptq. + gptq_model = vllm_runner(model_name=model_name, + revision=revision, + dtype=dtype, + quantization="gptq", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + disable_custom_all_reduce=True) + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + del gptq_model + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=gptq_marlin_outputs, + name_0="gptq", + name_1="gptq_marlin", + ) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 4fe6daec02520..fa846d43d0e88 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -10,12 +10,12 @@ Run `pytest tests/models/test_marlin.py`. """ - from dataclasses import dataclass import pytest import torch +from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS capability = torch.cuda.get_device_capability() @@ -55,43 +55,24 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) + marlin_model = vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="marlin") marlin_outputs = marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. del marlin_model - gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) + gptq_model = vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="gptq") gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. del gptq_model - # loop through the prompts - for prompt_idx in range(len(example_prompts)): - gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[ - prompt_idx] - marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[ - prompt_idx] - - for idx, (gptq_output_id, marlin_output_id) in enumerate( - zip(gptq_output_ids, marlin_output_ids)): - # If sequence is not an exact match, - if marlin_output_id != gptq_output_id: - # Each predicted token must be in top 5 of the other's - assert gptq_output_id in marlin_logprobs[idx], ( - f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n" - f"Marlin:\t{marlin_output_str!r}") - assert marlin_output_id in gptq_logprobs[idx], ( - f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n" - f"Marlin:\t{marlin_output_str!r}") - - # Break out since sequences will now diverge. - break + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=marlin_outputs, + name_0="gptq", + name_1="marlin", + ) diff --git a/tests/models/utils.py b/tests/models/utils.py new file mode 100644 index 0000000000000..3e49dfb331176 --- /dev/null +++ b/tests/models/utils.py @@ -0,0 +1,29 @@ +def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): + """Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, + output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_id_1 in logprobs_0[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + # Break out since sequences will now diverge. + break diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py deleted file mode 100644 index 1310b4da218b5..0000000000000 --- a/tests/quantization/test_autogptq_marlin_configs.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Tests whether Marlin models can be loaded from the autogptq config. - -Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`. -""" - -from dataclasses import dataclass - -import pytest - -from vllm.config import ModelConfig - - -@dataclass -class ModelPair: - model_marlin: str - model_gptq: str - - -# Model Id // Expected Kernel -MODELS_QUANT_TYPE = [ - # compat: autogptq <=0.7.1 is_marlin_format: bool - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"), - ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"), - # compat: autogptq >=0.8.0 use checkpoint_format: str - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq") -] - - -@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE) -def test_auto_gptq(model_quant_type: str, ) -> None: - model_path, quant_type = model_quant_type - - model_config_no_quant_arg = ModelConfig( - model_path, - model_path, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="float16", - revision=None, - quantization=None # case 1 - ) - - model_config_quant_arg = ModelConfig( - model_path, - model_path, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="float16", - revision=None, - quantization="gptq" # case 2 - ) - - assert model_config_no_quant_arg.quantization == quant_type, ( - f"Expected quant_type == {quant_type} for {model_path}, " - f"but found {model_config_no_quant_arg.quantization} " - "for no --quantization None case") - - assert model_config_quant_arg.quantization == quant_type, ( - f"Expected quant_type == {quant_type} for {model_path}, " - f"but found {model_config_quant_arg.quantization} " - "for --quantization gptq case") diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py new file mode 100644 index 0000000000000..6820b2728e3c9 --- /dev/null +++ b/tests/quantization/test_configs.py @@ -0,0 +1,73 @@ +"""Tests whether Marlin models can be loaded from the autogptq config. + +Run `pytest tests/quantization/test_configs.py --forked`. +""" + +from dataclasses import dataclass + +import pytest + +from vllm.config import ModelConfig + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +# Model Id // Quantization Arg // Expected Type +MODEL_ARG_EXPTYPES = [ + # AUTOGPTQ + # compat: autogptq <=0.7.1 is_marlin_format: bool + # Model Serialized in Marlin Format should always use Marlin kernel. + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"), + # Model Serialized in Exllama Format. + ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), + # compat: autogptq >=0.8.0 use checkpoint_format: str + # Model Serialized in Marlin Format should always use Marlin kernel. + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"), + # Model Serialized in Exllama Format. + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), + + # AUTOAWQ + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"), +] + + +@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) +def test_auto_gptq(model_arg_exptype: str) -> None: + model_path, quantization_arg, expected_type = model_arg_exptype + + try: + model_config = ModelConfig(model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + quantization=quantization_arg) + found_quantization_type = model_config.quantization + except ValueError: + found_quantization_type = "ERROR" + + assert found_quantization_type == expected_type, ( + f"Expected quant_type == {expected_type} for {model_path}, " + f"but found {found_quantization_type} " + f"for no --quantization {quantization_arg} case") diff --git a/vllm/config.py b/vllm/config.py index aedb589247646..a5512c657e038 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,11 +9,14 @@ from transformers import PretrainedConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, is_neuron) +GPTQMarlinConfig = get_quantization_config("gptq_marlin") + if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -138,14 +141,34 @@ def _verify_quantization(self) -> None: is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" or quant_cfg.get("is_marlin_format", False)) - # Use marlin if the GPTQ model is serialized in marlin format. - if quant_method == "gptq" and is_format_marlin: - logger.info("The model is serialized in Marlin format. " + # Check which LinearMethod the GPTQ model should use. + if quant_method == "gptq": + # If serialized in Marlin format, use MarlinLinearMethod. + # TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod. + if is_format_marlin: + logger.info("The model is serialized in Marlin format. " + "Using Marlin kernel.") + quant_method = "marlin" + if self.quantization == "gptq": + self.quantization = quant_method + + # If convertible to Marlin format, use GPTQMarlinLinearMethod + # unless the user explicitly specified GPTQLinearMethod. + elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg): + if self.quantization == "gptq": + logger.warning( + "The model is convertible to Marlin format, but " + "you specified quantization=gptq. Use " + "quantization=marlin for faster inference.") + else: + logger.info( + "The model is convertible to Marlin format. " "Using Marlin kernel.") - quant_method = "marlin" - if self.quantization == "gptq": - self.quantization = quant_method + quant_method = "gptq_marlin" + if self.quantization == "marlin": + self.quantization = quant_method + # Verify quantization configurations. if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: @@ -165,7 +188,7 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if self.quantization != "marlin": + if (self.quantization not in ["marlin", "gptq_marlin"]): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 70e0a7cfe3e3b..1c652e347d4ad 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -6,6 +6,8 @@ QuantizationConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -15,6 +17,7 @@ "fp8": Fp8Config, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, } diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py new file mode 100644 index 0000000000000..7bff0e834483f --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -0,0 +1,444 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import numpy +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] +GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_SUPPORTED_SYM = [True] + + +# Precompute permutations for Marlin weight and scale shuffling +# +# Marlin works on [16,64] tiles. The goal of the permutations +# is to reorder the weight data so that it is compatible +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the +# kernel will get the data as it is needed for tensor-core +# (without the need to use ldmatrix instructions) +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm) + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +_perm, _scale_perm, _scale_perm_single = _get_perms() + + +def get_pack_factor(num_bits): + assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( + f"Unsupported num_bits = {num_bits}") + return 32 // num_bits + + +def marlin_permute_scales(s, size_k, size_n, group_size): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + else: + s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool) -> None: + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + + # Verify + if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: + raise ValueError( + f"Marlin does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") + + # Init + self.pack_factor = get_pack_factor(weight_bits) + self.tile_size = GPTQ_MARLIN_TILE + self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N + self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K + self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL + + def __repr__(self) -> str: + return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, group_size, desc_act, is_sym) + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlinLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + sym = quant_config.get("sym", None) + desc_act = quant_config.get("desc_act", None) + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies marlin constraints. + return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES + and sym in GPTQ_MARLIN_SUPPORTED_SYM) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Validate dtype + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_thread_n != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {self.quant_config.min_thread_n}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_thread_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {self.quant_config.min_thread_k}.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}.") + + # Detect sharding of scales/zp + + # By default, no sharding over "input dim" + scales_and_zp_size = input_size // group_size + scales_and_zp_input_dim = None + + if self.quant_config.desc_act: + # Act-order case + assert self.quant_config.group_size != -1 + + is_k_full = input_size_per_partition == input_size + + else: + # No act-order case + + # K is always full due to full alignment with + # group-size and shard of scales/zp + is_k_full = True + + # If this is a row-parallel case, then shard scales/zp + if (input_size != input_size_per_partition + and self.quant_config.group_size != -1): + scales_and_zp_size = input_size_per_partition // group_size + scales_and_zp_input_dim = 0 + + # Init buffers + + # Quantized weights + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + **extra_weight_attrs, + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + + # Activation order + g_idx = Parameter( + torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs(g_idx, { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }) + + g_idx_sort_indices = Parameter( + torch.empty( + g_idx.shape, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(g_idx_sort_indices, extra_weight_attrs) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }) + + # Quantized zero-points + qzeros = Parameter( + torch.empty(scales_and_zp_size, + output_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32, + device="meta"), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + + # Allocate marlin workspace + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_thread_n) * self.quant_config.max_parallel + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + requires_grad=False) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + layer.workspace = workspace + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.input_size = input_size + layer.is_k_full = is_k_full + layer.marlin_state = GPTQMarlinState.REPACK + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + + size_m = reshaped_x.shape[0] + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + full_size_k = layer.input_size + + out_shape = x.shape[:-1] + (part_size_n, ) + + if layer.marlin_state == GPTQMarlinState.REPACK: + layer.marlin_state = GPTQMarlinState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + cur_device = layer.qweight.device + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) + + sorted_g_idx = layer.g_idx[g_idx_sort_indices] + + replace_tensor("g_idx", sorted_g_idx) + replace_tensor("g_idx_sort_indices", g_idx_sort_indices) + + else: + # Reset g_idx related tensors + layer.g_idx = Parameter(torch.empty(0, + dtype=torch.int, + device=cur_device), + requires_grad=False) + layer.g_idx_sort_indices = Parameter(torch.empty( + 0, dtype=torch.int, device=cur_device), + requires_grad=False) + + # Repack weights + marlin_qweight = ops.gptq_marlin_repack( + layer.qweight, + layer.g_idx_sort_indices, + part_size_k, + part_size_n, + ) + replace_tensor("qweight", marlin_qweight) + + # Permute scales + scales_size_k = part_size_k + scales_size_n = part_size_n + if self.quant_config.desc_act: + scales_size_k = full_size_k + + marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, + scales_size_n, + self.quant_config.group_size) + replace_tensor("scales", marlin_scales) + + output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, + layer.g_idx, layer.g_idx_sort_indices, + layer.workspace, size_m, part_size_n, + part_size_k, layer.is_k_full) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) From ac5ccf0156e1772f3ea89c205704a31219442c55 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 29 Apr 2024 12:50:01 -0700 Subject: [PATCH 151/413] [CI] hotfix: soft fail neuron test (#4458) --- .buildkite/test-template.j2 | 1 + 1 file changed, 1 insertion(+) diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index fb1086db77823..5c9515840bb03 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -25,6 +25,7 @@ steps: agents: queue: neuron command: bash .buildkite/run-neuron-test.sh + soft_fail: true - label: "CPU Test" command: bash .buildkite/run-cpu-test.sh From f4f921b7f12c67d3c4b7575caf5ddd9bd4b0b787 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 29 Apr 2024 13:52:22 -0700 Subject: [PATCH 152/413] [Core][Distributed] use cpu group to broadcast metadata in cpu (#4444) --- .../tensorize_vllm_model_for_testing.py | 6 +- tests/worker/test_model_runner.py | 23 ++++--- vllm/distributed/communication_op.py | 69 +++++++++++++------ 3 files changed, 63 insertions(+), 35 deletions(-) diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py index e4b15fd57add4..0e113ab647e67 100644 --- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py +++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py @@ -6,14 +6,14 @@ from functools import partial from typing import Type -import torch import torch.nn as nn from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, TensorSerializer, stream_io) from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor from transformers import AutoConfig, PretrainedConfig -from vllm.distributed import initialize_model_parallel +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.model_executor.model_loader.tensorizer import TensorizerArgs @@ -226,7 +226,7 @@ def deserialize(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "8080" -torch.distributed.init_process_group(world_size=1, rank=0) +init_distributed_environment(world_size=1, rank=0, local_rank=0) initialize_model_parallel() keyfile = args.keyfile if args.keyfile else None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index abb401f25c100..56fe6db589f18 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,8 +2,10 @@ import torch from vllm.config import ModelConfig, SchedulerConfig +from vllm.distributed.parallel_state import init_distributed_environment from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -249,19 +251,18 @@ def test_empty_seq_group(): assert len(return_prompt_lens) == 0 -@pytest.mark.parametrize("batch_size", list(range(2, 128))) -@pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): - - def get_world_size(group=None): - return 1 +@pytest.fixture +def distributed_init(): + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", + local_rank=0) - def mock_get_process_group_ranks(group=None): - return [0] - monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) - monkeypatch.setattr(torch.distributed, "get_process_group_ranks", - mock_get_process_group_ranks) +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_config = ModelConfig( "facebook/opt-125m", diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index a3e93691a1e8e..8b2c26c3a8afb 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -4,7 +4,8 @@ import torch from torch.distributed import ProcessGroup -from .parallel_state import (get_tensor_model_parallel_group, +from .parallel_state import (get_cpu_world_group, + get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) @@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any], TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) +def _split_tensor_dict( + tensor_dict: Dict[Any, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note(youkaichao): currently this only supports broadcasting + # tensors on cuda. In the future, we can add device as a field in + # TensorMetadata to support broadcasting tensors on different + # devices. + assert value.is_cuda, ( + f"Tensor {key}: {value} is not on cuda. Currently we only " + f"support broadcasting tensors on cuda.") + metadata_list.append((key, TensorMetadata(value.dtype, + value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary.""" + """Broadcast the input tensor dictionary. + `group` is used to broadcast the tensors, while `metadata_group` is used + to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, + dtypes). + """ group = group or torch.distributed.group.WORLD + metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" @@ -161,27 +195,20 @@ def broadcast_tensor_dict( assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append( - (key, TensorMetadata(value.dtype, value.size()))) - else: - metadata_list.append((key, value)) + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` involves serialization and deserialization, + # all happening on CPU. Therefore, we can use the CPU group. torch.distributed.broadcast_object_list([metadata_list], src=src, - group=group) + group=metadata_group) async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = tensor_dict[key] - async_handles.append( - torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True)) + for tensor in tensor_list: + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) for async_handle in async_handles: async_handle.wait() @@ -189,7 +216,7 @@ def broadcast_tensor_dict( recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, src=src, - group=group) + group=metadata_group) assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = [] From d627a3d837976a23f89ba808f5ef6908fdb65bfa Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 29 Apr 2024 20:05:47 -0400 Subject: [PATCH 153/413] [Misc] Upgrade to `torch==2.3.0` (#4454) --- .github/workflows/publish.yml | 2 +- CMakeLists.txt | 2 +- Dockerfile | 2 +- pyproject.toml | 2 +- requirements-build.txt | 2 +- requirements-cpu.txt | 2 +- requirements-cuda.txt | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4b9fc3d04d872..d79681f03b003 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,7 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt. + pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index 1558dbf313ce7..f817f3382c5e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1") +set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") diff --git a/Dockerfile b/Dockerfile index d1d29177b0f44..e471a6e93b963 100644 --- a/Dockerfile +++ b/Dockerfile @@ -85,7 +85,7 @@ FROM dev as flash-attn-builder ARG max_jobs=2 ENV MAX_JOBS=${max_jobs} # flash attention version -ARG flash_attn_version=v2.5.6 +ARG flash_attn_version=v2.5.8 ENV FLASH_ATTN_VERSION=${flash_attn_version} WORKDIR /usr/src/flash-attention-v2 diff --git a/pyproject.toml b/pyproject.toml index 2e026c1ac8911..6a448defc16e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.2.1", + "torch == 2.3.0", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index 2bc07fb152aac..1a07a94e82e04 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.2.1 +torch==2.3.0 wheel diff --git a/requirements-cpu.txt b/requirements-cpu.txt index e911ad03295f0..b739642d8d344 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.2.1+cpu +torch == 2.3.0+cpu triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1bddae4c6f40f..6548d7a6684b2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -5,5 +5,5 @@ ray >= 2.9 nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library -torch == 2.2.1 -xformers == 0.0.25 # Requires PyTorch 2.2.1 +torch == 2.3.0 +xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 From fa32207842f1ed5a966372ed0513914bff8426c4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 29 Apr 2024 22:05:40 -0700 Subject: [PATCH 154/413] [Bugfix][Kernel] Fix compute_type for MoE kernel (#4463) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d37837a0b2ce8..b4f81527141a8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -433,6 +433,8 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) invoke_fused_moe_kernel(hidden_states, w1, @@ -447,7 +449,7 @@ def fused_moe( False, topk_ids.shape[1], config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -465,7 +467,7 @@ def fused_moe( True, 1, config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) if inplace: From 26f2fb51133c85ad8a57a87c8037f750dda757f4 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Tue, 30 Apr 2024 12:14:47 +0000 Subject: [PATCH 155/413] [Core]Refactor gptq_marlin ops (#4466) --- vllm/_custom_ops.py | 16 ++++++++++++++++ .../layers/quantization/gptq_marlin.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5ba104bada7ac..4af8b09b1e16c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -167,6 +167,22 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) + + +def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int, + is_k_full: bool) -> torch.Tensor: + return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + workspace, size_m, size_n, size_k, + is_k_full) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 7bff0e834483f..efbffa0878c4b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -6,7 +6,7 @@ import torch from torch.nn.parameter import Parameter -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( From 4bb53e2dde809ea5727b8cac95a080893733a1ef Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Wed, 1 May 2024 01:12:59 +0800 Subject: [PATCH 156/413] [BugFix] fix num_lookahead_slots missing in async executor (#4165) Co-authored-by: Lei Wen --- tests/spec_decode/e2e/conftest.py | 125 +++++++++++++++++++- tests/spec_decode/e2e/test_compatibility.py | 15 ++- tests/spec_decode/e2e/test_correctness.py | 25 ++-- vllm/engine/async_llm_engine.py | 6 +- vllm/executor/cpu_executor.py | 4 +- vllm/executor/executor_base.py | 1 + vllm/executor/gpu_executor.py | 4 +- vllm/executor/neuron_executor.py | 1 + vllm/executor/ray_gpu_executor.py | 1 + 9 files changed, 163 insertions(+), 19 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 59fb8311fc5b7..5d3469c4210ee 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,10 +1,127 @@ -from typing import List, Tuple +import asyncio +from typing import List, Optional, Tuple, Union import pytest +import ray from tests.conftest import cleanup from vllm import LLM +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import MultiModalData +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter, random_uuid + + +class AsyncLLM: + """AsyncLLM + + Note: Current LLM class in vllm don't support async mode, for test purpose, + we implement async one in here. Maybe we could move to + vllm/entrypoints/llm.py in future. + + Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes + to make to work in async mode. + """ + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + enforce_eager: bool = False, + max_context_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + self.engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + engine_use_ray=True, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.request_counter = Counter() + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + + llm_engine = AsyncLLMEngine.from_engine_args( + self.engine_args, usage_context=UsageContext.LLM_CLASS) + + if prompts is None: + raise ValueError("prompts must be provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + + if prompts is not None: + num_requests = len(prompts) + + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + elif isinstance(sampling_params, + list) and len(sampling_params) != num_requests: + raise ValueError("The lengths of prompts and " + "sampling_params must be the same.") + + async def get_output(prompt, sampling_param) -> str: + request_id = random_uuid() + results_generator = llm_engine.generate(prompt, sampling_param, + request_id) + final_output = None + async for request_output in results_generator: + final_output = request_output + return final_output + + outputs = [] + try: + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + res = asyncio.run(get_output(prompt, sampling_params)) + outputs.append(res) + finally: + ray.shutdown() + return outputs @pytest.fixture @@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, def generator_inner(): print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') - llm = LLM(**kwargs) + use_async = False + if "use_async" in kwargs: + use_async = kwargs.pop("use_async") + + llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) set_random_seed(seed) yield llm diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index fde950c14382c..60c20ed7db7a3 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator): temperature=temperature, ) - with pytest.raises(AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) + try: + with pytest.raises( + AssertionError, + match="Speculative decoding not yet supported for "): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + finally: + # we need to free up ray resource, + # so that latter test could use the gpu we allocated here + import ray + ray.shutdown() @pytest.mark.parametrize( diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py index 0536cc4ecde76..ab8d913fb894a 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_correctness.py @@ -40,17 +40,24 @@ @pytest.mark.parametrize( "common_llm_kwargs", - [{ - # Use a small model for a fast test. - # Note this is repeated in the test body; to initialize a tokenizer. - "model": "JackFram/llama-68m", + [ + { + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", - # Skip cuda graph recording for fast test. - "enforce_eager": True, + # Skip cuda graph recording for fast test. + "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True - }]) + # Required for spec decode. + "use_v2_block_manager": True, + + # whether use AsyncLLM engine + "use_async": async_mode, + } + # Try both async and sync engine execution + for async_mode in [True, False] + ]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7c1eb2ecbe550..4aceb19b50776 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -211,9 +211,11 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. output = await self.model_executor.execute_model_async( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, + seq_group_metadata_list, + scheduler_outputs.blocks_to_swap_in, scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots) else: output = [] diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e4436b2144bd3..da1b500cddaf6 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -109,12 +109,14 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=num_lookahead_slots) return output async def check_health_async(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c36aa18fb25bb..96cd18250bb37 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -112,6 +112,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, ) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 5ac62f02b99c7..489e66d586028 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -163,10 +163,12 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=num_lookahead_slots) return output diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f406287f3c1d8..8a3b9cde84311 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -84,6 +84,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + num_lookahead_slots: int, ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, ) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b6bcda4e6b18c..3eb3726bd5a6d 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -196,6 +196,7 @@ def execute_model(self, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, "blocks_to_copy": blocks_to_copy, + "num_lookahead_slots": num_lookahead_slots, }, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) From b31a1fb63c98fa1c64666aaae15579439af60d95 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Tue, 30 Apr 2024 10:41:59 -0700 Subject: [PATCH 157/413] [Doc] add visualization for multi-stage dockerfile (#4456) Signed-off-by: Prashant Gupta Co-authored-by: Roger Wang --- Dockerfile | 4 ++ .../dev/dockerfile-stages-dependency.png | Bin 0 -> 118207 bytes docs/source/dev/dockerfile/dockerfile.rst | 50 ++++++++++++++++++ docs/source/index.rst | 1 + 4 files changed, 55 insertions(+) create mode 100644 docs/source/assets/dev/dockerfile-stages-dependency.png create mode 100644 docs/source/dev/dockerfile/dockerfile.rst diff --git a/Dockerfile b/Dockerfile index e471a6e93b963..e8a9842c089dd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,10 @@ # The vLLM Dockerfile is used to construct vLLM image that can be directly used # to run the OpenAI compatible server. +# Please update any changes made here to +# docs/source/dev/dockerfile/dockerfile.rst and +# docs/source/assets/dev/dockerfile-stages-dependency.png + #################### BASE BUILD IMAGE #################### # prepare basic build environment FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev diff --git a/docs/source/assets/dev/dockerfile-stages-dependency.png b/docs/source/assets/dev/dockerfile-stages-dependency.png new file mode 100644 index 0000000000000000000000000000000000000000..b016531f1e0a06bb38b01b1989df932f161e3aee GIT binary patch literal 118207 zcmbTe2RN2}{62h1C8|^?QXwg#VPrLwB%92JC`B0=S&4+K5-BScNf|A& zw`65+-t&9Y^ZefTcO38Wf8V{vgFT2Cn}{-bFh+EwdtH_RqXc-qcx*E0rR z1*scb-ImW|e79nKG=0?kh4TUrS}O+Sedv)8-0^79$yTXH1;v&x-mxlr(6O(x*|&Gz zysH<^6)mscx}5c}FO|AM||lCh~bKJ?{fr#c3|hGm`9(zD5ljp!Z7IutIbwko#1 z{fL%OU@;vA_4gGi7JOXKdfwmvl7FW|zA*fc|0`^X9{iu5C|fLmDgO6L+pWU=-~T8o z@$Ub}k9T`5oBezl+cy0i883F>l4qxjpILW&iE5_O1hllYI=&|yyFBzeAjf&)@qQn! zz5B^{-Hs}E(9=HHz?u#BY}e8+T=Kz_Ws~JQ zmi?z{URp@@ylnpRaF0inF`1M?WBkWf%gkTPXGf6u`SHOA{l)szqVkiC%mbbJGg4cQ zC2f^r&EtB68%vJ0`0X&vzkZurUaq_9#TF@6^0AeMQFAlS)7h$>wz+=%M@D~buZy_r z_}7T^%c{L~*;Xz0_CJ37m`~PmST*Bft+oWOLUqRZ@9|qPU-B+d{{`QPi0)c zDD$^g*H7EtSi+%UAVnK-4BxnlAo-3m@Uh1H+dpjbdAd>#Hx+kkf0sZRZ~?>t#DJCn#g$yHw8pWDKe{U>r=oaM%f4z)FT z371s8F!kfvb7M)-HDE5gz*A^}nl-NM&ST$T>XxlG8uI@6FpszKpla1XBb)6o| zoa`@tZus`b)fKdT8%PigEN+e3L+(LQY1iL-vA3svRVzMnNLI_ChkLq_+j45CBKAO_ z&;mYs1IMS@F9ahdTllRorDNgJFjbl{W?vo>+|EI+uruK%-`~BuEM%1=uP@+a z-BTl(QSaiIGxphB(1}VDseNTroI5j-+g>8w`*c!#@+N=yhssprfmr#exCU!-sWKvPAyCtHp-da#EsaSPMZ*N^ymNM4qST}ah; z-K9J_2iUXdrG~t#e#D9L)OO5yVrBL=I6!luPx4B+WfBc;GULAy zko*;S&6mfjz2q~-6LjpPLrlXdGJ=$WIs102{qZH{z>bSGFUcM%`?74>cj}|RfaIS% zdM34n3RY`GPbm@OS;8r`Q@OuTep)0p!sYVOB;6C2M9GCJCFg=s-o-eWDEDLo!@%Iz zL;T9iWd7vUxXw&#dabhds=){zud#dGFl0bFLyu zuP1kA`u1A!q+iErAI;_qnTx+&NAGVv;d?+hd#AZm@AswRHIaYlQam`sFY4xuz8B_G zUbA*BalJER{jCdPX(6&Z_TP_AcSKAJBxm)j__jQKUh8Kj?@lz{THkvhk-Kc1e5i;EMGSVip( zmUnfz!OR`5m2foX)hj7K{zE0Hg-E{T(-(OTS-4i}}p3xgSrAhQN_nwS!0qSmDju z#>U3p5>1|;Ld3QxTh=>|w0=8sXd6a^*M${+v+>d`-F%?!b7s%0PQlJ(r`9bnMjmB2 zb)f51g=^wd`H{@hRY4?IQ|Uf)5umJ1cC_JI@Q$-;zAf3k&gN=zV=Y*cvXD^c;h&Pf zpFHyA`TF=^@rpw->X$Z>Pgd@lGZlpgcE#H=G`aL`zH2Suql*+nIQkeXHx&1AoP@6{ z=g7MasZ~{m`z%F( z3{eY{I{oJAf{w4Thmh9riWK;FX{*v+_vT2$e3X5?gmaEJ(z_-8{uXycDC%PyrptR6=wi_d`=jVsEO z$(Fi-@DWX}(jbvXc|7Fme@gy_j9r}_u@Ol!a!sC_7OT!v(CPaz;+2u(L*Ard`WL&T zoj}w{l;M)HHlbba#?rM%9VYXcgjWX5C-1Ba{=5BWw&}gvsr&_4WAOL)god3|+5+eC z!8jS0W@p+vwYS7DlTT2QpS9C3obZhXFo=tf7nE2^nf+7f&G|Lcs-4ID@BL%u3_0xt za4$V~F@)CNr8GG6PT>5yL ziMpCpN3?ZsoosukmD}rLIM9-FleiTc}8UL+kxs=S$EnRJH z`>5gNGvmoKAs$?pH)9$fOI@c2oRCC?X8{m(@bH{fxydRT8#l9c9(=HCl|OB~Ze|bX z224463OhA2Nj6?zfe;@UYFl+~>|3mymNgH#85L5N^NaQkcCGkHTwKv69~lknCA1%< z-&6iSgqYHejUL=3*BMt>{(9)~_-tu-M_sn#jb)r6KF(dn!BRHo%lj6KIrO*aUhasI zY`MGVOQKE&wt36j8%vApXQt}g5xTZWagnWAYl(fP3c&)KN9`d0U%>}I!^)oJQ zpMw2}-QwpCJ(=y6@cWh^WACC>$8_((Yt4t&iPza1!%Nb~ci?u3J(8VIpQm4*{D*-v z?Rzu3knU)D#cFBW3(uzp0{ZvL)%|Mg+>w9|N{Lzwq;!%y6YK zxxHYa^F(!h3GsR0<%bR5wi$m85#x||adP={jaxGu@Mb{c*s(w|ulKiB1&f|m5y}~D za_V)gFJ8c|d)PpmEQ{NZgMY0e{j2L_p=-+(2Cff2u9M9*SOpHT^O~=0I`4XS7i4d}CEJ3WMqbx{=CrH1 zr$5=mK34AIl6GJRNnuo))?8$!Qo>HUjJ48seWb>!emIDv9f1IqfFY#OyCbI%>gf^(f>8apSr#8+JUFV_iuSIEmLE*t)AEv51%JwjS#sZFd z7lvK_9Y=!aX=nH&@U_LUeOo>e)9})tjC$5n{%C*QMBPYEb57%1kFy9CuV)dVLiOJe zgR6ON@~(W`e0lKKy=%*)o;*95X?YsK`$dA$s`E$)KVMoHpO$qR{j+Vc@Iej{qjyQS z7La8s`#0w1*~Z7k{cJO?*B)2$MGD+y7kac2(Y4Mfdz;_ky0To?nKo}4=}OW0*XO&x z!HLY|(h&!Bo(66`@ukK>+r8qXqnVctxJHIU@9?A}3GGyu@?REXAgy)Ge4_9|n6w=( zZQ1&<&HJk$!0&mlp_7i?FFrd~BUjW|kPlI|O#f~oQ=;B$TQ8fDo?0fAXQorFo!S`} z_5M+weH-_`Zws4#_MzM#&SRchcQ(Q~WLsguowX0Tb)^&Rdh0O6Z93_n9giiSto`42 z2qU*QlN%=X%+?9?d12@!+1ej#cU`bAw3<2nFj~oV?A1y8Mt8Hi*RnI=w0k?U&ZPvZ zQms^G;L|S745zXFZ7STdmj^!}Amw~Fer?_H?EPi9Rj>Ut;g;3@k9cpd5fcOldipfN zWxT^IT=7}r-|MaU?|MT+Lk>gVkJb0RbnML*f3V9e{;7QbJ)h`Tmpi%3hdREV^%|NS z8+afyLU4q6U6$Rl>fNubJAAE5Bx7j?zy+*nS2KdFJ(pRtmi~8jT<@cJPTw6qAA!K) z%*GpQ1v_oKt6H16XGW{JaTq z`?8%Fdx9e;Z7~wv|BFiGS~u2S{P31pelopK7i$|V=Y07zT*g((gVVOMnyVgEJk;`| zK)_I~U40n4zLP~lq1@27Bbj{#YX_3-vP~Jj^h)jLcD7E+>i_t09E({OHQFD@cQC}? z+-2Ya!OI8U(Ja!$>Obq!b7&19G+0e`=-cF1`KcoWeu$a=IMw)CmN*6PX#(p9z2&Z^PDqsL8Y%_@3+{#vg8P#HSi` zcZHgS;j!#@oVLw(#M_*E9rbB`&c}^N(cj{C=nU?2BKxw93)orviilCg5~+Bn@ zWVZMoSgcCRYGs%IBAvXWH5P@Oz!341Sz1sF1_&K>CW7`zIqy_FtGdnZ<;S-l_o$Go?qpUtlqCW7M-4Mdn_yZ z%vT_c8i)2ULdOTVXZpAu2Fv$TY%TpxzggW+Q1{Cqp40rNHJYcVo z6b-jfdENwsCQb6?zY|BdOipZq71nVlrp zq{Tl<;rYMt`zhcrpE6+YkxQ&pm2be7V|>c5?0Qe)e|VYvmEg+G3;~@eKSLV73Y>oQ zLd{E&gk#BS{cRI8c%F82F6t>f`2PHCDUZ%HoiMetHF577@w)xwN|7%#{v!UTQzP}c z{iC3{!rK`sv-djn!gb?vw2>v=`xpJ*SuGOPZx4J!7))Z2i1EJGKArm6Yl;1Lt&Hi^ z9-SFacB;;M_M>Ju|Gh+ZyKJxWYuj#Lp4{o7WF|Fw1)V?0bG*t@wp~wXXz)L5m3CsH z>-nL5K3riu4lqVh*Xc<4{AFXc0Kz$a&!VdS}g=ex^ga`a=fvLD~7 zp!se;dcqxh;$-By)4qZOkcn6C`^!DPmLMd5Z0>a#c}-(wTM+4{hO7N}5WvRtXF6S1 zh0@mYKML>|p~+AZu*r{S$C*zb{0_;T0M*leMDQ;ZByL|G{xE9|B1g?KK>M@*j>@hz z;Mh?v*XFhLo3Y4=-^FvsK7VX37JB`3mbRpH|K9{hX=6AK$4-;OY}>T`VVgv}eCxq1 zEX778zJo{<0Jo3T6*^0pp|ssaDeGT{XNk2v|9|W+3uIvnVb;Eet);rbDe zH+SN0YL}mJaHnyWswIDUacw?eT0dkDX`&KhCAcpCdYCmc)k|)@V4_vnHF!s<%noJR zTb`f!JiGIeiQDH^wNke0Ma9XnCidJQ9+EnU>;R|XVMl47HaQ!*ID@o3z*n=l!=&0(Ci!~?53&K zdp*du!Z(D-7SQBeXs{3@Nl>!<9T&^S$>hdnJ5p=c-u1i`;8G2_frvk6Kiu80w~!cx zTenDg^elNZ$e_(kfBZ|JFV6#U^AaH3yJG|G>}wahu?xJUCCo&m&Q@@1^|h4ey@d-=B{s1TR)|DUsykU!=4+b8uG3Xt6aErjsF#ThBLWS!&{{4*dD8e zg^<(;ORJK9jC3=a{Biv?2Ii8;f3AH2&P>Df-#_EB`KC_Gayv>SHjxFv zUiXW;4v3Et3k9Wk{oayFG0htPogg_i_;mmh5U()^$)vC&1ezX^{2`js2egx%3qV*%5Y?60 z$^AU|R7i@(g{ZP&2pO%v9{C7$jYHiiM%E1^w69ugRO#dl-zdP_PAI*nTQdoOMQ|%6 zKw;_HT|4b`sgO4aLzVzhXGZ@t>a3JYdi6VgvQ0c!eUz-BLB?PCrt2;OA-iSm^Af2o zrE(J&rd#!=dAqAK>K|GGp$gDSF3%&Twu*^~DfvULGBtr7-;i7Af${K>U*j#%Ykf}RoOobG<S!wo zdqB3pkMH0|h%3H#P+jSpAdW?%peT6ZRaSCG1+S3VEDgp+0>^1c5U;x?jSIa$Z_B9h z1(tUT)N2-&MSdL1of(;F1G=m{|L`<&6o;IX{l43)@6pjOj3(j>>UL6gJvBxgbXBjw z=+2;abpo0h+t%ZE!7k4LWt=lH8BH(--3$DxCcQ0Us_U#(S85^KT{f}~iyzHtM;Bff z6cl_4-oXzx1dv*M2&(W_o%E+alb=J4Zbt?wmY;5xfBfpztDmL0+uA-q2<&xgmmX_m zb`2uD{q7RV@!6Ye3?!%%2>uKqik*{z+~Zc{+&!|g-I>_WfbInN7LpR!Q>X%gddX1s z*mW=cbMaNjHsQABBxR!Bw&}qAtpX?!UW`2yAt#C0OUf4{I!bn> zyuIa3gOhvz0>6_PE>q)^(<8a01UEx!8xH+N{Ph<3dt=qYp5b?AGN<1#aEB7-m+X7D z;z6s^>o0`bnf`d1HC$y%mXDYNa_mY@rpTYQHWhjfW37RHJbJIJi?Cf}AG?SgY*PnC zHX1bdCpI0bal|@+8sfrMibOtgtNjbuy5iQXStGbYhVM{#==k{fQ=N>YB3U>R^Igy%V zbyM~SSLXbhL)1wb$PlUzc5NfOn4DbN>?0oz#+OFNO+hUcF2Gfj_=JQ9 zykDVezCKtOpb}GaxigVybkB9!$yF%>yT2l5Vr748@eZ5c-&Dh8GBx^-zp_5(=SaYU z`lEd@nu7S-3eEZeWMkVC1;8odq$--_Fpxezf!YDT{?zYdR(8txeBw@u67rtLD@L-PCPNHemX594dT@$aS0El-FLKCIHm z`925?PxW4>|Fvr!e;E>&cHgY8BkG3}YD`a`GRqE@5en@9!BfXTkkWpDL=p>QKp;5i z8x2KOvibH7O8|Tt2BK!d+9q|vQo-l0DYX=c0?43Rp|Pm!m}!LvD=TfL$2dTmSHg1U4G>s$@&o_ zRdI-(@|?ZV1@}2MJcViJ&YjUoNl6<0qv$sXgzzk$8Te(~xmR zR#g9`+uF|wyNHVokN1}eNBM}v#DhH@m<>T@-+ugf3|mc4-$hs@6E96ChmrrSLCoaN z=8;&@H^%)fMZvO;wjpamYGG9CP{qZ*y?F70W7mb_&{wp#OHy4wPeQgu*w6U2nhd0@ zYffUSG~TBBayu*{(O<>|qB@w5*ufrxA>%YLFyib$#c79)s0LxNUAxGbpjQWPH{2}; z-9*rhTr8#fiJbQ^X8Q0Qm6!gq zH)q1K>RhbW52Eb=Is_BsPb7_5bY>8HaIWy?R;z(FnA2Aa*Yf(M;uKCNMYjK*T*qAIdVjs zgh>yeQ({+UtWL&7Nz~v3d3_Mrm60pCjoYnA=fAACy;=@p0|VB2<~iu8TE6@SV21dzJ}Ek?t$1 zYwH|4)6RgHat2#~Jnk=Wpl7gv=4sVnIqw2rLWyX!;|iFffuL=-BI7Y%>qh3}o!E!-X_qQwSDd~R#| zD~)|lmq+cz{vbQU*8o9oBQiTr44YwZsL!UT@SQ*aL?%3Dy^^fl@Ya8$k=u{cJvEsK z8-Vdd*hT}!gBq(mItEQX``et-yEg6jR>iM~8&tj^WigQ6Rh%^9Ej!mGm&bJhMF9uM z7nmOUFhrBHhqe72FAsMGKk}A}*XVzenD|+4qU$SSE4e#mP&yaPQanshZh~{oK{tdz zC`Gy5;RwX-P>_>U|B%oAn@s7bsvMvsenOA6XJy8!xw?g0?KdKHBKZ?Q0mmjy96&uG zHAjvd@zaRY#G-KQ;KOE~Ba8Anx%%SkZNw0WtgRZfeM57>ow;T>GX<^8xa3PTkBB_@ ze;m{($NS(FGLVqspL(E~7@qauv3S&*?U9yPm*E5O zEj0<;z%|GQ$!sc1sFZf?*S2PWt^@tW!e_C<%D!k6l0v6i#%vNjC^~BpkpVi;=}9u0 zg;VW)h1{jU7L_eDg#rb&C;l=-cqghcF*CQs6DT?X!i3k@R+~AXMQY$0_vCt_{NJLv znL?+S`rm7V-BcP~Ekr0FJz;z!q<1PVG4WuWrV4P5NZ~p^&`pvisPqc#B>F%M{3^&U zoBXGhRQe_AW{N_G#6U=RMhh9TJc@0EY!PaX>_B!0+4v&aot+*~8rV3AwYcq4m>Y@v z|KFfS?VYRhRqiF)P~4H&Gf12Wt*GHYsQH?#pG&G(=*+=i1bJN`{O`n6O%CibtMVW2 zcJ3m2TZ;66t1Foo2tU@kzs9OGhK5^$kbyt8_E-0iy2UcN(FUYHL`W|Yb-O9jaZ!nALO!)o^YxWG_$e%zwx6INbhk-V+Owz4KlID% zz&lLJB0q&miz^c;r)9h|`H&BnY=35{bJA?%-A?woTspn`|0uAz$*M!l1Z2vAJljsX z4@mX{aT0-gB|J-bJ{!;Vpr-8C`TqjfsD|WhXlJyVSczo&N>1W#Zf$eCyXyor1D)9n zfySfQ)KqXahIN+|fGDsDgcfQZJC;H0m%I#^#ZsNiX$l5q8}R>z(^QHc{XZ}oW$!oZ zm?=8{JFLYS782b9yjT~5LH_<9_CdSzh1!dHkX+GFf z*oj}fxG_QZ`8251F7$}vOM{C7*wXo(#+*2mV+?j7Dpq|&0eqQN-H&gb06Ma&& zRd`V2Q<2GSZb@67k}3NmrZ2fuHW3F0{FxlD$Spwt618gk6ehHt?$6_pkPyOr_j5U% zqBWIXU!kT1b>o)zi?8G}=!8rHbU~USncvS5MUA_tw4-#Ft}!S9&4fTaDoqyEzY$CH zGx8~uY?wP$o%$D4e?yggXKgpjC0Bvhr2-&F~|@>#qMK!0_ke>hDtRH zHqFISwN%HCKRcyNYR9CKgYg!DH9rTM?y)9u((Q*KX%>VwG=RoGVe~hH+by4x|KnEdO3IfRzo(;1(N-N8}Gy5`2H@B&9K?gOpEM0p<=5KsnBlb~z z1wEB#JNnt&<%mNT+5<@X$V5rK9|A}FpSR4y1om{D*ma@uF3~7UNTW+VrbgH&!kw+o zKznH^O8wB1M_THqyX&7}`C8C8T0|lk8k5j{`RbRaK-!=PW;TiLm9wPukBXs%S3POV z1t>j)h77yWNpw3UI58D@S7E2^&btN*>DfxX1x&;DzM*JCk+unMTdmXlzg>Z51Si{F zvf?}^`hYPAu>p>jw$U42hp6~DF{3VoU;_-b{DOwkJhA(_ry|du6VHWj7KnzNDY$qe)!U|E^?`-PW?MxT z85c<1DB{Ww4DywIT@W6nsqL>4b1_%`vb9TCwXJzic5LgAlZ*^kyH!tnXz7IQ>vrAR z(uskbYvSg0o?Ozvmfl>lx3GK0(Cva`97Li2@*2J~QdU-eY17FzkDi`o_3B0P^77ZO zUtei|``$ekG=HwxQ1l3s+b$;dUE`f1dBuE1(S7^rk1H<#&%qwpeSLimOihcgGjAlX zZH!DzSmL#R(C)rbP*AW`4vK%uix+!=JrAQLXEg`MIZ2g`QNF?(HqiC;_1Ta0^Y+wb zdyS6TH>Orr9`CAs%?`Tg9uvchN|+C}_6>ge%-%+H$BrE@WPd?fN;&tHZ~pxGp4ZCD zuLK0FlCt@I9Ycyjek%LAh3&nPCwYC$3JO{+Auat#H`8(yI7q~@MM-&P$Da2(8XBxX zPabI5-5?~SHW+mCCSz!5XhT=m5(Wl_pa%~QtmI`;?-&{%GS$s-;^0>c%`5WbXJKL4 zaJlBgheZI~@1c$!Mxse9War`%if^!(8n=m#iCJZ@iitNiG*Hsg(v?4dJ~XRNFRQ4i zm|M5ghjeww);T& z>$H;~Vtk&(iiEBG%#Wzn_Li$9j^4cJSl{FB2Eucp7%N%|h`^YWzxMz;_V zXLpHVf>y!}RIr+!m&#oW4h~Mw-$2`XIYzNPnK+m??{U!MXRy<}nW1|uuWv^x9i*Ds~j*pi>V;o-~7&CS6^IgoVRQ6%6> zI1;-|Mn>ij>iVWTc)EF=LuDLTp%%YEr@Ol|h+P4qzZWEqD40RjqH8qMxV2eIY7NA~ ze5{S)7K`ets{GHNm$I?3vF_Xd{vQ*Q%@98+ZfJ3v6I12_W53o}k<2u)Wa&~6$Wu<}e1^{jY?s=& zah@Chp>UGoMmv*p@7}w&p^JN`$#TG;Cy7_tLk^8Asi@>blWOSf+|rc*v`pLv(eZW> z7j1cYRZpGbNx$%O9+vLm)d*ysoSA9ImoHx`t>aXFUtG+DwiA{_xdAL8A?qp6p80MU zg~b&kN~M~VZ=A$0@~3L@@+dAYt{|v5_ewh8E3U4t<=N|n0VQ^0Da&%49U&#t4G#}* z7Znx!67=8pD)N?t{cSh>vGt);hAQ9Nmln+pZEcSe8Q&=?MqiDLTn~iFfY#v^!&-^j ztEBCE6j6#4lg&EgCx=GL8{DpwEE_g#$d}u{fBz&h#gn6nA&+nH&c>RT4g-8R7I3w! z?64tRW@@`#kFLM4{>l6gN)Nt#{VL)*?ey~XYbL)NkX-IS0XKmD$hv>OqM|$FqD6T? zaeaRp_ZZd#F{PX?W(!y$LSl)V8&xhRD=T~a#0kb#GF@F=WlvA+xc`u!+){8)PkpWr zZYLornHt4SJUGDN=FOXEXWcs7TQA&YT9+l6ZdQE*JoLb74i*-Hx7F2aSMaMT9zVXC zQrN$r0UG%mG&Vkrjg5^n6nik(*xT!aa?r*dJ61rRJDy@dht+x<*GMxErDBaEM^X*n zE}TUBO1_+_sj0oaecrowikRS9R#w(O5H4>DN_(#snTJL#H$c?u_gGl&32JC)Y(HDF z4%G}br3Ta5S7K`6GI1Imquef|bR@Wjg@tV{dbHbpameB@+Kp7)P^vaUZ{hYGJKo~J z*!FW@ct|MZWFu3&_ru-Y{Y_C3<4A8k!(B1;(L!$5cM$p?CMOG9SYsQXBJAK?7}v>@ zC-ZaQZTZE;2U^{S{G)%ON&p_b0)pUtJ3G7kp|pJyj4271^+H73Bq_;39Zo(ei$r~I zhGFx2BV*&nuCCjtvEwk|k|imYttzOQ9BldkM|nch7XSnw)j^VNm}ADGnKO002M4XsY+$9` z$@0ycH$hT1f+urbmQqiU*JWg74Gj%9Se%N!*x1;J`!b*kHJ^H-{%B!Vc6ND^p7hO4 zn>I%I3?XJbguOzYTR`P~`m_X+(;meCn57s3UMf79y2~I>kqiq#VBh7-Qd9$`P#CY3 zz~k*r-l8N;4Dn?4vN5=xT%wQ_uM5!64S` z+O-S6QB?fhkF*HeZ$ur6a)Y<5U%#G$^d2aIkU{*TDF5wSl0^_9^9u_Z5EwN)Xt5yL zqBKH&Kl0M`_bbdUT(~qocrGY7Sa}_MB(4~(4-G9X*HN#dqb``4^`R6Wr4k_cd(Bm4 z?%7YD1XLiHg-gz95rr6)kJzx@uLzqacN7Wz@<`7e#4Gv_+E)rHD%iop{m6(kGJqW< zq@*@(-@cqOFf&^UbaT*uBOe)Rl;}R8zyWi2a(zuce2GMArx`^g&M5uq>-z_+hnE<| zw`~#U!pPMyTy3x<(D)YHj6(*7`3x-T0*G{CqobpdXw;zG0J)ikgoKn19r||_J4Q2hlQ9Z1!T&`r zH7MZo*hpXz1KQ_L!}_7)JwkG#+|a>j0GxYXT7HZ3ax$2E`?6f7*5Ev%JA|Qc5AUu) zbMa!ROY_Z&j&uWxaFhi{Rtx8vEDU%XSN&1d$@DszOcci;(h zrw-1~vG3dX>C-13?^~XpJG(&vh=BvvGSEDnPc3vbHCC@AdKLz56ucV|!42+m6Ats? zl&HO<<0DwShR0q-MQzLODlt&N1*jR60mB6bvcvzPmwC}_^JB`lZ?=1e})lZ%UsNp*U6VW_2K=VO|O zvZ&|y`uX*t1T^!ikuc_%jH*fV50r&1*K}q8w zLd2ip;e{2MJj7USfHi4tLsp%>r_q-3b7$bycnqW&O;{3+!&?_F zS?vZmWdNJP+wMv2B*w6jKr!s<0?G-;CQ)`uOQXYP^BjG8{0eX-9fe?X9c0`2+nBXz zX4it*Eo^PXPMOr_$cl=J8km@PBEgcN`8x%J*{H6rZi0jr_2>~D1(2Z#9g3H1FJ2!* za;L|cU^V%`C3`1)xaF6fJ$v?M6m2B)maUVZu=11w>Q^i)S6&y?idR~BEUfu-^}NCA(>_ix4huck;u}N z8{E4QRp7Gk-vcBBXNN=kdJb&kAu^%;(C<~&*4CF=KdwqMt3FEh^hpp-p2%;v?K(~= zKzX}_qp@0tQgt8yX2HQqXo;z^htWJF(Tqacs>yZLZ?$^j zKghz)?t#e==fkNV1%ym3uDdkr^1LN}=+GfU()V>&TV#Ff-(G+s&6C1zNPwAWC|Fdj zlYZgecGmwn@zzNf8w4uWeb|Th_pqW6g*T>o^-smC?64#FAT4btm4`06$mHbYez&=m zcFK01V4E1}WuOH>eV4HqG9Pt|mS4X-K~?8bY7ugG8X649E|QnvE2KqZq)nxR2Lq&5 z(tc!k_3G7z-d<+P8XU6rX61#kgEMcV)}y7EN*TypS!ABgjI zv#On7-Uyp4FavygF;7UgLafJ4*;38Z5MS{Yk?iw7*f*G-l3v4 zhf7qhhq`iud)%G<koL?%LZpgYCujAS6_{Pc@5H{$oIbZP5fk-P0@ zr=B@BHf9L6{Ojj)e!|RbB<43j4FpA9NDX%VT!COqb9%uZl=GP=^yJZ_muZ8=&e@Oe zH9mXR1E*)!0}Pb;6cxt($La zD05q8uEQG{C>&Kfw)vb2Z91Bo=+|INkC1WLXvpK=jWnKbTYP^T@9pdGzFVOqUZEo_ zvwP{Q3QzsLoZR`z{DzSPpawx4-Zs?fR^}EF5rJkQOXcO~?|xr4J2*>IoUcpaWMkvQ zg||uni`aT`HR!j(?7grKy0RbIrZ-Sd&X)M_5w+m5EKy9d_p0$P)!qy{$3yJaxQSn zt`8Y1MH|6Q;63P5rOU){;+H~!fWdn~L7OiedwN{m9l?{l`j5qF@KQ~RR=#S(u6(CO zIc9pn7gdT+^6U$cSDA$s0k#m+ufKjRZDeE=z093_Y$wk|Zxd?o_a$fpi9ebscp(8{ zNh8i3Q6WG5;lqcE4ev?8ewM&_KMtfq(|@$ri%r|x^V|FP?-u}vWe#_Cb~dB5#Hy&s zBby^ESBJW!#pGz840rXj(@Y?uY3tO>{qMK`qv%jPJ+oIs!}s02<$q=vXZG#ecelu2 zdz0If<4E2lCl5G1ML(W|&2Rtvn>Gs8KJ$|}kvR*mHhTe=K`y(2rUyfusb0O^fP(ye zlMq(UHDuM|)Q3kBbPxM@ditL>Frd%|fOHc#h`A6`(yjbaUal@HwUKuzhv@u6;W9{H zAP*)#pCPE;dnMyKy$(kg-UFS|QJ`^(5M8~TaJBb~c4=rq` zdEGjI>p<(>FLi~$LME>eMFlpC-2YB~W^&UB8IUOni>CSW=FNkhU%7sLeo#MZJcb{_5;}q!FWnq7)qxTq8IJYT4}7B69nT zj5ear;Tqb{n_xc(0{dcA0u0WdzlEY?-p`*LZEbDmx$PiDj1E=kyW1j{HuUr)4O=Ys z4MYECaNKtb0rQ{9^fUFaj`)VDP+|n0N?qt9yIXYXZgDG8hdY36R#sL+clVt`@>80> z|IprY@Y#$|$$}xgdDYJWOU*GCt+;4rjn61&r2-G3scaz<6fY#G$om5&Vbbha%>@)$ zJfKgl@#6@!AkrWYc{Mfbi0rmzr%B}fm*040WOS8Ap^j(=$Lm`DY52!*0@Bx8T!HuY zzA4}}H>hD7#l(glmG}h*-vp>@j&PlR0}_rj8h$nxYU#Uo?;5atl^N#TTeoc^;z?y? zrPH_3co>!2g9mF7G-JM73~UT61{_<4Q#lJbI5?sTL4WU{zXO13Eus3TnlVAExw&!5 zHD=2B;ZgvAcQ9l+3cA>8d6%p-x>f4u9JqMy-@ku@FnB0X%v>@bo*zP_2U=(0;JArH zK211CwI9K=51qUZfrH)~EwMqacnj6RDxkACCt|^Z1%?yD-K)jUuY};mEGs+Ybe)-* znVfd~ZYcyX_83bMCGU;yIxZl$fbFzshA)B#&`Yuu;FYlJ>U6U`s;W!Qe|vHTTv~b7 z0pJ4x3aaB}Zde?G$SF7Q1=N+7`5t>pq)!5VsZnrLScErDdn>Q;i-PFSR1MT*JZvhk z9Q>Jl6n3mtJcjGK0~y^MhE%o~Dk{;!@$KF4aBqyO5pa_%IN>RS9Y>r7K8r&Gy7J>_ zi6nRc&Qsh%9y@nB>*|MA%SjYV{tOK*xU0q33Mp_O+7S(mjIO4nY+om7IUkF$2S$zM z_X-r&H-g5`x?fE(<|hP<)Kf*Q4-+3fx<-UhurQM^(Tun6+_~rPf4$I$yBRtdVYFw? z@J&yR=-1>puZ2Yr9nk5v&kG#Of}p(5vlsOq4=*pK-#-$)p!Vp#yhKM9&&f$9-Mro# zDg}r?J&G{gvctZ-2kwT3PL{XN%FSfqqiWXdy5I?*ln15Nak%S>uF!tK!I(MDhn%xM z##auAfJZYZB|ZIzi9!tC!v12_vG0Lbjh4LMz_<67GPmY4J>-70@8^M6P!t~@AJ~z4 z=W2wq_h=w@7u`<*_c3qTvIW{5(W{8?hV5;!l*B_CZljm)O_^rX$_&LhmnyR0$gR5!*df>D-Zw_l2nOm4gcGVrGd7ijs-VT@4M!Cs{nJwEa>iP z?HdGdVcqYai!NY8w95bmE$*G7`Sk%>HVq=Z5T11;6F&8L^9IS_YuihyEH5h{N)y<82)bW-+SP+vLP%hFW ziKSe;-pOwZ#~yVg7ZEJWLr!*fhC_!+jl^yc=`-cRPtgr0DvOE)>d$LJ*2g)% z8)b6#f7V~D&(Uj1eH0aya}bU-$GGS6G$Jbw)FbI z;__sDc{HQR|N0W+RaHoOSv96m1o1>zW9$a)n7?2_oGcf0JmVtg?%lhuE22wdYyoVM zEC%86jmVjZV3oZ zz9XWE)aANCN=n-VpG(S{2BO~s*m`{!NzhAX7*GT7#}i&>P?dIWuN-q+J3)mgKd?h7L5LiTy<js#y3FL2hzB^89uY#1|FSYgGE%q%#+=^ayWo}+P`O(b2x|Wur zN&30VD3tCzu3fwKj-|YCySR8Ebd7Qpqh3V3MnDHUSqN_MP(xH1#=f^ZbW1|QS2&$m z$gik)1dN+iaONYwUxrRWq`%Z$u=GX2Ud8zo94a5?l7fcs0PiEMAxU>cZZ$zkwI67^ zQ+BaBJyD}jV0d6?Nb6h;bW2o<*H8_hsOS!ix0X?E21FYL)GCkshn}cs2-0gNw(g<9 zV)EsE1;9snyw(BpjH>77=U*E9rLLQ0PoN+`%wA}z<>*Ky4oWo8Ov~1^HD4jv5^13(+mRi} zR!K?8?aCE8!DC6c(clz?l!)S`P=Dz(vbEBI1N2ndxv%@;1|U+>LkGx1=W}1%=e1BF z@=#Dk(7YB9unO?3975SQNZd*y=kfuM0evktoI&mE+Y@zqh&qS@@_6RKlzKbK^LSvF zd`*_a3LsEY?Azd1WN_iaZJhUMz=_p2cvR2M=^b&X740k4Ux|`sRC;*$86js#)fFEc zygt^~>k}M%1p=u?;m^uS&nl>gD*U&fsC?~ZqA=#S@C1v#roA$k2N$8B03%hvKn*Wm z^u{i{MQEwuF^}83e}6OD+g@6Cu%erZW$juffMuew9zA;WS7quZVAE^ZlLMURQ||)p z5Uh!}6V)Gzn|h9u0~GLwFqhS+=fTBemK_m@nR;GQsFC&9w8`q;RcIDh3*~^PKOO=A`4P^B5T!;WUe(SH?KyX>J8b00w?rU(XHi-2-ikB$Op< zb}mBZ^_P;_sEjw*J2)7^dr0YQy2vg>IQ5fS`Sw#E?uxql&}4Kt)8{cz<@AfSw}3=Q73ku{i=*iL5+0%`r(Xzu za^c+W-SbGf1QJj)T8&Zq{$$lO1eNFy($zfUpj~EbAjxc~8o{$?@@i{k>y1&%^FeqP zSmkO6Y1JJT%fz_+C!AwB8{0yZK6je_SOO7X^PuieNY~3*S$*4$f#y&pzZmXQ)M^J$w~Aka^i5H1~BX`+I@UYXON?O@R&F8ifPmk zoA6$D8R-TPi{m*?lBfU%ccy|8&1YaBHAy^nLjch-c6Vv)FId;w{6|3*bNc+ ziP$cMfnKi*I0l7!f*5SfN)}Z_1ldI-@&LINAXAAzJ^gU1Y2)#b9M{dh#jWK2sc;3| zgq<-lF&jigScvS4F879xj++nzBSB{u6vQecw-o|gTXhE(oqCC0-ta;^Rca|Ju$jZm z>m-)L^aKn4^tT=$>IgJzw(dg*4nSak6pq;3;fzeFilOz-2qKvwUE*i z+barI-^xw6RZ%;Mxjl+)0y}+Z z3?MCAkv)!8FT~lhfi80+qx^XJxm)HOIXr2t#N5wDwQx^2Y9(aPr}+u%^} z(*96kCn!};EBFCxNfSo9V+m4&5!xBh*K<0lkc7oGYi?m<8_^F`03-{N+vL>FdsZ6s zCI`>Ulxc6-yLayfLBYjzbaaGkLoB<34L1OR$!yz$e)Kg60iSHV0d}1tda+2TvfY^1 zyh$)g<h>qMNKJZUjkM@gZKv4xL}kT|A4~S*oyv zg#eKB5V2Nsa5N$%kdh0I3teGc!PkJ(gI`T#aJV!d2M@2?+V(4=E#+{k(O#sh^_w>@ z0!V#{LKu)}s==FiSrC>%47DzY0<{v!Ujc;-6uVohN5Bh6wsvPs#~lh`KnCXKK4`hx z0ErcEeut74@oee@BMXs5fMWteOdefCphG+g*{mD`izkgs?u^k;?+WpIi>^u$9$wxz zH8r|1BH+3F7ajx$6YbcfB8gK~RrPL2$SnY-Mg(06pj`;BDX7Gv-Hx@VrzevyBk(A? zN+MFs3=G^Lawx67V~ddF#kx>4*xJmnXAXt1{N*Mf<8)b!-QAv|M5*<|Eb zIX(J@Nw?K0Q4WCAOUTVgBpbGETMG8VA1>#73n@lHUw@`K!h{9tuMx~*cT_MGdK^Sv zjLOlsgk#K@oBimY4b#&T=15bU(1nAR!#(!)5-H#-V>c2tFAug^6;L(imE32{yu}G$NB+va>5hMeo+& zOgO;5F`m32fu=jRr}2kK)IvBmJa_J9OUo$&TM)Y6LVWH+IUSGlTq%4qD-S*J-^QK( zqHV}JZqPwU<(j+~c@xgZfFtXBf6of7PyF!&J`RL#{xk+=uG{eTM=1V*>Afc{a#rnX z*f~b5I8q=7FeZbrurQ)ii-Q_m27PgxuyvxrOGt3Tg$>Qj?qCO-0J{)X*Px|}*d|g1 zIgo@JZTV={xu~Y3%1`qwt*x&=di00}SGPBgBD;`j!x{}uO}7ljVZYE+d=n_n2q#M@ zx4UYYF=RMMmP&iP$#;l z{{VzE;Gvg-Uyk!CU%0Rpz(qKXVc9xQtb)qwd$ya;yg~HW>oCUQ8g+CbLbanlglI+t z1wBcRc(ALE9k3_rI%%A|$r?>X;uMt$iY+)c z_!hwtd(X~=ens{qD}qrFH|Z#Rqlhkp#j2rSR|0jH+m=k6aIrjU550gpDk|!h%rxx2 z4~E%Q_2Pi$)Nc3%vaBL1ID{+{z)z8G&7EddD^U#nHa<``c$Ra0!k0oQf>yFVR|W64 zd_#9tM*7{{xsGTeDWEZey$brl77OF(*jTT4W{x#$;x!w0 z_gyA8-|Z(?Q(H?ajaVNS%^@Da>AD=%kxMUd3g$y0wfNL;jR+_B`=`4s^=HNI<|gN z{`ljSjcMA&+BM$Z-h@4%3f>5r8F=YdTW#AQ8rex`rcwYW(a(8c@FkjDQPm=ykVar7 zuyegDB0ZI<^b}g^8dYdh~LV$8TQo zdxz&7-9h6mJ=J2FX~KHIa8^I){IRW64g?LA`lP#m3{Rc9j)TYSFIi_G4e$m}hJMA$ z?2WM|4@odJa#SlRDM9sLu-{Y~{cY;`9vjhXi$gdnCTNV%i+jg|CPxIfM@WOC*vjjS z2YBD3%kz5s2~>4(D(z^KXWAy|U_l*Hin%gvXOJ!{@N+XNPm)V&|2#CoVR;mSgUe1nJH?0|oq$_$J8#~; z{SH#8A|tnCM`tJLrHrcYKuBKontAy5g12lvaSX0sr@3|cL*GEsO0N*?4+6IYx@jf!{tiL)46&$p{;2z`Jl{(;!6iN4`a zNA2c{-Wx@%tgK)xUS8R=T|`yxH*Va@skXj!DGx@6C!-^atA9)B$Py zrQ=9anUG1{`nxYaoKt64=NG)AqWES1|Hapt$JLyFfBa(~nHgL5C?rXi>{*iC&DNgm zOSY_K-2R97gwRjj(scT#}*P_J2@~M1UK-BTJt#gJL?A%;> z?fUg+c!~|zZPI;weFL~_RX*_?f1EgR;vRA!H?#{!2H0%w**D+I{_EGT{k`?bzt6

W`L(y* zQ~7wEr-&u2pO~->K4z=WT4d9ksl%01fBez!V7(U8I_(J>#ZvkmXiW{(26C9{1@xP2 zQ-lRD&A}ZdrDs+A2mjOjf6$yU$i-%^xh34z%P)Iy*IS-vAzCnVI6&TD)yvH6;^l>} zAm9++-rP9C_Eny$saq$BYY2@SJUj}HcI6Y$jbM=2?&oQ#@SuFL`u5+7KYXY&BGA_B z2kaLYvhc%!?c28pg%4k>wP*N7Zs*?_8Dr6kfv~eJriB|9D zP+5O%+g3$^3pmJcEdO-Nv7X|V)XJGJSOwjW%6{Qm$m{mw&gh-owsq@zXTvH{Cr_^5 zv!@YMhIg}J-l?khGrl0AfT8T2^|fEWela0`^j`MPyvy>lJ;t706inO%QAh-|yGM(H z_&tS+LL_OEw=DLn`-`?(bI*4W=>swQ!SQ5zhwBs(1tbPpwUj_od3Whsc}+IzmYvm5 zn_@vGwx}~ex58~*#{wGT;RmbM$pWwW07f*bX9^F=M_emmP$0zQY=||oq^gKL* zdxrX}`AN*(Z#8OcsnflxXW0tK^z=9vyNCqh-hKciH+Ofa*}Z#r@9R6R2=_|6lf@0y zTGgv3es0Kj!Z_Ea8&88CJ$g*yCXyU)Ae+7(XHYsC5ln;C2CpB4@3r2dEf?_Q@!|a< zCR}hQA^W8O=j9$UH@EeSKl^xk-`n9}=K;$lVX$)i4Kcv$zlo(k_IZmc(*?li>1d5_ z`}coFs78&7BgEKniaHZ4pUq1_AJ)-Q+cI85Ok(l|ZP{ydcT*X+@)%#R;)^vHP+eBa zp=eeuZ;Mtf@(z&+Zun%iZ#_(w*V}LYJ>!mQU_=$w-$ti9Gw%yO_x&KJD8Bv;zv)ElPOoP=^4t z5;s^;NH!@7ABFoF=F;jj`CrhY^HaVNPpn-uJbl^3FF-c|#@Pb(^%)HNvmtQ%>--bp z;o(8skG#=!&wWj-A8dOYQh=&}K?wTY!UDsnH+V^9_~KQw{9ms=fu9KYgtG)Icbs$6 zFSP8$$XQUePoFl47Q|Pj?6PKssM5y5RL|ubvTLSV0t49Bdn_ z)N0kHlH*}_Z!4s&lI!u_l@$Se^78Ux%(igEc7H}mSG#G`-}Lg+9W0aL(9luM@CQ|K z;--Q}|6^UeyuG|CLq-K()o@(_QzQohj;xyw9A&rS?>v4{OQlO{Lz#tkJ7?0Vg`-E0 z7F~{V3_~wj3tgLlpzW7GwGB8v*kG}86Akv@gR8Y$->Y}Z=mH%Dj+GO%LYwMmT_mGd z6c}sUh;cFZMtju*uRuk1a6Ad#xcIsyR?`n#EhaW!hp4N5`*D9k1P5>&MN$dDH|QHx zZf$#x+!Lahw!WXUyHM_z@8$6X&Ep~$yA*P=t$NL9<#omIW0~J>tldF)=YoQw<{)G}G4hn$dT}qrS=YmReY7`So2JQ5bdj zkK@Hx!5{vndSPa3eSLwFe6@y6^ZguVGX{Lo!Y7Bvn<$XdK#v^-QzH}HBUesHeL=!1 z=#shpvuuc|;3;0V5AYLDNGZ>(DL?!2TkCN%6^MIX;RzuYF-e%jwv zkSU2#Hcd#_S+#1_04+WG&cu7)Ru`ZRdw6PY6UDydA3n3_+}dwuJl+NzaW}fg@-7 zk8%WH8kxGW@vq8)m-m_rp44eikD3MmPJ~^`wYsZQ`nh;p+1cmcHFdGG`BGjk=mH-d zv}(@ezTdz&MIWLZx4)Y(|L`Y9!nmimnl@57B|5H&}9iZb-l8gD%$hr#CR%$_0 z*_J%Aeow6yBq5iPdbk>>;X}GQh6Y;hs~q+!X8{i4(r)tADNg4&~k!c?6m;91ry7uTP=b;r(UMZ+^w zXDIL9y)z#;urfM3ti`VUQrmGwNNWb_7dt1BkS;)(7QSr;L;+PrzXAjGg9rBy4Xr^; z9KLN-piykXjF44#9X_2nI>j0nGQfmJo$~Y;dxzW}icb%h76h*C^mFEuh$gtVATk?> z2?Hb82=f#&)I<|+k}u@Wz;F&7KKvTSKWc79H_Dc45_i=Ozcx&B7N%`^R(zPyIF`TP zLQ~}D%>mnUWrO>G!wY<^5M+%@MZrsXQ0;~c?x$){l!7ImVB-5h8hn`-ZU#%}ps-M(wKf-UM% z@uAyr#Ym#SNTgx0!Tg{!rg8s zwV*Da-%X}GCTc1<$ia<4(T9kYgojU_Jk%NM+IUoH$BrFE{`~aSD=ofIwC%?hNxR?| ztj2oxeUpb@O6f2o_}O;vtQiZq0ff8NfJziM?#Jb;S6e#dMuN9gPz3q=+I;5^+n;>* z+O^88$-pU7S|Hl+(CzC1Dbmu|IBa0YY=TTe>#AQ zF2dV(C)!2|*dc-oAO-}t8$}k7w!Ws)-}M0@O9gYEaba5#uUbvN^yAB?1{wB+GuI-(GfV8Non~qtyAuK?`686dISMIGYTFJ!4 zB^Yp+X@^%2_r#|<*Du?S^uyEm3qPZyN6r=MiinIV`f2;ew)aYi%*tf$jQE@0PrsdSdZyq(G zjo!wiM_WrI$@4q~mbwnox!w$!U&9^?^aP0q9Cz0LJwV+*IB+}b`;>>OsBn?B@T*BP zEz_;3BoF+>5L0wJvBI@j7HKv5G}eC@O@Y@SC&u_EBBxh`|>IX8jV( z_02g~q!k1GlYRx^zQR#uo%>_5xS`fPd-fEyG@ejVL>628X1{g4Uevn|Dmbq;!(Ie^ zs5}WrtsU9OHzddZP1&YX&+MvsF{4? zb%ANmo+0h3@BMKxhwcCq6;z!1<0)MF7^4i$_FJU|guflO5FDZ(b}*4DQ2ImYkicBU z`3V}P7y2Pi{2OTc?z|0`H+zk_+kNx-s2smpLz>gS?FBsq*-!9~)pT1u>mczNv+Pg-&;K6Mc*v;kV616_#A+0jr z`h$oC5!`L+zul@^-ar@Yk#}!jy#gQBHsvibOTekK6OXE&@ij?>@b2A$!*Q10J>Tl4ZymeS zb%Fy!W`d>+Ut`v7IA2=y?-J=Y&T?2(nzv5nslm%#Wkr!~-H$EM$*SPL5eLB#mb zPYILK8y@Ot8ku1J;KFd>g9Z*Byz$=M{L=SqA%r(S>RJZYNqSIxxFit{fzc5gp=#ah_gl+JK4#w8%JHym z(vVpX(A_6GjTQ9f$dML=+-T15wSU9zMeK;WMad6ZgzBR5f-h{dILv70Slh!rV#e(8 zzURHC*4_ee-z!UGhQ;f)JLB&mR!Qu^m$!A^M^ySSBLfn)nPkO9(1*E$=1#L;^V@H~ z**U-FO*h*=Uc)G&q|gIx%hC<>_5VtDpBDW=jTARIW|%vlHd!!qgab&{=o>ikiP0IJ z?bzwlL$F}c4G%A(7_E)-#0x9zg0o#eb=FAqE{u7Wc)jdOdU|7=j&?e@oONDl7d-IO zvsR<6-PvOIzEh{wUE~&)k7A=G@jlh5VTCTf{;w8y?bOLS&nfPUVcEHZpF6h;y6L5z zaHvNoa+OwVfY%1gnv8vEBG)~y>bQMxqBl9`s?vjN=}Umi?##K z;x9LR;$T+={L$R!4V0s!6k5Mz+_-%&OSJWg&om`qAVzKMOa_lH-Y(~^&QJ(l{SXT) zPdriew`+r2yS1I*P-E}7;McQlJYT1Vbu_K}`yY>rv-SlpYvyZd*{a>N?Fq{xt)4eq z#`?9`mP^?WSgh0i z___>sYbvVuI(7cUKIMkA7+{;uri>_TBp2yh+~W&z#pAwx)4lZeyy=0t8%KBD^ZZ^$ z@LbpF4{O~yt+S=Y>i!4%`;H1&WtsC)W%bfuZLQy}_SG5byE-nsVY+^f&()1l*()#g z{c`^J?Kw-9&0TnILi_P6E=*W?!7rgKWA@5Rr;E!Z_S-N!25@PvXxkXQK(@D7QLMD}1sD*Ut)$8228JbqZnpQ=MYH7LY zA*#ol5zElE4rKMmz0&w zXq5b!V$O!Q6x84KoU@=~Z2*I}mJ>H;xfo!X-}Sm{en7W-q?zlXOTEavZrNpE#ky79 zUq&a?#KiX=sYuZbZcTd0eK*85#?3}AN`q@_l{Ig4yy-9F$!(0wztQB+^i&CDc`s%i z-ddqjRi{18%V~^uCcCS&BwDPyVflf1_3P9LpyViRoq3-4!xZ&CtZeicSx{ zw8Z-(@-d|j5plX}#}PNA!VSA#XfbHAg@uLnf(6H*7^O{*HEGJ(?5+ehoi5f)*k^hO zD&_eydi;Le_V#bBdPXl9L3DTNo4xfuC2Y`VLaODon_ecL8qh^LG+!FDfYRQX*|NLV zhoY&JEL~sX>gpQtb#9n_sEOgdyLU&9>QiLvS86*O4qoGjXF0QGWkqN&8DF#Atz6sG@c0L+TH$?GzvX6mlU>f26^XLkG@;D z?ziloUdf^>K;E|Qk3CbI?Z{KQ(tIHWd^4fICo>N5L^6@=VkWLRWYwi(yCpxdtsKhW z#j>w(Wj?uCOHq_9)ve#keuGxMF6h*lZ)DG9O`_1fgRbrYx_K?Oe*fA)7Y=80R0}DtKfbEK+hKxaMqA6#&A$jqpyylT zZ05w}dT018rSfpB!Mi(md7=}qyU$H8JoV!7mt~zhb@HO?Q*J&w!Ry(p0lMBk>~5)J zRd_J}L_)R7l^d2(@sq@^+H!HEarlPx(_6M|p=x3N>xl$$7%yuWiBm5Q^jvHqBR7g) zXPN@C-LNkDl;SW^n!}y$vzFzHnQ$=YdFxkuXM9Q63{XlqHvQLU7fs%kza4e6JL^(T zd$exTr>~`s^F+J$QKwI*$_#>583of9B3+8S2;SfOw_0;q_FayogqgnCdT&4l9Hn_t z(S_*I#;!}3cKH zo;=ZUIA)UVqP>yS-zlvx4EL_`CRx4Ig^WdY%D8?fD-Ou~(%kn{Wdumm;1~Vf2^L4} z>tSUdY6mWL4=*=Ae*eTliE+5|X!Jwl&o8qAOq54_JK8ws*|PwKF5OwShQ4S8b7utf zUGm!JD-&27mT_g4?Q(o+brwD*6lGB9n{#g$A<>|>)le)zyJOyG+|>q9{uQ^@zO=?u-;OTcGIunl(-+nOgG*T1k`MU(_Waed;9iLyW$pR&(DXm zs=w@y7MdGkKKbVD+g_LfG!HvNQ?6FvPqqE_w!YAZB(2@VOUu>c=g)s@JaOUDr3sx9 zKY(5q!707y@pT^Jk5c-Rw^e=k_|b+cs)e_Uv5haWYNyLrj2Yt#6*cAc$`vb~_h|O^ z?anV>FHF)u8dt+hTYL4FwXhPWZ+m6{z4dPwaYiC?Cg1)QrHd(@;kSG+X`;pO;kM4s zA)-%``v#}viH{B>2)%gkgOoXvGj;w1PkDqyU^Z25s$iC=p)mEG=~(6G2i*RFPCTgeD%+O-XVEryZH9AU6Os2tf(5ElvwG|UA~Ufm-~Y9V z$&`T+rKU&<+YK7T=eK_sx6`3}&3n57sW2XQC*&K$(i&|rArMNU5lkWMi&3|3_S^tj zL(;==Di*REoM2%?F(OsI>Div?+n_shL+t2i)_2>1x5PR#xk_U$t+@bsb(H=Oq-iRC zHU%0n1Yy?==jad!Ip1z`$N2e`)xdEhm$g1>uyijVHX9M_u;LAMXV)V?1YpK_YM#8b*bm3sYgR={s{)x`>5ZXw7rOl|&N^{eW z52Ag0N3HQ4IH1zwHZ>tH#xeG&VX5c5)obp;5~t)(ubqa*f(BskH3{T6rvQKeF8&bVRxN`JqVSFZD!VMX%Bb zX0`Q4D=~dYrQt80Jn=_{(5T8gR_4{is6))9T04t7X)$u7Jtm!GDU#eUAHBo_rhtB% zR#M)%JFe~!fRaj$vLuE2mEWNMXfbpT2QF8E3NU#z@Op1fpkZEVmOsOue420t&gVY7 z)7DL6J`}IYDDFx*LR~oTyv;;eXY?4iHO*ol3C_5!0C8M?*0-;x<*=|F?IxUm*s$)$ zmxosM^7Qye3$T59uY;M4ggOkCklL_P@quq&h5#pu*{;{1!BMW8SW@vET3SiWZfZ#g zLr+t48PV4?PMd0oi8N+NL}Isy(nwS*ZQHdQfpYQaMIxG#&!1(x)q#M!d0d%>Q zv#k9x%6p9+n_=v|ckNQ-QGB;v!=A0-u5NjgE zJOu0^w!yA1s=e!8vb_BBIUp-e>S!t=Hp3gLvtz(c#Sh8i%~>0~A5)F^`R&6OcOt_} zrbUnG3WPs)SNXX;WayxR^=j4f=hDqMXEB5Z&D*?3x~|Y+;~Xq3+-13_3V3vC6tFrQ zf3xBFS=X~WW-r~3@V!ZF)3m7Q=oEC&DS(3+ekX`*E)*sW8=Tm02`w*cGoNJuDbK~Z zk0H`#loYSc4?lg9gPn0|&sg+L$rKnLFCSx7*jy#oS%z!pJ`~88faFC;Imp^9PAmTP ze1R!FNgMZA;4kw&SB1ZOhhNC;(Q?fa<|ZF_aV7QqOHKtXH$rI8YSg>;`Fo{psx!zf z5W>H$q2W*}U5_};Ob*&1`fJjdoX>r#`xIb=4uUO5cYigY}Tl0>(6-@jvKY<^tfIb9QbEi%y|)6E=aAkPf`vJs8){bj;d-) zfCntSEsOJx9j}J@2_QVg>f7;C_Usp) z#||}X*O~XsrP&SX+Z*R;3bl~oqoev0YFj%w1+Mv4zEmOrfLU&kn!!QMqp1f2+alz; zVUzJ&<1Y6kV<5zpKpSK0;8AJauq_bYVmV3sJgA~e286{4U4A`!1#NM~9}XHk_~xBE zBcKqSBecu{^DiBfaRILq4{hS&)oau!7$-IU;wneG5L@|{vg}|8p)jNyz6{m+_TpJp z2+ZhtnPa#JmYAtEWd+^GpIOk4s`Ji%oiQ>2T+mR0tEHKv`dwUuoEN0Lp9moZ!Nu-3 z?X$vG@*mP)(8}DWPY;TILr?F?sZ08MV=<$-5E4Fp3W7Vkw~*3GSPRt=fF6~qah$Q?^% zBeq`I#)Sz%`9a^>qvUmmbTZ`1A&@wNg{J$e1xeFA7Ss-*uwW}%zy4}VHtHJv%3)SY z)YKt;Xa_B)3uLn^bFs%YST)odAU_33g}5Z#g#;8xr2H`<@hK^G{HeXUxf@mBWB=M- zF}B|X`&$*>EJnR@i9(-f3UAvv9v?{M(>+43tuOAzHCs8HuZ*Pn~)QWYSrJ zLsD*8MegD&>LzdAxM9hC%+1X`j4_yvaz4c7=-u%}HVk~C&BQgIhJV><(ZV@%<`jZ*%zk=hL+SIq^9!l&ZUZVK2hbtX zkKs3A=zcP>0bI5e&Q-gK7sug@ACL!x5P}=FX`_;Cjkf@Z(B;EkdTeOKZ#V9_C`j-G zj`~Yv$&E))0j}1k zw1+-4v#4${9cl5q_R`qthiiJxBX;Dd@5Ga|Kq&+g@+o5aeQ-1Lyek&Eu{AnJS5n?o zjnoX2blcm?z~B{Og;!{3+nt$oZnoa$&EhOw=WS*=M?MD%bcg=(l`C!DNxt^uW!}b? zB9;+89shp)eeQ=><5SbCUA{0GPUHF#0?^wvFqr@QWnO(psJ9ubsnx5RMGSF2trPcZ zMmWR{2U$afU-A2o-#WAVy_u=ve=OJea80zLQ_DzWHq@6bxqKm)pC1$*~tn%VR7ZUGCkKc+M}u(Fb}pfz3MFyTAAPA51l zv@3L3xURd|HBG1XFK1Vk9M`Y9xqbt%JGDMaP%Cec7hVS^IjpTv?WVDhmOu zi@HFaaNV7XV&T4bVuyKaGy@eBy}o{{KRGfK$l=4MPu)Cd8Z?$mVQ_?Vb0Y*8>pFVr zb0gSN!!~a|eM}iwPcuZxyx+Oyt*X15xmR=NlrA5XsQFC)j<7wUXbRFVtUOHyA{PyH z(?~`e2c;#l9=~>P5cHwMNYU#)uYH}sm*VElql#&JVWH+g=H%v1b8S4-8r0PG$nZug z>(QfqXtg+EWM#$hjl5l6X?{%X6cba;;AQ_Va=FAnVTn&1;;o#(Cu&SByIw};u}34J z-e{+5s0tY!6`*NzK5>}lDcxE$(AAAZu`pCk`-&hoYtO!4)Hyjhf#ix7YI4}vIh-~f z8u-kD^Hl3^iEV+>a%Wx5y|Xw$*yAHhr?-WozD9L*1`isP2LfWyx}u^zgWfG$aNDg_ zABu~2U%0S3BJ%L}Z@Ht4Lb`2s`m0v29x-y=ULe%WhZCZmy*H~kRO3Ot-miEJps@2k z-$V1VESz_~uOEpr3+O9*zK_!Ph8fGtS9u3$2EGouo0cX`|JC~Z#N9_7Z=k5NBv4WZ zO1^;9QS(F&63a$q(q>O~{`+f^$KEVmiMch~dKYa&p)`q8qh2FmIo*5ws2>ZJdST^K z32GRgd66)9(h{9V*D^E5PaVj(hqdXTmhqeU)`{qYB^_2JNua_$&2FEz#iv-RY?R=%F9Q2b$$g zVG!G-Iay!M?ikx`&zDvjQ*cJb`JnyBV(Q{mNx1lwBO{vZw5g!DZy}w11=nFuZw;(} z7)XuCA*XBR?J7@BNpuVeW!R_Bcy&%eLAW?SDxnZ9cQ$RRD292&HMR(sB}(BY32MQ! z%~)crn43uBn|R99)yWi?6=a@sK+U>*)|8%YzfUGR4taN<+c57nD9BzPHvJLpArl+w zcV6>lnbvn-89RXpje`qn3g(4yD-rTzxD1uC*g{J!@j>c9qyra9p6}J1fCkiII9|`B z$gEXl6Hw+20)%=nL43d8Tj{i z8*}YgRFu`ENs?{^Q1_Kj5_s?lpRDb}d1~A_W`>i~OwNSGuwk7|>nWFq3la1(?1h#cwZjp^Sh6)>vC9VfXFlB`Odo96V`$2U>| z$xlnSrr#;A!-L2W2!x%&KD@^&%!V_SmDi5u1_gtQ5jc60P=dxvlcYUOH3yd8bDJ`a z4|7fo*ua%_nm<1hIZ_C}BS%v^N+N9rDeTh^rZI+GGP-xToW3)ngQwPS04Efe%%q^! z=yGGt#c_+LCR0f9soyXKn=taTr?gPat;)#pi*k03GHTg!fVvA!%YPg@KK{mFLf=9< zT3QHxRK#H7JW78xPkxx-i5S)D`d2d14>s?btBgDYwxf=l2~1lG4s=u?2t8koWzeoG z|3b)Rq&n-G>kfK-h%dcopO#{&O<8C%*~$JsEA7?E*u(6rMzUh^S`Z3(6N-D)3@#Ln z3s6@#w%FfrpV{g$lyr9WD;~0~@nL^||5c~Q0*Js6d4(4g9g$FNH4IsmL}Pg)e;+AJ zGOoWxU#8K$dVJv=EqUg+8Y$*-XlyJkJ!yNuvoS@ZknL&mwvuMo7IfI+;qd^L=#C6E zbB6`(wVEOv2|o2_5rkZAQ@}YY8fddC7 zta=adN3g{yI(Y2Zu|H>OW_()yN|n=WvW^f2JL=Ek4mwV!`KDKUjQI+x`lAt>sH_R8 z>p@bIBlv!wlkE0vGM#Eyy!*SFV0v7ugx9pP^7r8!ghbQjr6tBCO7N{@i}jYVFa9oAg0eEKrQ87hQOl&FGFKn(|tCj+c)$Ji^(M`%81r zaUkT~di|~y&&)ZBk?n(NnywCn)rp>;H35j-grLh`PqbLLI)NiwsYZ?*DXbSQb*55Q zDlhl#Or*4x{%ts6{qedjPmm~;qvp|k&)5+}-zNc)k#-DK3#C=`I!*F?`CIMvlJo;; zS&>QW2a|=dy9bh~`l(nDXF$Qj3B4v@$jQu{)}3l#u9Eb3!pefKhsF?`lRhw|sP;D9 zz7r^A+^rvT_N+6p?DZxNhO5P7P<)j~_NCURm4F=35AWa0_jxi;xQJ1`->;VfIF%8z zh6Pf_N&kC{NOx4gRjGg?PP1o+0{h+o@TTdk;@>d+*yMHe;|`}E!B9=MFX+F{=gMefl7}(v{31ydSPsMx!lY|E19J)`h zA3EW}S)|AU)#RMrM%Fr!CIN&K-g58IdYIgk#6npx3LVh{z!#}ylz7o?iZN;*R5xdr zy${frGf>Lp>*S-ObmfYCkZx9lGeYWfty_(c^#05y*)ove%FE;L&@QwwHXaUybp**H zMAMQb4`5c3FRfjR+pA}PT0YYKGE_q$arsOrx$99x4Ksn@xw*SG$6IiKNJ>rF=<46O!L}sus*>{yLfa3^aUN*y=l`BzjPwYR2J0fbw)mAq3cnUx2(oO5Qc5s98X(X7mg4&q4%w2 z!6fJT@Mdjpzj%M8U8!dvYd$}%c4vPt2Vkde z9>pALm=JMSoAV}ZgccG4jC8w;{YqmHElxooCi51-bk2Ik?GhT-``E;=Cr>7k;x~-j z^jni7HO)SogN!VncY>n}GU4R$g0i?qZU5+9`ts26(Z(FrIes)Q`2O+_6sj376)C`} z@cTz7Qp;y+c2RS&IZsj2RQv!2DwIEOO9-Ks!%Q499`~h289I5%#Lt6t)sS0VIk%%A z%Dq#Zz4d7rG!vw3xt~cbKFQ4#v4xN*gg7D>NROHTl)_b%(HRnr@utJU&_#xA9pr( ziWP|J=zf}sua+TFmqnXdUn=|OFJU<6q(6Zq8J{*As%qh`Y_7?x^ypCeco8Y#my`2|9-4oj zct$t1-lnqmBn%-C+t)9GmlJy)hWBYu^2E|>oY2GQqqGTE3JVLzUVQWFl~59)+zJVC zxKn|oEWxBE27y)F)bQPo$PX=s4z)yY%Y=Y%D&J-}9?Y)_e1^u{k(|>3pmX@pm7%m= zqUSF&f5vc$gvX<-z%RQi!I2Hu+$m~w4yX-5MhFq~pzlL*-TtH>fiVpaFh=pIsSZ-6 zLJV|AXv^WpXghRgpQ&K8)=RFge0%P#S=V`wMiZNext1p-ySRyIUw=}u>;xm7@_AJe z#m$0r0L~vqHSpLqV%<88lSJ+@0+yk3^YB3RxWXu47ix8kC(b>OSmzb<>My3hr2Z27 z|F*Og3;7blD&Aoux;qJ(N}Q$FV#tIGrQ;_ilz*;KoZaEk_WU!tYoE$KxHQ%LJNppP zwtaC+I2rCG0Oj6hYrnXp(B*Gp!PCdEQf*e#z_*mcyodz%M7DCd&ru#z(h?cTTqmk@ z9!m^Ccp~f>$!OT%!movxGxm3yobA3pNL@L0k+vCM~ zeY60)empptEILG+yHl%Igy*whC`zdwfWmcnSzbXoj!#SWY(Ah1o$*(V`g5FeBxCP> zd293cMTtXngE;}m(V_R!vgkp?a^d~;CM&YHbw3FNOR2gMJ{A#liBVxi#VL{#>P+3H z=n($}wW|mMu2iOSpqp;}0x83PQ}i>+0T* z4AuP$Md`d2v_?V}CfN{;zY8A4*r>JZ)uU~UzC;0viXdXY->MI)saaqW2n=lz3kX;Q zzIiJ?UWIPY3aOj?33n?LjTMoAJ@HsB9Vx&|;VC3`lyHlZ-46Uo_Bc{z`QxObzpoSj zMi8P%{$$vti>`6 zdh1r(EYo?nSFJkDclfaXAcCxt2%UCNt`{Tz|Es&Zj8-Elt4>$=G9LL$3XlW|{ zr=nU6MeR#q>QO*SIvdOF0=kw*8|8I1wFfM^Z1A zji^3f1c=D#7(+N|Nf&-bK$G8p|D7~or#?B1I4cT*jlI6^+rOWc4v@=9tU7vfAUl95 zw8Kyn%$q^|{LMvWYc{B>C>DpB90<2(>G_D=r%#_IL!P`OOec-@^3#2r{u5z*piFfD zy;(ZCn^Wv22a(=FM+q)~i?&s)1ssUDtxmL&#$@{!MvVz4(;pY5yQngdcp#}PBHGU3 z-`Js-geki(UAi<24EM;Mk>5d5Vhd~HjA5dM;r+FnHVwzd^PuQfjqJmM2xA)#(j%%1 z0d=a)pylr`?IDB@iMYbydFMzvN|9vczIx?s@l8?gTNQFr<@MyT3)lV7FZu+wT$|%# z!cX_fx_fsdY=@C@N1puYQvzwAddyb&w~SYui?>TE)%g9B*yGItLIS;DVz5VxIf= zmsSQQmuUE8|F?L1vr;fj1b?~xY42#+!t$S0P)t>XnZVZDT6od}gu7)NH~Z$UcA?_Z zrbH24y^M-$W<#FnZqQNQ_=#V4WPc)MA&v;7C;M%WN{V@~j1TKW47%aLMVdHBu1=MZ z_848otKJ0a^Ke?cbfcaQ=KDK_21Jxyu<<6|ZIg z?uq#tn)K%5`1!;x_Lb8vt-a*BqAZ0yq;(ROoWKGc-MwR1`BeWd0#PAOJ&^lQa$aJz zc!b)4wxA2E#DRc>5xZ!e%!KRC6OtS&r*(2Xi_q_%Q>$oiq!5r;@{VCex}qdY!@l#^ zVGjwScaweafjCdh!{HyU!>}F4yO8X;EtXB4gJ20{H3a>Un9F3ea0Ut9ZTa+AKt5XFf=I1x|(ZowM zso`zVULHhh@sfbibb<|%6Ts6ZgY?iVg%I=nIQCux;k9U{)uk8LHr5({Bq*hT9DK2R z;8!R5L=uol3`f6le-<;-ZXy^kw!#ZH(5hOsO<%56RUY*+ER}8J<@St@2M<_ruqVP9 z(IjVg{lDSk=gX3^$ZK9UR0L9qLwrGDsHK)F<0FMi+SFdTxhdS#n?#0m+rCG(`r(n( z<`x0|2ayGpY^$l9IsNk%#X15Kk(h{u;)c$BeBKdkB>ufGZaA0_j?_$apv1IxoU|XC z&NIX&knssb#m;WnvqkFzj@+F)oWS=q)J2Gi-?bh1V~)JyKa<4dCH3r9xO}35 z!*ubOv2p@F4)thv;naW_LM#tnX__&Q$p$RvlDLlqKO({s&xbT-`5L00zy7`4h9#u_ zJZ$?ziaz%n68+Zu#$fXj8e+@_oBTye{;$b%`*XLKS0L|X#DNDWIbRSG&sJ&u@{1*| z4c@f%%9Ulki#|RWJAW2+5vG3G;p2 zoyj~>AC3Mk=Xt$;G$(bJfx_JDGfaoxt&Qvm+G-D@uHD8;L!-Xsg*)W{&=8x%=CP{> z#Tf#I(0hu09K&k%DYU#;%0Zst5vwB%0{nolXI|5&pD*-nkjY zrsP#iMG71^8oM>H zy{K^sEN2M9HLE_!H1VIrh3NAVsWq5q8b=VyX^|Y1)P< zp0ltZ5!|Y@*FyR@30Tz^{wEvjXI*^#p3Zk^$cMC)QertW8HS#K9S0V0#Se7GB zqziTbHEK%E2(k|riaRY5C>p)!M7k_o*xB72SRR};K5VHhl4V7q|68_PcL`Xh3O_$U z_*H`WXj-5(?YYP>adCv4%V7t5C~Zxp)xv&N1d z`DwdrM(N4fe{~o zrWU&2Va%BHb}xiQ2bKmEx2CZc0<5NW)qK$F|9sFY#3e#4eqTg78%OSY9B}IN>D`PV zwW>TNYP1N4{tfe<%ws#rmm6+qM!KBl=YVpE%7}h+SrFF`#AA6)RY5WQ^mB`RS|YB& zwzDUE#$q}-vPHC{=L>vauVe^Q(wN|Xqq8&nK|!3c0Rlykk&#E_%}Jn7qviM0nsRFH zSkQ277i+jm?H?o@yBs?JDpuSf7aCJzU@+J|wCGeS!FX1_R_pB;`e;yXx{aqkeu(?TO%+ zgdyj|5D+Yiw421|)_dRoi%R{t8G5Z*f%{BtugQ0L3Fw`SY7?25)y(qDhO1uVk7lkW z519V*n$S|wE3;!VGMOOKV-af+ZcEvaW>by+v#E}ZVjyN^9f%A@5K5R%$Yz!Vx-u*X z|G|@4kOlp~GZw5Pg%N$*)GQKFAxe=%wFC6(9qo-)ge0L28h>u*^>1izo#%X`ajl4* zz{^J8t64>P(};{o^omk*1OHPF*_GzUau&;>6$6N|H9SMCHt9dlmebd^Bx-EJz#TALQU*s*qir=s?lJiTWVoywXb{P*o@zzTjv*&kmo33) zHGDzPpGRDeIK~S%=3VfK2ykJ_QJfa-P=e89ow00ISMeImzTAi943dbEvNAUA?x)@~ zen5@c2JOT4-XDhUdf|%M(EGW+US|JK^x(5i6E|ta(-1|(?B&ImzNIke#0h0D?ZFQD zXAkaoBNIWR=-9}5&^J+GRft@n{=h|wVYEmAa30_kucRL|vBG>)T22ijG+_Defy_1s z(vWRlFzVpHDUW3-if@f6sxTBgfiv8l-IZ;1Mx%o~WZRTlHEVAF5l$6vgpdg(Jaq>L ztyHtOKdveBePa-Wm~-cDktUFKd4s+>r$&Tp#6>7{Ir)hLYW5L@#xS)&*eq(jPXj8t z)se*Z9}7xw2n4P7UkfS?^vUu&JPX8H-~YLWkM@$b%WjdRNvcW%@VMFY!m7G{bk{?4 z{o^=VoEE&Dx+--J<<#&EDYS{_#CCG;MK4cLhKxI1BBgK8dFnu-U5ReQA$om5&_uW0 zS&$A=0YzZ$n&IsYL?eL>Z-kWp-J|M?+wz}}WPaA{#8YYxwrVoG0xAJaFKnAPe|`s= zT^&-ck>oL~n813FkA}Fg*hIr4(X$d^ctU~CseK3Y~R~3 z#Zs-)Ysqk?bU1J|h&GNUCnt}+D~SfRzvAbhFvgK)8R&y#r1}55vu+tNbh?E-4#AR0 zPKZ_$E$NN-zTgufBwgfXTnu1ADm5ER^_&Ro za=IHS+G8m=Bde#5P6Av3@A^c7o=Now^~Ch&>R=Z(I!b(!+qnNc)Oi{f2jTcJa1_%x zhCj&goBxw%7^ElBXR&`Pxu}NQU_J1^vwpSw3egXn^=;MZ4C-PqszcR6(gAI@%bHu! z+(pK{vXE_vPM{4?P~7^|8*CMn%8Vu6`$v@pJSkyi4%&rszzEBpe(LO0T z{NTZj+x|I+j-&9P7~bd4r)&Q)8l%f(T68>vyn1|NaRjUb1<(+qu<`D}II z(#Ok(aE5O!-AilD^C;>&s8*2F6j4+0;@k+u(3xnstNll&G+QqjV1CQ1%^`MIR!f8m zv@a9iWM1K?p>VsxIISaK;N+RLziq!zw$&qK+&8nYdkZbq=AN~j%~#B*5=erbF=GGm+usqMOZ%$#sD#6n5y!IJ z1d#Up(Cy!=vjXPLpq))b2>^dXN|3~etaelo4!^UUCR27nmi;iAq;5OqEQ6#wSzS{+ z`+$B+VnyKQjCh?ItNzE!-RrjtflQddS9aj-&$J=0*#5ivyYPum>Ub3`=SgZKL^?%Z zvs5mw+?iJmagm(7(QW^P-5LUvZL00^iVcXQ%6oK1hf80MQ>fDhHDcXEg7SIpmXd2K za>Dgz6~0U?Wdv8}J}QK(x|%^Sl7nvQMhk@>FV}UH=rPfEqQIfiEp9-jdZqWxrf*s>TJ;C|o|250 zmA#mh{H-%mi`D#Sw>a{2sj_zl*v08|Y}H=gYN*4G zA2|DCS&mI zHZVWO`)E*??Yl(yLxQUgRA9gUau~>1;PoFS2jq5(2i5w`4=dY@oZG`@Dk z-Qro);_ySrV)l5{Q<7X?+-*Ad;jw=5%)rVt7WADXbHaiVAibjLEd`Kh;*RQgh>dBX z{||)9nZv-m0})ha4%g-ET8nU-C;^D09KknlbOk+;ljYIYCsSzhg#;b$dBIM}8&gJ#p#55@GfAC#m?tXD~cy%FUt}~Z#k6&Fyp8*@AiPa&YH?S>7bK2j(eW33pK!J6txN@x~SDbHUU z4trMR-7>-E#1d(12T~^c68U2Mr;or8m@qLBeZGD9C{F9CtZ}n}|5@WhM01Wi&EGgw zK}oJ!ycPCMl%3(f5N3`Cv%DD>H-u8=I}41ZasjuNLRr?Y2?F*%Nlb+imMT}Zh<9H` zeX$ng~9lm{M*%ulSH012J%$QisZAct?g+k&jXR;-7)-^Ae zgKP<0BFC89r$=60H=r^WNayFxGk4#-nT}9pg@iy%s2DGx0%@}6tA<3n7_`~6qj0P< zK(uer@JB{l>7R~S{_zOlvNn(n4Ih1){;uS-slR0q$`WxF5+oBhoFG1p8@z1U8TbVO zcL<54E2hj*5=ncZy|yBmbc>KT5oAt z?OhO}_+xLTa)&|y;)2ljbX!dnTJ&gPf<~dAb@76GgAVhoPqLD=`~BRKp^piGqh};E zn?1-=I610q#-UYq@9VH2Hy7lFgG;9!yO3W$x z69*(Fv!XB_%s3o?r50ErS<>;IJ$;+jQF0oQCCc5Dia1hjrS^Z`SdOnNtX!67BQ!J?~7HdR>{Jmxv_`baR4?s(Y3y^Oh-@&@|`g;u7^ddMlbEX zFX&iQL?zQJqJ-9IzbnlcLv;jWKZ`hw`*>=3#K6sUoEad;Sw9MJY0qYMZDXMJHBFowM8nj4kvTG;V*?TQ z@kJA^b}S?<%L7P>j=6Q7Z4`c;)zK(>2;z4|Fg(qXnsUJ!L@|Y3!<+tD6eIQMxCsZS zb}*sgTk9ydy*)ks;nRmhB-(Q5iq62|GPmjo?4-{*v`*fNe98vFcu`E9 zRu(U%zlTQJzaBGOv&DCzI<`)6x=oFJ0ONf+4@tefyTbvQqbq_gqM<#0Z-EZPXT!8i z>GK>j(0sr~SqLMmfPqI+`Q^Kuj3NoE5GI_ZwSCbt(fRO|gW-w|23J#_OCe6xUmN-p zcaGhT0T?CHE>X^CBF0L}&?KmvuCN3&=D1<78ByO|q*LAO23AnqE*x3T< za$9r`X{N<4<2~WC>8%>feK2CJIJQzBv^SD(a_i2WcH#ulF3)$afnq83KTIB z%&DAPkd}juk%N!vnBRxUwWTz-lt_X{yU`3Cfi1^SG;ylp_S$PGtrjInoyVNP7?|oP<&M;_cIvtbtk+_9`*O=po5=O~r*O`; zQ{MjebOZ~NfR?yD@z2Wp*p)Xu8e?rL@Nugj;1_g zPwCUcp%q<^O1){*CNYA<@q2WTiXQq(QO#@=;N;*t^`+ui?&ppcgNA#+W##n$j|nJj z(rZbW8O;oA%jgK++v~>rN#4ItoS3EKZZF`QZ}<5T9MoZIRy-I!CKOw3TECK_xAHal zme40efch%1u=}XE?9MKlwXi*&2| zyhon;&^%nH|32Jo{2n=`G`%2%3Pt_UZDR5E7gdRXgkW4m!+AES>I71I z5#;dMijy}|Q&Di2;-6b(o- zFu;X6kdrYJKOpCUwPss?MwUua2M|;swdtF793gfIg7xf)t>E28g;7Z>YU=nMNSx9! z8d6?M=Z0}=!RsYMpyHwn?>O|5L?AqeEU|tbf_BmA<>uwhaz!j3wuF8^gbKsCRXg=X z=!;Fh>DB$az|+ z(of9TF0CD0bO$QSg=&$gV1&#OmmpS!(dZo@dI-+gozl1UNF1b#(g+VS!W@JDY{`%$ zS~VQSj}23kO&*aW$;TKBsiNE%32~Q%W3mezT3cmJc0pH*0Xq*>bSrE`4UF{gQGGjC z>kQ-YzW+GIoJnv3VRo-Xu89Tp)u~4?(ayN)>8?M-D;NSF+yaMOwQ=$d)c>rwEzJJy z-)ktI0|bWSzip`j1&tYlW}GuOdC6$S^9bl|2nmup5FklP_F$2Z#9LKRTyKKqCR4-K zXk&9KbR}Se;p$AdrA&$lw*48#Z1+O@FXhL@Z&!+30IU4O<{H z4S|dHMfZUa&@AB#7=KZBtEGDqS*bj2Z(}pF$443`lFg_-D{gV)a8%-jhN<5?z7Pt8 z7MJrE^yr?TpT@OWbs@CB7tIjnA`ES}C3cGD*`!V!NKU?^KK|7LX^@Gjk@EQf;!X=t zLvMY7@1Qffno`{~k2!*exFJ%nsSd=Tg^C)5JynLAAO07*@hqkm+xg1+>JL7_Jq+cx zn5%OD7P$P586w!*oT#{klL1KQ6g9e&uJem}|%q{~l=)@y^K2WVVYX1568#nI& z`e_M?K(n``^y#`a$PD7AG~76O$v-+0LVFY(Ezow!3Ii8G@@z#ap}C8<7Yd|{i#YQ6#J!WVY))LgWiM?_+#^?G#vPAeCj+$^ zHx&$pe?AmJ&a4h3soSGd-7U^;-s89yFIJHUw7Gcn5*@+fJp+~HG{qV%5dIR|s{{pC3bFqJ-d|3lW9z~!8GZTvP1Gou;LGm|9@W)u<= zSwqWMyCaf_RF;G!Tcu5E%ws%=xXa##WJ`#$L|MjCQj|(1%!CSMX`@u{_uM>VdHcMd zdFL^?@Bi}qo!>dvb)D;s6E89e`*0It#RgR-D;t#GK;vSqgsWr2(ydtN$GrWd^EB_L z^V;ZE`{V{S3YL-srwh(H!8)pf8ThYtI~ljMDWjc(Bx7e=BFnwrB` z=RKn+i_hzmSH$6v(?DO*)kj~s1qFu6(RP3QukS;2Ko0}>S!uRYX(OqYY>O;q0Fi&n z3^1?l^ml4I!@#$t&Jd>>5|g*L_I%&&_oa|&t6mzzpvmJCQ2Ux)xZ)^<0(yg-!+;o3 zzUz6xx6>F|kud+inpzaL)3es&ovQo2B47$dQ4~4fjvv+X-21(UefX&Z-AgurHpi5K z!fkEyZ9q&lE{Ss~j=Eb$27vDDgdV!{bOVHeHtZR7%S0|Un?S^Xj$uR-bzoE|yvBTx ziEj83g~aqMOi1Jc^3LOYiVa%4NHg$018=&h{=iNYDD*5!tK#*e9MjG`jV>eA7wKj9 z(Z_+DEy((XC9fnHC^{Vj|Bitu-Ve^5o+Lj9SLLxZ$GFHx}=HtUrUXZ<48rFOCOqa*@N6=`z#~%=V<6Y(M0_|^8R?g z;RA+<%8VTI5Vn^4MbHgaMea>HkDqo`U-BhUD3pc|fK^4n$b7s)YfO~{-vJ1aUqgFR zCE_~HO#_Y2S%+nage(Ed(TPW<|0=)9JG}*7#X~zkCL>}q&~|r}Xuv;=IEtr1s3*$_ z%;_3bxU|oJw-RGFlpTE_bh-G6>m<}4r#sk7liA$$A=m#MTmU;~VH4mV> z!vfIftzu&jq_@h?$!oRz9c{ z!lXbU{65Txm_|w<2x;HNe-95c@i#84m|a|a82y0@h?;b{WBuNlUqx{x&6$sB)W+Da zVATR!)=q$3`r94NKdDzo*mhQ ze-O>=T>$8EzWVmt_$MB3_^%_W-pjV10P?s@^J^eV_b9_>ZW%^%e#}L_cd?g75iL5s z@yD1DdXL^b@zk(iyw?UHS#SCkcf(jMJ)6UwY0+S(T%1Hv8ug^kzfpvJIfhBRa`&K` znwk@e{`24q6}0p2ek%(}9Sc46K1r-UV4tbx7@BPn|tC_cfyRAdLY7PFQDH zPztQ$dfHG{CKx6$K{zpoIi>+GFgU$k)kpkW+fZOzEDh(u?;Tj)@T{KPGLUs2eflz} zSJ|4u2e;NYH0f0H8lpx6cd{MH7VPR1-uK2nX4Q@4{Bpev1mA{5rgS@7=X-i@xKDSz zX-(s9liy@5i-5uLk8=!e7-IBF+(l+x8FG6L$+h$40g$abdZ!#MKI_8^{{y6PjWEFT zj$L88@wSb?C7N}KG=gGx=zKXsOPMhdK=B27iRjw1q6(v@C~)i6t$mB{(f)OlsRK~^ zQLp>lN124o=<NTU?(q6aTErApbYRc%$QLT(`@mQrjYVOpqvr8x14&nl130$qz>nt+1)SvI3pGKojHbt>i2HMU}L zC-@iZ#69TCfY9mVJc2?jt3D4}!K~?_3zO>;8i3Cp^KDXqBia!Ubs)CUD&Zg?QcBs0 z!YT?B%_%49t2M|lc-+ftOS^8u%UUgJ+v4D(JIr@gAbEk2Po=1~#Tks{t_O+CumWi` zq}gYFS&hKQgi+^m-yL|$aXYEKG&c2|I*B)%3qmOVr5HBO;}!Xo#pg5=L~gthI5 z4V_m;)Onu6D`)Rfph(8zUSLp0FYxF<4vHytR@8x723CcgbxIG_o)62fkZD)DIc{>E zxvu_R=V|`4c7MzWHS*1|9pOlpw!hl%d2!c+-TkyTxd(+yh&s4LyHM{t{D$PF*v&iU zH&UTpL|U2e&k?!wQpmqsTt+9Wdy&@guQW8q9Bo9V>Dp&DW-LpuPeUnh-Pna$H8-P& z7Jv?iKR%OJ^;#&moFuZE%dG9qGw($e{|(<^;j!RBWc`KE&bMZQ>nTq+(N(rqUV z1()+3h@i~-P|d#w2r80P+E4_9ub!5=D@b;Kd<0DZ99smfnhgi3gsHl)Y z)2f#ZW;dkH_sdStu8mPgi=vye0Yu;wUFyJ>#URn%u^8Hg7jA?BzC`e)E4LC3e-cGq zeC;qRvC<@AJtY29vN!3eR3?aNzsU1kd`TLnv|IFzsKljhYGrE&v!fZ(OA!#DaPh_3 z8(3>RR`q6-?Zwn_3>6x(bCmj!;ko` zgQUImXy`?hSGVo1ZVN?fA@qArHHj54;M`mvC=VO*)PA^_%!*jZylvLYoLoD_hKZC8 zfJcYXuuKV(3VYQ0_D5ZUHf(m3{}*?N9WHzBn{U@t@FB5Sua#G))6Nr_vr73Sy$$m z*;U{-WTKgNb>|;c0oO~DLrV3v*o#@OL~KLy(B4F5&_`Qj_aCcb=#qHtwWqH@~$z`@h_wePmNaXZd zy3Pfm)!6VMx=K>{7=)|yJU_clAv0lFIt3-&7=seA;F1K4{07xoOx*x;c!Y6$qabloAmW9;kE zZMy#Q*pr?g{%%>@smQB#RD+w8nb$I$H{IGJGW1`x7}2_$-`)jl-MFv+Roa`CwCX?> z+_jA7nNYESfeZThTMxV9iFE=3!9f=*crsNxD>^A7Fo-*_dT@0~-m&6gp8aL+Nguya zFQ|;1iIx4Y2%!r9&bm8y?%b?rtvZ4}3&HAjmr)+GGG1^p%=i3k(h}}<>0Ok%bBV*U z0!OZI?_4opa%y67N9K9tqLy^8ul0EHg0jj5!NpSh_F>}%<{-`Nhf$WxrXcYbrLbLU zV==VcVCJ7+xn!<#Kb(I>4(XYV1a)&&ZOyZV$4;d6q)-i>J#=bjs-xh+@&3rp=I0w8 z-8=5IjdC?~EuH^|$?KG-TTc}|S>67SfrAHMLkv{w*`uq8zs^}hQ+TeYZEXMrAj(8D zzCl~S4pjdcnHd4BhpU6WoBfYCPobz2(-F~FP~Zv?DONs%QIH7MB(5+tEzBNN9`*H)Da!4ds=T!(6PEBiaF8Fxdz}ENhj8~N z;)b#NclR*leX9{b|El#YJi-mGZNMMTC)1-a+rG=-04^G+;N|gzWc_}tYvB6&0AC50(N8QoZ z`7+~wS^%FkI*frKHlH}*nQ-_Ss0&EqZR_tDQ#?r7(c2m{+amJ;>IbVkgXst}Lxw;z z@=fA_TlYWRkIo4SKmN@~TSrw=&79ouhF|keCEq~?L$a|4h7H~4OvdB*m$?0YY35Py zFvHic^_)jssI>O>SrpWo%7?9=ZzRTthHG!x@&D960YmN+VcJ4-w(2pt)5Tty{rbl@ zaOt8(mZh^J(fc0IwVWr|zkm&kQt}A}!$fp4ecZSFef+>OvzlL4#9d*gLB)-*E5Rf; z6N#GFdtdC;mq~#ZjDuIl-a}DE72Dye_U3hUUh?ab!+l7SR+D3)v5Lf(jX;^udup%kxQxSZQx$bgwvMTHZCL7zf+o&8kd*QQhJ(}$BroJ1f z8Ka6upM8FzxtNf=D@0}3knoP8XFx;i!5URp^rT^O1&2dJl=a zOPMU0uCk`OM4m5sSBJ*~=-g3Su@R1y&FOK>AntI+fi*hs1hJVrfocNqGp)q(}jRl7V!-1MDK$DNE%dWM*W%o$vcmBEx6e0IMuHB9ue&`fxK^V zX2*ChHWWx$tX1Rfk<~NbY9#A?5`|UA(+UqlO8;!*m%E}oXQzMr)S`T2hq0|F_R}6O zTC%d4-XgQo#I2ckpDmsgo`{qzu@;5loV@LMdteb{0?X!%0={GZ>%mhvP3NC#lZ8ph z)b!uA&sL-l&NEm@2sq|+7Y)Z8<%Qu^vKVNPI6675%y+=dKmR*_TW-v~AW)9Sc&`d- zAkGKU`cLW43&|400um_qv!nKO!G0ca=@x$!gSk#*zi8};-| z{I?ywf|}Sg{o3|%2j6W&Vq*ViXDZfPT(|y}pe^5ygMgOH>tj*Q}{lX)#H_H^5yrXz}^fY~e_ozU*l zzUe|+ffbXlCa7g3HRW$i(9*GG#{n>1%gj}{vHQ{sl6#^9(H-$uz>+ze^@bILR9UVUp;Bik>Z)<)Mq0QAp+uA96JmaF2bE1 zjwsBI>a{wC`Yg}RG-6Znze+t3CM6!6fb=Ujy8iykg)KNKPblA$fYiPoq#R1BvkT>>=kw*9pm z?L1nH8a3*G?~pEHwd>a&r!+tRGt;Velis(OIK#C1h^L8^Iy<6ZC0ejduSbs?Y3yH! z2kqvkuik9xtE*GmVAv1UG^mDHkid79bs$j+2>oaq@VI~4u@q9ozKEqdy{0WIg zFdp;sf$L8Ae%l8&CL?-KD$mjCY(Pa)W$3U`qpoON+MMn7K;-#|t6aC&i%gkrt;q#I zWC_YZuMOb6*PpF1@*H?K&S`Z1`A3mV87VlZ+S=a#Fx8Zy0otw z14o@W<@i!+%(7|J4o`}(o9vvO*mL3SBj+Fel5y%+4yTFuHKs&!5R%$=Gk;|7g*6M z%?S9K*e=XHbhU|@Jf1K(wr0?5Es`ROI=VO}pv;W(qy4BKB?q=kI`%Fhh$(9amVRlg zIu6sL6!$17_18Bls5Xu}cMU5Nn(3QX;-FzO|I}){I7Qe6d9(NS91d)x$_eoVC_hMn zZrR=)V~j7y)Eirn6AczGs;kFV)iqTFt*D!5B0a}sbIRIj*+Ei^+ZJA= z#!tI3V+G2VeCv?JY0ALv9Stdf#?z1@b^U~=Bqq;Fh#)!D9Myn8jh4x*KU@DA9cMB- z2>k_XDp34ba+WPW?U>GDmu;jp6)8SkL9u4>STnO5(gU~Jg;HxGJ%w3yh<_Cue*Z8M zKIB8ADHkUfh^Zrma}3Y0+mIo4!qm_x1IKz0{>q&8(&PIF%?RS7_H)$HfwWn$0W1_s zC>Z58g-_+mUC8T5XLj|Ne58|o>m57F$n&#>Idttv3o*eY2if{=Q20ygxwv-HTtvE2 zNmje3uKm)I8|Q1iD7I=5ZF$f@Ao&VJ^9hGp)fjwr9h@}x<}QpaV&jd#H{~Rfq7k4& zTxvlY24ow)TtPxB<7(9;0;jDD^on%#Q&4~a@1g|W$6csJRYFjR1^*#WyKzbur8U)= zOQL^cgSvkK$94fP$cD42TSS(|fN5u7>b%-PeRdI$)#ARhbU~7WtEbhFiQsK{kS{L< zd76YBoO{vbZ+EP1QYIV!okOsy-2Q3_s2*F_N)AY&zY?LOVCgEe76Fs4Ua_wJn&e)r(R5!~to5@teGJ24Iv zffL-gSUCtKCo=v1uG{1NK^rIB8Agu$1SOmbsY2ClHe(gCN*HPT({9DrdAxRRFZ#Is)74hX24w-M+>A&FJ_>A!}_t zX?)n$uv~hXA+XHcL@24zw(nIh;QI+r-CjOOS@kFx_?XMt^7szRvp*Yo)675;^L|49 zggFA)^mhaRnw@Ld9?iT*8pMNZ?pA`!oD#bTL$@xM3}L)x-r;~L;SaYz88GkG)q$Sn zp@mNx(EeYhpKIgeP!Rype?7sdne>@>q*peCGskj3;k}?J^yiKj3k!|1Ux>V&Sg}w9 zK$nc^xiDM2jwPW*tUjifLrFHLw#GsA*Ry8}YibMPwpKf7i$ZQ zHY)m&yN5?VZ7ms!+-k(v$6%ncAJs1q@+GLA{#ewP3zlUrMDJ2+E=j53|Z%AZu;M;dA|Y%@>dE*NZ?d zb5&`g(Z6M3sDH-DJ`u0Qw+1D*K;6ZbSZM|0A=90xB8SSDWkiEAzPE3&JGdJJb#}79 zzeEEubC4vNSeRxzF_Lbv?(|2Hk?kg55KBJq0~fCRE~k>{OJtNOQ0HELTP$rb#z6FL zMWhrPAQ_lsgzXIVWaxqewh% z;&vTF>NELS?1CV2lTywCF*=lvg0O*$hk_wZl6Ehm34DXNl$6H)t;)K)til3y(ZqYp zD%MGpCi`BO#8bd7Ku}S5o)w0QEa&lK&f>^ZixDH0<53J zK0zvIzH$w(Kie!IW>NBEsWN%QO+*y8KPjTF{4$}NUsV}&` zbkjk&h*ycyijTa7`2dgjUU}*ePK+$+C-QGtCX!{um7=quWA_01oNk>unet`{!e<#A zJTo^mul_ZO^IUj#oMV%vK@aGs8(nVv%0vXF*wjR^6T~^~alr}kgK)|{!{X`{ajBQa z#iyO#E;RmWaDlcvoF~Md1;Eo$^ngQF)U;5{c4H9A{a?Cz>}7F=z7mGd zz!0gj1LhREN@?0l*{JZcJBmYSlnrL`gk>H1ANLnmG~Q0EICpB{_1}$p}ADj-HX?QBi=&!9oGqTbm$9CE3SbJ+s>g}lL{tt)PZrpzNp!3f&X3k#V zyy8^A{UpWDCnGOBe7*b1V1t?4K3ny$V)@x+XVdG$?!69sR=>)3WrfkI_m|o21O$MN za>|q`O}jQxgtSI-mPydlsjqYULSf$S9umsroUjP3Wn5$AlAZ4*>#03qsFf@A?{n?N zcf60@Da-h`qMPdp{{}m{k}#}kx|h#u26d6(S^>fskh=22Z>240VSzEF)Ne|K;^C#R zh}NFV*JBBGy1t?q*hZ!y4E$CRIYb)QTz5P}z-j$U$Uim-rl&Z-n-^5G&opFklfJ-K z1xXch%@e@4T?^!T@3ctby;ypWul?r@KFkSgZFtA~1yD=Pg-c5msVnHmDCjS@SyPm* z>LmZN&#v2h$^Vjrl#y5!Y!F>Y-hGs~FFT=|5eJww{=P;9lSJ|pUUz7&c%}3D^=oy9 zlE}FLNSbzKQS4d$;a!VX8Vq%`N;VE7Y#cvtqDVas1cDo{nTnv(1vS|3KQ!4FgP@Ny zw3|`}!G$o>YZ2o9DG82OzQYr9w2HK>D^(@Gi3Qm3_d@0n7m~K)t_3K#wlYj8XWRAJ z1=RsfK9`5o*kBkE5g`**+RCt{%`en1W!(ptmrI^Ct3bxe%S&;Z^6Rg^L2`6aGWaJr zWtPFfe+Q(l=S>9EEvCs-5Dc|6RrdV-;W+gj|L{4BV}gt>e5?ECy5(9^cR{$w_(9dJ zM9+ovthWjM%G9Yf<7x@3qgA9Fo8V8?w<@i8BH~erbqp~m>;HenT`;-%1*D5jyRgzG zBA7hlx@zT0E$+ObEIneBG%~1fq@}#o0DHt6zuo-lpQX@Yn$w?hi14BgMn;;_48v*}RI7<;(Aog19=H4~n zJqJnFr;^DKfKXPo3d?BOr{GN1Olg5fk!n*UgWQYacfVJEt@5Gf&{5KEa+7*MmmOhg z_~LzOZP^=zwoGcw9OrovN5(T8$Fmo8R|ITiwe$m*FL*+{9{);V`N+%LJFomfIBT!9 zgO@m^7z|`;k(qQTldR73H3HrKhFEMp^Meh(vNv)88562<$9hHTSXls_r6EP`Ph0d1 zZH)0f&AosY?LyIiw@Hb`aOqsh`!M_ft#%cK5-1(lc_XGoC>FhK>>&V^Q%kdS>0XUn;sy zL>6$9tMaJp@Lw4!uDeXptpdff*taRI92~+$fQJsNolF$kE}jsOhj|W-MI9R>Wq9HT z5y5)k;WIDS(Bb*BGL?BNutuALgB^KMrp_nPOD(V1aj)Cp!BbUzmpm|oM7Vq3TCvC< zxvwl8WhPCg!uLFU^r)?TRa=U$yWs8)JLc zCmy72olG3+otKveL0^LeSp$?EdOn^$Mts&bv*VWR1=N%x@d0*pN`oEOQ8fl<&+9jkIXT*PXMHAi z2~L^xij4UNSv&7Tp=kGBQyFg)e zFtqf>{aK3V)5S%O6_gilRdc>sM^PMe2y8x|f@dHY$Y{xT#EnyFMII$C82cz@ns(*% za{BsxuYl!FBZNrutkiNx>A$Y42H4%AzTNMn{g>iSh8ZcXxH?oWLdWym&Rs0!fy!`t zI;E7Eh+UaxZHoI@HWwH0_acBY;Z%ijRCl)0Mkyv<-|`d{sck>J|0lq8GK)}h0_mI= zFIMCjTDPUE{8GtJ=ebI(zrEaqpEh8VXPo(7-Hp^WO36!dn zY*NCh_I`5oS}NLZzffoV{J~+fX(*hBwA_OdjaD)0LCkO^gsDH{zarzJ=n^v>Z82uR zSV!bfJmg3(0nelf6XlJz5}ad@-3M^lbg#UH zmLI$Y=i`_cQW#7E@YhjNc4rBV2`~aVoO$1!b8{nOqe$zWf!$8icdd<1YOL6cpbGkS zY%_(9SUXUJh-6$_$#Pc@ZpCo|=A=#NcZx+Ocua9|`t_yb{3Q%%?+3GxVG?k}=SwNY zr;L1m&{%ip2Y6jv zg;4@W#1)KMjzshMCy-;{M^VmVJ~4;ilfPF@8oS0+=-=!)j+ z5>KB#fCf{jYB7eO7S#&jhLa?vp&A1P68nDq;9Mg}CdBsyJ&sKKX&Oz<^HB5s!Vehi zT6lUQC8^Yz-`Gbj-hpHS`6;7enu{AZK7|Y!xwGd7#qyPcgX!swCW}Jk5F(bSv}Kf_ z@bSFv!H`!{Rk5nZxb=JhQ%Ha&1v_VRWlZ{gYRAxt14zf)qSu#6?_h&&Hm zFh9;|tt1^zPEJ}R9o^eY?+(=kV3t_vBDJ}>k!^oz==*JdKUbbC`s+CF=ohr`U}HO8IH!KB|JOI@jdd2!GE&nOgJIg$pVf(AZkG*5IQafNlFoq{tz3vUujb4-_DPBxcCGszm8SiUd6 zBotV_z+g@7hKgw;FkG}qP!fVVs+)H59{KqA%%g&n96)1?M&udsbP))j7^rzFnX;-a z%bL#@9|f$dQ@_S?>Pyr2_b*>z`dTWTsip z|5S8u%Zxsb%ar;3W!V$yN2DKgID1+;w0I}tqQ^4Z6ROFsobbV}l&=Y~r{Y@~etYqp zQ$;{0H4~qnY|RT2D-)0m|K44iT5iLg{4t4WNZ$M$JWHlL$moW^QZ(;ff|cOEljJxx z!<8WQ>-Wd0e6}LcBk}h5TR#l$kZ^mkKEfvfsfG?4w${&;KqvhNB%Fk0GL_WS)b2|h z0E*u^(&LaazPh^l?KRK%``lLtWQ^mzmYuLxnJYX?U<4z(#_ntiu-AaEgGvNSZ_bfG5yc);S z;zIdY&LbNKhr9M}()5foo0tYM%2dp9HRLg9d~krtQwTUp1sYA=u+utSyr$9UhoqiT z++WT#YoP&Tk6<@N&9OGLHMXUIQQO9JiARymt*M_tn}GK_Q9e=O$W1;0$IZoPck-^{ zZGD|pxF3zKY)_`b?0R}>IvvZ4VPia@XJkaW*qotil4on|S9|r+r9B~HKWg>CmXl8n zcpo6fwa)yY9*}D_%{wA>O@|>TC+9+bpWxcWaNon8qKV59tp`OATX`(KDSH1~ zT0aM;M3K}ogl50`39R~mpL&yeU*<7hikrkD5_hx%gttp^9TXuWFm@#^3eKCq8SNS% z@S0E?S|}x?pqCVLBYK_94S5b;icg+H$dA;(ldzoV;1h;Qvs4HUq;kEZJ_~r+Sx4v6 zBt;$v6?w5oZhJ5hBgVFhJvVv@$sp7ZpzhQ7kT*S^MxyNQB1uB5?SJ2c)AfKPo2Y9p zZT(J>8i8TIz^Aj1O@W=uf;v|(ux>mA!U6;3ucfEQs-J^zd%R=G#Cog9kMpsOsST;NRk?jf=%4tsfNKqe4P70e@ig=E{DyJ>+>E z_xTAN7-%ujh@0>H4&-Rm$6MI6lul8)$ZyhIcAVT`XdOYe>TPSCUHg%T&Vhfrrz zSQ;5v_Vh&R&6!gykI!_^Cy&oGW$n6k*~p^iNUO)n%ak>T$$LSz3O&ZKOQpaD5AuzD z^5{`CFxdgp!b{;=isxAwe~Ybfk5({xC$4twCBvbR0ohPP^2UIsXPa69#%O;M!Oh1~ z2w`C*_I*Q(FxI*O3m=X6CM7++Jp2j+uFM;uG}qiIMj1xFXa{!ZzL-|#JHu>kCEL>` z4sm8ZAw2W?@aSd-IT>v0JNHH>3fr`RC@_;YpX?URWB~3baIJvLb~_1k{xaa8R)4e6 zg3$~Ibqq3!CV6Pt){|_D{}{vNv~*dZI9D(Jj~I;o3125Ny}4uY#v$!(F447gSx?K;ljp)IRXv=>d*Ie#gdt#bl%h3 zM=z;(tc&O#uA&Rz+Wh`{UAl~<=Huu^fh`lV2K@l5a9`9g{TGFzYWQ^ic-d2Z?wG04N5S2&3ZN89=C8E>*B|B^ICfNG;k_4Bd3 z9^57I?QYfaE=rN60x8L`9H5QV+}vDwup!=`JlG*8kp}!K-Xg|zPx};l>#7)gMW@{X zQj!#=2=2`2L!<2V2$N}27oj}s=wrkn+pAZvS_1R4T-gzOQD3w+)Te`;eL>t4Qtl9R{`Pk~j1XpxSAWkHPI=QPy)hPnZ^_f|~kLM;MmTF(nxtV&t zca}nI_ndshlyB~lfZ5OUC=pnD0|ve2c}@3_lm8?W*QLjh7w1|;PMk|DK5(zT+|w?!x#FZg z8Y0eP4tfEtt#9cZf^8cQ^1%Ax|AcaqXgq!EfUd|s4#<-Z2h^h;TcPdjl`xjHM@zrfvYWT~`OxSL; zGt#QuyVogeOUuZ6ZypVhS=xP@DI)(aQg-Y%$?R8tP-uTm8j` zcgN)|gXIJq0mQ_{5ja9v|2<0=5Tag2@4or9U@N-IiA48u;Ol<X<=RF<2w?gA9iDKIaw5U> ziGdY$N)$~7(ip6xA_%nhIZIQCKv|6an7%r+#mCf*TumgU_B@2c;&aPIY}&CyPbtnc zw5bD)@U=_=9rru=S@%Bv{{AQnN$d5nz?UPOX(8Vj<#jeOnMUBxg#;GDl5SBmZc4Nx z+4KBLG^qY=;3n#y>&1bavKbzx4X*>*$PtzmSo;GTs1ZAqj+3UoG74 zQdkwb-D7F>qlsKFXiBpd$6>%ccYU01SPumj;BPh=0i+&~N`qrROMHqEUdsb{ zehrLYypW~;0Uf9n!ejQ+UYb&fUN`&e8yVRXopN~@4~vR+1DCtd%{qu<)h~i~)A)C& zoKP(POgQZhCG|{|Z1?eb9{GI=05gHFCBP1j8BTNNj8@Slxyxpn-$cZSxg04Ikr=A4 z=Du+1Sdxk2nj|v|vRra^Zy%rhM+`gfPG&$~nz(w2Wgjq+NC-faM#TEUr-L_CB^e-L zu#z5TdggfvYajYGQFQP4@zG?evmq?(5Q|J&{sa~xKhajwbt*mUJnky=3+Pa7n@R-s z5>H!-(Hz=!4uGz^uo=XV7rVZxa~+rzhVZ6RLJI^o3Z^ZTbR04B-9Rtb4n6Fz1d!u5 z=qWust*S2yhy}Ypc~R3z3~`UyZs4>Acf>;a+UcsT~f06I-J_K)| zs*i%PIOi|T(GEVPI#Qam(#2|s-hA+su$hBLLEN$1as3|kBRKx;oG%sDb{`+l0pHwD zj~m5API<8PiBA`ETB-dwXs$a?x!SExoI6O~kxZ1bMy2@z6WGb~pFBgQ#h%QjVRo%6 z)k!8SlfcZ~f#nB7>U*H{nJKiU%kZjUTYQa@&^QG1kelnt&DZ z2B`#|3TCTge#DxMB;bk6u-RDEMDg&`4b=XWP2R=4ylruWUY$SZ_uo=d%Y+<>4iGX< z@wTHY?Ld|;-7NXwjqIU0ki4xM_I4i|!@HB)Jk}6Z0 zZpt@oqHt>cNgU8HA2&{V1~mWmFUx&T|H7{p(*Y&%X!lca16|7f-?P6FSjJ`Lp3A%0qxHf=e*W`a*lNpXJLU9Wy|r4d6(_&>?*~*h*+*^L>=xVLju4t zG_F6F#vB^JYl}@9_d|WjNy6F$Z?T#*Y4grJMdX>I%%cJlMngf=6f2kAo$G)8I0gm$ zChgP7^<@j$bU6fcR~d3G>}o$cp#VtR0{I4{>T&u4`6my|0HI?use43*3_&6zhBMN> zB!7s%CCCYUTo;%YQn;(QyVE9&oPQ%76VN}1Y;5QDqz=Z_t{V`MB&-6n63@0=7~Cu% z^*qHAOt|2VC}4NfI&hV?{gTNmKY{5q7GVH?7B?CnJi3~1GNPP7`_)DmYc*Ok1~OH0 zEfkHZl!uVoH9h8Osx2W%G$c!8q|P@+OKSmg-ieihE%3pkx9(EA6SL`{# zT=#^3fG5k=uv~$qPAtd5T*Pf4?8nymdU=&dr*A|6=2Q zyU-43?QW0>RpM$+8Jd2M@$N4~>?+wj+F}NX?IrD(;2@JCz#AC(R3#h?WPN~w5R8SnH;8Gm+&x+IEbW9x?{`Zef!^zXcxs4&?b2$eS zE@1l6uP?lcsOgv|ypz+#R7aF(B5s)UiVQWE`NguXz0w=bJC zzHu|cLzpCbkLvf2;X_LeIgd#l4#EM7!vd*?j*@cxmJI%)Pk>wnaJ-!W#%=zBQT3s# zs2`Of_|$0CEnM>8MJI7#0$vPrsq#miGAi zlcR3c5j65r9Rn$Jv#v%Ny1w5;3CEZ+0+h8z;WkwLDpA5vGKVaBg$D=8?q_nEVv z%v*%Ug7v-qW?@<(V?^I$2~u-JP+ z;NPDS2+4>(^k~`Y^>ixyULeh2K;;;pSnUvtaZ&6`zFI^79Nrp5On8 zbVI@SRF$2$mtQRvi8Ow)#9DUfFg6dC`F1EmD6v$BbVQLuvOWC|IDLfr-TwggscrbEp@#y`H{LFNU>=D7`R1JPV9pNk!y^PL&(nQ&0*!}A$r98o{ z*`|lfgpU@9=^jsF@jGjjOt5!xvnoise}B3_*sQFH0Cj4PSa_B=pj;PTzDk>VSvms+ zfJcxN*%$q>l|Oax{Kh%U9&WOkuXAqIuO9=Xn{sFg>Psv^1T>en0b$x)rva);2*>>h zNdNN*0PZjE!)QwB;Ch&DPZ>)4IqEm_0_aUKw<_Jcqj_KlBU@fli})@dpE1^hNCL#& zQQ#dOiKHv)-3uZngM>u9fWMODw5Hy@DGlNGl&Wy=M_N4MJNJhx6rIPg)wuJ`+Twag zXXEmHRUhY7k(noHsUokEdBqq(b|bG*%>uNQ@3?z*DwsFHsn0+_+iv3o?ULuHdO%r7 za1+U*hD;$6vp#mTDvDF6geI^SXF7axD_(-0P#9I0Vscq zHcb%ttdBQ0H_xQ|UIcGaaR$7cC+s_FQf+bP6{WG@K2GQT8*BWZY#h~Lqc(WC9~^m$ zNC%Gp8#%L%3bYgyca`zOm$QNbz)p3FOa~xc8_nYUG^= z4qq0JwhjHt+E8Eod5{+7I}CsiC$X!`ehsMh@cCWeXJT(8S{8vGfgpBwep{7BuF!qU zH!z4|KUA$|elcr|(OM;&Y#bz?{goubbVL=Oke7fXj9Tiu7Cjg$$AR_Op>5XY2sC3b zg`otdHdGr6uXV>bG? zO$F%Wn9_C#Uq0Ua7@4iX>~UFoyLZ34)%M4157{O!KT9T7UORQFscC?W0U{XitWjqu z0@n1Pr~x|*uFSHS07A7;ErS&^Cu~-xa&_6{hy7>;rWAOX6ttJ~+4O7iyXPc*am;Ob zdFyMO&GD|&PAwTB_ztR0DRBr`95Dx3@*_I96Fj_Dbh9 zUu;8EFW>f4&z=?_$Yu78TJg4rWPY4~^;z!aO{i>MX=)?ik&){>Pgt}^zlG0#ixdZ> zmP*{G>aE(6{AL7#3EkuR7?AqEy4bO@;9kR|_3Q20a;TY$4XA5@xaz7!y4y(81=HYy zB@@QwE}Le#XLVk6-={SNDHTI1Fo?#d$)c>h>iA>(_TOADpo}L$ZP~J=jxsepJqOZ{ zMu$C#>#t{ImG z&fB?*X6EKrU@hnpj6xo#p7n(a73E0Z zl_+#A|CaWjf^x_`^OUsICPSIlbiJnwvM@-iKw=SxlfNaTH_wt1xyny z2g(`YYRNc(QYD8W$gaaOTYdtpmNvZceuprL=OIf!Dx#r7D}Zh!xj2#iig@$BWHoVl z)LZI|9e8?xG#JFVLI1E{w3U0uMNG1_&G+j#nnNaF$)99%34SQdhSo{{zut-=Z_7#D zCF^8t6}<8b z-$DAi5?Zt$$Ws?VWMuo4>(}F5nZejHA40WvLE+(!zE6+(#{PBxek3K5naaw_>YGE} zdvO@NhV2a4XPmvYcU?*+;G(qOlfUL^Hj`E%>1;%q#iNfPPKm#$J2NHZM{_r78JQ{R z^YrMI<9J&Z+)L(gBd>|(dYvgpg&|PqCr&lYFTI`cIqF7rlScWQGq3vXX>mpM)#nOR zr!CvIIr~1od=`E8cq-B4d4;U3yo9L<^QX+6Mtbdhe+#y}gfa z(>MUdYEMmCGNDkcjqxhng#fp02%G%=f)j_c7vl&D&v$eNlDDC#KV1uWF*SbeCDSoK z$6jFgMomPR8-3WEY{f~rXV_Y^Gs1gH(?k}G+T^+Zg0mJ}yjV0R7k>f$mnR#=RCOK% zKrOu^A>h-Yq3t2Eq<*RiYsX8BqLiCge&>{=fgG(2V3e7m;nQN@E|`$TNWkRpR*Dle zOKoK}>Hj!rjM~Cy%ex8#!hxc=QvlhjUGMHVanz_^u6(I@6UP_iB^-;3E3X@9>_l$Y zL>M7zjpa+z&xXHb?^>47+J5R3iuGH#-=QNC&9(TOfEE}{EIA2YR{wYY36$LV2{1`M z8QzeX+*B^%K-wsjnd0|z%w@Mb1HcPQK#1;~ke7Sb9ef23Nc_aSn}m}MEQ1q@MDp` zEQ@qJG|6S-xUUg9WDzpOn?(e+o@L1u>zgoeeNF~B#hikrOO~82K7;=(sg*O0a-iiD zHX=PNFLsUmPhry3FL*ZF5x>FPf?)|>igV~b1Q8=?Cz79iw>Wgf*in*B*q4>nqB!CO zyd$8eTMlI-)p?3No+z%r9lK&$5onJva~tcVC?{q-jeP;S4&f7h)lrgcir1Z!`+TCGM1@e9XTJ%mN4A-5J*A9{a`m@yZb_6TuM3RV;axUQzwkcNXv) zxn7lMw>U;UjaC}NI)La9berzsS$p+@gHdVICGmSOwbTy`pI0_Ev_I9+?ghD?RuNG~ z#14*9WK(o8{%JSj4i|2_@ycDlD>YsTjUD~xXstGTdfT_}n(uKFTC-iZ_t5=+9+;bb zs5~)g-#lm6dD`=|;|mj1wj=&~Gi2AVI+xeAUH7lA|2Cz{S~ z&Oky>O2McCn5jL~*>h^W~Xiac;!&gvwf0qGIJY)r498Ig{b^^l}WM6RuU~~O7Lj6hpBs{ z`{LE}znCBQAFV#Db_Mg_#N=n;J!7TX+KHvj9JE_$Q-0F*F=tqYN)k>+;@Z2Pl~j^p)*tKMvNHYOuK|M z?l5CE(2FG}1MU$E$bimybcx3qQKQUhPfKmXh^YEB=p=U_ZAUvax(`vO4D%I#ST*R% z{GWZmw_3J*?%{v{g`?emD;;Q9r@36k^Z9JZb?ZBT8w~IqBv-`)XL`vYf`TXEF46gOiii6#)NuZMkaWqK2aM}H&3oKg@+GX>%X z^5a;-wU|SVd_yXc10Gp*|NQ5oC9OqlFlJ3>vXK2iq@z_YA6?gDm$z={5I&Me3yn25 zFeAd`FmAt2bA!N$gq2rC;d$|xD9%k0z&QKc6v2I|$L!|I3$mllSomOBGI+Z_qW?hs z(>M(>U7BzqSjNJ0^=qhLA{nVWZ zCtxDkIX3-Y4p8qA9-4T8h;-mMJvfjeUho0APb0Wc8B(nC1HU1WAmQ9wfjL0Otv)#%1QA%l=%kMrXX!q(TO zo$Bmz)C}!eU(fHoJg`PF6^XPqy)OWs#s+R9)eqIJ4qRL|gOYc?2`uYpQW>ehF;X$BH3G77Q70p8f`bamVau(3hbD}rK?q+(pvdo|H01)7b0G!d5fdI+aaz-z+T zFf`ykpjH^lEG&0-&=T_#rpn&vntn7sKAT_=e6~R*uo4P&#hGm#RAunlySRQqo|&uo z^MmZhAusN{J+0REMq9K_mf3sB`TS40CdFw)dK`0zyky zDdU(4xK;GC=x5&%GkJUyjC(iEsUl~6d=Deu7rlK|Og$Tgd4a{65r4Dkt?*H?va2Z> zgyX{q88yKR+ws^?E`}zP=(?Skpd{s><-9!6w>W|1N}Laf)mwTj1?z?1GP>k}fH#vY zNL)?Wxl`2#z5^A390Q_QxNav@4IelUaNYb#jz66z!ZmSUs(6po`kyj`O$^47=Lqow z7_s+>+xW}@o4195knt0MzXcvf`-oD`RNL$Y{VyrmYo8O7Y zqv%EuIKGWW*st^fQWul{6uw26qh-+vG4T ziO8%)1i^{w^mH)e=1{@vWWD?js)|PCg__GkIB^7qVLM50LrDjc19| zMH2Ft8?kg8*@QLWR?uG7S$8J#K-JwM;hp8=q*EI4b0`8)z3WCbWNXmDZ3g6#)-h&! zl!&1~IO367QhZ;2HMklboXq>SPr?}ZwIEHRzkwO7qSH-}&+R%c{1#%Ub)RjbU>9{e zMzrJHlEt%!H9}osjZuWy{G;O-F3SqwOk%h+>%}D`DP36p2$E{gL?zLC# zWb!nM*2TkQ1Tc3RGOu{W74Btx7JbmV#n_=p!=xWZ zy(y2v;bQQ98NY1V8FhPa9E2iZf~d*!Zr|RTmh9h{B$_4OiJ3}1jeOF$cf{U*gGR6k zpP;N#MF|YR!b@eV#GRt#m#~6{7BTAEdK0(e&5s*`rmZ@0*Jm`SHbU>~Vr{7D|6Lwb zsxVJ@a+vwNS}bd%N6!i6tDehtlgWI`IdQ@~@n`;9vvIv20GI?Ex)J-l%!AiP1ZG=XK8*ee!r9=7V%%1pj(yqO=!|C zrKiOj5jwVO%U$o<5=uEz<>4=5rO&KZ^D8_Rl27@&;m@6EeUdKL-S;o1r;lZAjy`&Sn2Ok{!(^ADXMd6o<{*3yEw7NI zi~uS9kyiDFQ3OXh?|L^No@S(CoOMM<+7ez=jOEbj2Vra7xZWOvO_JK1+7YBudj>bM z?7EYDURu$j^sjz2{%-1hmm_WbgW%1KD0sBmzBRf-IbO7i_?~tvdwZQ4{}HYFe7hNz zUL^X`H%3perF1$#Xh*FYU?Ovf2^UNqxUL#-m9*l1)lEvk;0r6rJEKl@A#YH_YQZ!A z8xtjdFs&91u|i%Kd=$sR?%3Zyegi{OFl!lJrE^C2S}nT6fdg;1fsC{#*Y^5*J0gbd z?!yhCL~zfDD0lPm@`}pKcETR1%Lsy#v2Hq1WCwY%Q()npJ0T|i4PDC%>FL|_i+z5Z zMmR%MlV;-{pZ2RuK#O~^Yx;irMi;jy+q`k(M&Fuvw4j2ehvuT&Dtbk6no(cT-84#mJQ;M+cq(fp-z!{~&P-tXRbll}d- z54f_BI)d6t0;gyU@^p?Gosqt@DGkuS(B_PE&d zUMwyHc$ON^5y?pC1ZHQQWw6%ef%Ghp;rPaR6ZZRQop8?TVsD7d`qscCQiGHXOJ~_! z;cfiSL=IO;CHCiQ(Spg^>+zqB9hP?YZa<>q;ihZ;aVn8Jbr>%aN8xNGCzl!E++Uy! zauEDXT(E^4VTn3j9>a?cc%WMP#Oz?fD%6kK``LS`^vlfNQ#A&KQpKNMna{3eZC^-v zWsn57nWeLN1%9HzGn+MU-kipSe)cXM^{;M-!aCgjkej=^HR(lGzX8sk7Z|3NJ*zzh z4s}X9NAgGV`kwZ_#QGh9K@8G}^&l&8UOH6S78}qpxpzxvxVrY_$aGxBCMr%Ckk++q z$-DFK?dv;Z&BTpE+`0%QSqr@!P8>gOEp-^OIO%W?RHuy22O(l)=(g-P6+P13xp8y0 zEg^>kOXdx+heS5nI&9u)!+`^b^lo_dZC&QYpI3Cnz^S*EiP!V{rQeS}kBoE78nbU& z{}Mxn8|8v@)O!=Id4;FkMyUjTAD*$SjqCfxqvunBu6Ay5T*jT=x zvsPX^{brx6%9$K2x#+8;8pDS1lJ+Fb*}ahwQGZTKTJ}0kl1#<(%6$4{Ku)y=RMyqj zcG0W-FD|iwq}oP3zCT~_!fT6$vb6l=#{4%|nwwe$-yPdtw~UI@$zbgE#$#r-is0>0 zVYcBDA;}T0t;_Itad8pV7f$AuWc#qr`YhbwemJFh#r`;L`f|ZT-v6BW&-xBV2T0me zGMFdz?B0OYLg#>#(Ww`_Zr!?V22z@Z+k-S^6CI`fLM2r`FjWe^>su>GKy3(l6U9Rh zme_4iZK^U~*?G{QLBpN|r)SLFyCJYVqVK(^A{bQZYs%=iDoqX`T0z==*|AltR-&~f zeVh3x?fI@xqPI$-Hz`qH)^geHqZhSjJ2~y)j2HSDy;AtQi%aInmqB-=AF3!`{8mjr zrdgP)Bz+z|di1deTeF?Wj;72xcU-%yRhxhRd*jY3BH4kwSyrSb8GOkGme2W6Ne?Iv|=6<+PKa{2XNqRdY5d(ijzK`g@K3h$td3y$O*Mw z(@(RTr71eQv2KM(12S=zswlR&EK8 z+~J>}PHUx0Pl6&5k%*9-5SF|_wKlhNcl zN-8lBFsIYGM$#kPrRd0!BYku2m^Q|TZQJyH+JAPbakO_=AJC%{{RP-Uh(b>k`0xf# zpn}wx!9F-n7*y6`Q+S8`pvo3DdteG^7=>OdB%=C|=+}87)4l+X)2(Oe< ziq&85^d~5@IzM>k-@$ATjBP9_%Ly&E%C49(mP14f64_cLTO|sq*%$`pw2-9+Ehltix;RP#K27(o1>chJj-J%Ta=FG!N6wKAVSXwV?RjN}b@-a^$$ zIs$pZP}uYRm}OdlF_la%J?@SY4(jVh2kbnj6HYAh3#bLfe5P#Y8iMZQYkTaNJkOy; z$=TfuCW4ETo0@1@1VLD4#swapopu&X?I0pkkf) z@_=7CyYFBlE=@&*)tE&nKj%NS1;`f`T173nfiy^SMR$#>J@^>q`L|LKsSibanq&a zV>Y#wW%ZBC%NMo1JmN5%nY@#X^9w2x`ny83wB+DP`LiPag_l(Dd{4|$EjySK_V2*% z=j_I>Hqsx6I21eag{{RF6CJ1tTHbbSxPAvEZXkkXFnri z245CT8Rt^Gd*Y@`VVDLdPCUs$1Q^Y0N7i~nd$43{Jks82Txv_ua0w4$&M7{g1UbO* z1DwK)3m3Wlgl8how<6V7?T~R?j@OeSRe_~O`l|gN)}mNZSJ!N-tTr| z#M1^Zq>2|+vf!~?E&~Fwi1a1)bIY%!q@)y#P-4!l;yDJ+CEs= zJl>x%tGTTcsklighz%k|nDiv(nEE&G*N*;go8_?_kX0m-_k5erilayCj6K6oQpL$U za@m%_cl%qdZ+WHT|Lc{;pMCKK^Y$lc%vUrw`2^0O8I%56bR=-x#_HFl8ImqbCc_kM z`hR8lpRu76$E!)v#)#Ho%z3CeX>Te1 zg=%bA5ift!&%C{smLZxGX7AZmnSG$D8E{4h%8Z@==_vg+pTccR;*E$BsR9OVWNP^e zLlQz@NH~rXg#LN=Avm^S*~1Q+0jn^r+x`QGz>N}XN0ld`c~tvv=wc~tL?SC@)FPnk zK4#MaU;(j?B5uCER13Tz=nJ8e&)uP2-&r%zgNT;Ia!H#mydF5I=bEAm7rqyW0sveV zH^%_r%6{ea-3IFQOyc_8x9>H@G^~`t(}Y<{IDUMf`gVD3 z1JYI~`Sw%i&+lGT`rG>yPeUAO`61ulwl_>}rLwhGhchwnltP9~AJh zA*V>@)>WpTKiP&p*M&vXJ5ATDOqjtnq~|$=JW26;pIN7-NO=j!juzmoLp4pRi10Qh zq3tiD2NqKeU={P??YRUl@eC6n6kJM7{|MxhJ;OzTq(;XfL(aATSO+3n?2537wAuUp z$6_4Cb^{f~i6RD=)ae?|N%t@R#*ht*fi~HdqqOAB^3JYVg}_1&T#bm61C^Wy z%m^-j<@W8Nc_X+V`(@Q-hjC-ZlnJ~mW2T_r5?W6qDWZ^f$P7i2C=1uAQWiq#LaEqo zTEpUH;YL;3C-(0g;x<2JfwvOgVWwGqm-G%U=VLn+QCf^Xdya12);YT{;h0Yum9U*=2v6D zl@je`m5JZCeLFAnR`8P>`$A}Gc zmW#I)aloGBE_4&bQSf3`r8-LrR;f1ulPSvHes;Ey=h#_mdW_}e&7pjEhc45mOK!?vX=4af>eEl;?m|&wWPo=e+(nH@JO;3P)bU2>Y%2QU}6AdG6Pvt;1lAI z?CS~R@23#rGNqc*@am4wHSRkSE@KX|I5*YqLD{b?=+S^T95t8Yd2AhlZT z4nMZ=K5&c_l!_RjV~}i>-G)`dvvH#XxEmOUT4l z*&Lo_D=ul|6>C~z8Ho@gjZ#aF4=uOP^S?e7qHe62IgxDOb}pz?HlRm2JP5)$i>DU4x^KggMcWc^ znmT5SwUSvsIxL>PD5aVUWwcgXi41kS+gciShiY<~Sf7h;8tM+Y3{pRnj3k=rDuqv= zm<}(0+wXVK^hMwxdMTE%CB8M6hwCz$LhMvQq6?-^@mRdz%_*X_fuUh$L~oJwaHGIy z+!EgK39}KR)Pki}KuwspABmS{Plp6c!nCJ>mUzV z`Sjyj#9*-^R0Q{y$(zHpPiY9GDO#ooB=38!pZLuZlx575)HDe^a?sz`+EMH$N&haK zMVaZV9{!mUW<%ooK%AQh|5d^3c68)j!u*NbGR&>1J>ojSL<0} zkWr~nfoXF=ML2nb1WuENLJHQlP9q*xRV{71R)&8B$3y@iLyFvM1Bx{orF?F%rD-YE zwG4$WLbo|Kdb9bE!LkaVF4FTa4(B=^ZQY~aI^aZ4tf znwm<-B)UGDRK51PE$xK@8^?h(oJw>Nq}Dx*BY(C3eh*rC(eDu}5A)FvP++MUsIIfY zEthHaV;zD7^n?To5E;r+L7WZ}dU$k2nPhKwbR~4mRFwP5wt>)ePK6yA?-|CaLvqc1 z?t|qg>z{0a?yOOkTb8fBog0~#T z4yn)1yu4;LyA-qog3U!#Gct40x`9L<=y~fo>QuFG8>UcpLTuTmqZDXEcDI1li)b2+ z<#Z~Pz&$5<%!HKG6%zxNJMKHm?EOm}rDk|$5o1=ehLXiR!bu7Z`_lObK=R|X;;rJj zhVfC{*9;U#2w`NA?lG1Q$OIHrtBOs`3&o6q&!}ECOA9H&7HTAwz^sWab(3v6J5Kg- z@a66jg^V`__nTu6Yb*X-#Phi=<+YJ;u#!0Fc19yCkYMT-4n*lMn=4LjILUK}M$LEq zO)I$tU$ZNFg>xDU^sFi=ydAYl)3z0F^w7DiGbd1+M>Gc=DAPwhL{lpgDyuysk&w$A zB?XKOujwiUdqjJ%bMMQCIyB&@9WScdtsSqKe94drk&uvPld734p;n|{xiSGnz(ntL zKm;?oYA+EeVc^O?Wb@CH%onG)6`fN#Vu2KvNQ7nIU#Iwy`pcE5VkuF?534s-I>mST z^9J*npK<2#t9jm_;&sCOOT>b>5pycyu~LO#;!$a7FN`{Qu?4qy8s#wYztQ2|m(&Hk zIDa(hG^F30!%~(Y80a|wRG7Vz);t@8n>Jlc`K{-fG|(6)>M99%N5x9c*qNZQh}z9T zi`icXFH@@Xxu4K9rBvsv@D0#Dv3m8R>RnA%FP>;Iq6hbGUR`pMzFMlL4)>B*S@^P` zmnXYd*}Mm78=K-h`5o00vg78#b6W>JMJlv1m+I>U076)Eg- zR`?v)x0JUL=-?J;5reJ}Yq%+_gAK2)FaCHUZr>_}m?+X#Pb%)B3rcvRRlIi*Wy8bJ zn=Dot&j`l4C{sa)5CH#oZ34jy$E^e1G+b zOSN8Rq-V)|&^(K{U?5S1JwG2OoZNfHCnBrC5W<%Y2s*{$s1FS-~l5ggl}@Mvu7e&Es!n3v2ep zPh%K#Qe0f4?g5^^zWTy{lfKavGWw8pXMR)YB0EpnnH#oJOItf^{)SHwlnJQSjC_$d zi8OiHX@u7ivGGPgC4OSa!6$lLeJMU$#W7h*`LDPqc|WY@5XDfx=%s*i*jm5v3TF7zB+MSw6J5g_3`rITqN9ks&{c!c$pDU_052ucvE!Zk_P!rBuE_GK`NcK;*pc44fa)sC z<{gU0sE0QO`n;{BxE5w(1W#_*25#Mwd*{@)IpAx7rf02G9Q%9`rpJWDx=3$ZqOawV<6_j0MQ!Rnp&q z*^J0DlkXqO=EO<#Nmf`DgV75w-|A35&bixLn#9Oq{knI*@T4ccF|9~KR%9mQpwZK$ zWZs{sv~=-&6QBU;V$5^+!FpJ<+h8)T#tE z<1>9J$LZTDJ+3~s`4+^`ydrR=hj%BLQl`-e_YQYv;<>@OX@AI*yhlzr0gf|uTDvi44x{kYczmYUzlpCuT+|r&>v9v^cC;yW_DdN;4sT$ ziAa+VvOP3GPD}-U72h>8dRcqhIlPQ&L&`Gof=8R0^RZvP4c7(O+j%pT&+FW}{?GG1 z(#RCrf_fla;R1+<+QxPLO^EN#sgw~DJnfe+U%tuJ-UgwJcT_D$#>?gRIV69_B})!b zHNcHNCL6W;LRJeWm3)mRZc%hlFFor6*FCNldj|Bf(DsP#%4IBd|L{d(8iZiTsSiRc zs=v-J)|ut+yIdZq5xXQ6^6Z4jZkj#Aw;NV5LMt?jBw4$YCEH8Ym*LEOK^Fl|7IGlc z*9{JkVuiRtDgcS>sOiUV`Q2utLKzyCT>7%MOlpBI3ArcF%-o)0V8kC(s7q^ZQP}E! zRX@>qyD_E#%H2bW26x{2-_2+=xb?$j9~w{1B+t9WVzp{EJHbnj2kZb>x|b_^eESsnezJjka6h+7JD7+Tt0mn4({fyy3rS&q{}nKmIkIK<;gRXT>?Rd@zD;jgkXLS8IS#;chY~=Dt1^xYFd22c+9FcZ>~+`6N#sQlz>8S= zOh}J$Ux~VxW0*&8CbD^u{J~)2g%0sEj((3ezGJ}x%5f?6)T3NnnzZZnxMBtV z6oLoFZeJ^Q4>=1{N;z;yXgKt+&do zL9_FpI^ovQ9z|w?Gdz~?AbYT6xPj9vDm-*NR%~p~+|xMYz9{!TvxD`c{cgd&I zo_GElMZO>;G&<3SosG_5*r34i-NL7Tm!T{76 zEjUs3EkX|BPX{KZqMn!c3EstUlCmNB|x=;M?2;DYi`^WSO4Xnc@3|8#dOTb?S(0IACSWba6D2(Ia1m z%B4D&B#r#PczXc_<-`-`Xw^glPlxG=Fb^5)Y|414Y5|7Oz#1q?eSuU3)!*OioGI1V zOweE9FsO>+(0b^z`;2){8IoX2gN@3XHc4Pi4xzC}3O&E43?=~eNzc`J(BdhJ7cWj2 z%6kxQgOiho5B5+;ndh4X*;-d$*6Zr-p;lB(<~M6|5T{GQiBZksB}=Xb(V=Ms3;+ZZ z$c)$%BO+7A{Lu>LLd89q_Nhkv5&Qsz;8pARX=KMCFJ0Gn&C9$GH8KzKU8TtowWmP7 z93hEUn@>TGbCshROo}>nj$svjyAua7$&IA^(+4X2O-3mw2&6vZjv6XG4Ycjo%MMAXye&W%~Sf`qly9ztn)%ETS`ZXz63@ zf8jH`fNA0nG(dYs?QcxDGTtDr7_pFpEr-Hu5D8SC`&^m(N5(@OQ?H<8Y|BIETa7b6 zQ9VcF81sbDE>#4hKC&t^tMacxHH9$O)iq8=s^_pC@UZ>aSWc=qEJC(T1=JRP=~68j z=LUa7y3E-GP-H*=IF{)u-SAU=>1)NMEMmjRg(DV?@xK4J7NDO5HX0M8GtGss)^-$4 zJuQDogh~87ORmavYC;~fm)50+8HtpkBI3Iq@ZGIAKwj8zr6;XoVq3)U5V||gi0~4C zduQR=P?AmY%V=)6MX>W~rcwVKf}Mhl5cq=_NoW0TxS+AP*XZ#h(?Ad>QQCoA-*Qkfx8NOI%6V@1xkSFYB9T=g`(hS}3S`Aw4ZrE*-?0mXw(%<2_eYGE6?|EhT5Iz({0yU{^==$l4|?Tr zVEmsHdc7xjg~77c6xOZn2uC`B+^{M2WJr4kwoZ-9;&K_S&=71QgQYW`AyjapEJOG zJvDEw)Ban3651DmjIfaek37n)(wi4rbm|VK{;O(dU@d@F*8HIyFB(ihnm-(hQb`DY z(jcX3S7F%)8e!R#Q9|kgb&OU$V)&4DErxTD)i{f@EI)qc zd5O*zJ}f*E(htr|e!8{+2ABv&BOoSN)TH@vSE3=NSCuquYv++HQP^IG6p%xhMzS+p zUIOAEwqgT`4YJ~+WU{74zz0x>;#DP@d-9gjlBbWUhIpPU=PN_s^l3?gY51K68_D>d z`u#=cMW=_9Q!s7iK@Ne;qoQqp-9@@y$RdA86QS5cEXvfF$rz@7G6; z9m@qn6PrErKCc7*zVqt0)hKIA9nH|Dj2s9)jo{=PXF6&8^RBzw(7<+(IS-=Dz+4t&hU5!yH&WyM-QGi$wyQd-Tj@U3}Ewz3&iW5hd z&^NjrZ-%b>q~#;1|L^l`X}l1Q?3JQ{e#oNzDmS+q*rVV|H+GxE*MP}n@Lbf|*bw&< zxt0pgS#hyr`S)1Se}Ag8wkyV8#?D`8#D_ClDVj&|U!$FycV_8=1^dBjtx=#)#KYk9Z7LjOzhPIs8*>eDYNsx14TzF^~t05@A?RFGAIQkfg8Lp*9Br6CEwo3QvTLyKLC<4RcrsGO&|3mXv?=ghQo@WAR_A!;VfJKh;M1w@$smuBG8iw4}ZP$ zvNL;0ox;9efue^;Tw1aUy2zh$xq@6LuAI=LBL5R6oR4>V7Tke6+1A@O?Y znD~+3^f_Dv@sX%S9YdX4GI_&)Bxd&G06t_TgY^bmKG{JIHj?nK0Dp ze_%T#!;|%t^v9rT#o9>5O?b4f?$Opb<1R}-iHK74uqf;hFPU2vWG-0CKREL?RNn12-FaFOJ&j)$)5zPV?q<`}W&!MKlWQql(}$r^uMWsJ6(~rPrc_k#pEj zlVenTUWUtPeWVph%30n0KZ;qe%TfdomC%^w(X#|z*72-G@{}yf?e*%(zld?B;U?X)f83eZy6u>l zeXkZj{4Q@%WS_{ft^*I5jcvoO!N`F#FZDb$&D81g;)B<(|MFX~XV}PDL$#D?W|tZ~ zXCJO8tb2CO_xv-Ts&5|EojtD+P&EGE($4z>f`fu0fG35U;p9(az{fTjd4&|KQrm+j zw%_z_T5{;G?>8axg|v-@NvE(UG}SP~u!UkGFdXvNT-k}7P#At=p*1}ZvJG7bXMQm!e9pQzx)mRgX&+4$CxzlIgO4$W zK5>Kk+l&g4vVkso3KgQLo#@+l>aO+j@>+=4x}K3Z92k1S7;XJUv*jy{?~arz0y$@* zvLM};!VC*?L-LV~3QEV5=ufp8`s&NiJ{#|m49&9ZL}#=mw!5OEqhk&cjEy>NiFq{u z`Zs4K_4C!j?>@ZU6*mzbpPFq7u|G*jxKgAul<|V?d?!YTjE;q)JB5SLSBgl544c?C zm3~P|7}DVr5JKDXFKN+S=9+X8;ftWYBO7Z1rSG6L)Cac9-m4C80rSw8l*0fq1{Q@X zhD92eYSx$m1;{9uFf?>H15uq&)n2*qO-pfd-)(gErvb{b`=^SmAJkmDqQuz~x^Md0 zwQINf`=E8&^gtRtd#DdZAPkZ zh{Z23Ok}JI63ldNrhL3a-_q6I-m(r)G)yvd4)btjZvPqm=eOL%5!X1wMt`!C!Nhhr z{#dvP+~m0s`HwQ2E>z0gSsNVzoRW;b3g}368|AQpYOmyVKYe-#_BwO{B&gndXExh% zfbFl7DRfTnAxJTWuKUhl$1>`ntU7bKC$eH%q$K-laK`(}<-VT&#DnpwR-cGR+CheQ zCY)|qWfr?m{o%Byv$2BV|EW}7HjEj`{4oeZ)JhICb(a8*GTS1`7BJ8EUrBF=_ZJST zzvjxybE${v87Io4p`ey6qo%rc)D3Sdg;IJ1&iB?WA3BTyG@~}^hJsjDFbp-ji;VA( zWai`f_BnRT&PIaOh8|B%O;tP6)1gz4<9&5`+^gHvy;k>#XVE#N-RB;eO)P#^Y*Du58(d45U!=_GOkqSM?V0d_E8xf&<3AvQ2V# z6suWPRTqn~Ma>iUa3>vO2?l;=t}HZ>q))1m+BYIOy_T!UE}H` z*+EH*))W6w=znvb?vuAL(&8s)k9yT5tLGsMQwK?+y$$Q^bHDsj4RbfN*@})~&UtI^u%E(c+Pcq%#%?oA9>pc1B`# z<_e~Z(r||~yUYm6aDmG!*`XlH5%Ep?qP?NugHlxsb<79PE_?=1Zq%U_|09e_cH!Y4 zHn!%!L`D43Al6Y)W+IC4ZF(~f4kAPVxza4rX2(!x30h-Wxdw@PROA0X`I+$)mql@a zShlMpPSIcwLcOb23@+<&QgC9v6Y#5A#!YXbREgs+_4Ic_ut2avyoFSm zfj=531+T^ZPCKcQ?kw}%OKSXF!*icz38k!ZnnMdhB<(+GCHQJoPm8dFjE2}#UbQg7y94=Q7GXnxna;E00KNj+9LL*A>cuV1vWS{{HQ!bg1$1CQnSFU=WZp^&K5Xu! z?}hyU6dKRCwNH82XqwmxRot4_JiJ6`Rh3Phlum(xR-1fyhN2`f*fhSzfrFM~Ql2`u zj5<67WSlt=YqT>0&J*Ztkcg1n* zfD%)=JvbRU-L#++5`MIi{jTC&CD1lOI%-%6GA_&)=FPlIV<7i{dMNRS2$sFboq+zb zUD>5CG3nywvHR($BdMH1o3xkq?@%K;K@Mb0&ozbplFG;NfrUAs`q>q`PKqYxqK#}m zRA+VS}gPSF- ziBttD6PP!xE){N`L=I7xis>Q=Lsv;fCWaO!GAPGvjlko&N|8&&y!avEQN%VvwTLBX z!^I$W1@d=-dPdN}uklC3!C@FPNl;k&wh0mNRIw!Ra%wQ~8#{RDkTfooQoS#mp@Ry% zvCyP36c8k)lJo1dglx`8eP(rWKm?u9M3oH5@vS#Qq!|P_43RFH zXc>^^mviIioNs$4g_@DBRMw&rsu;cKfx(Wnp+XLUleEmXmhjgcb;M^RQk zoP*1{)V*9SJw&I>U~OA5_&dp}Q;n>Odi1ewR6i1rj!8VB{bicZ3gGZ+YcnVZNs8hR zBsvZnJmovzpix`9m3^4@iBvZnDut4H4zaV|>VNy>vN_b-WvKo{1r%0XY`TO=IK$xRndyFGg$wgGVDv4rP0yDN?<_siTp|L01?Kr${&+gsk%2^`p zbT@3;Fci>-SP+5drf4e2+uH!Hh4PmvTo^;AVDt8UXC0t=8-IGu9(HQl%kO}=vv+zR z<5yM*DU|v=(TA;5#j8*tE9s3<4oB5;X00g97z2p~2cMvld{2!r@k&w!aNCF^9!^=h ziA5_qfQhH=iY`Jjh}@kkXq30iJ4$eSc@rzn zn^YoP@gUn|pf>qih%O|1c8ndRTC0V$+>9N=qK?c*W)g+Oof#$b4tn)8wKA7CBO;JBSWKW)H_V@3S&NR$`#ph5$ z{@>*$pj*_+>C>l+s|i(T4)RNxy=D~Q)=z)81ulWlZtT(+e}dM{4Ku_cii_;2T|-AkSNsNhye3Co&J?4EU_$D!KvHkpbtx)?!A6s9$YF4UbA^96)< znKQ3zP6a5m&hORh(0EG*sZVSjO1m^E)JWH^K}^^jO*-i+iLh60p3_MoJkr$o7hVHF z=TK!$xtE^0ZFi>X4q-(W_a0sns8ZEHYLAyWg#F+!5vOPKVr%Rm2|2s;6G!wuzD94@ zu!6-tFB)}=l2hFup8mcZuz1J(*(6<=W-_cz2LElV9Dxa;`RVwJz<&vsSNPn0)-=H= z^c#pptZN5cuBT4gX*$}toP1CJJa$tpch#CdVN?{az-{2-(^8cJ7PFkVmu7ih26ki= zQf9`*Ey}jMP67=bo8`f01f!6iuH(>k60Lw&>pNMm+k3^+xxu(!;59P_) zoeP5c&Yp?5NjT~|y?R9iAMt3ezI6HWV;+_V#FVJ4+os#HELDM8biVC$biyBY%5Z^C z1e2bWe}Znpcv9@zTgUel^3TFbj_(Ou^Ry|e?%sm5E)>#wH$iC?8G$RQCSt^=VDW4Y zXyS+%AL0=+VCF~7;zlNmV@vre{QnybCsmz=wTf+A2O(%1mWxSRc=V|{A#WcmE5Z0~ zQmAL%q?~v5wrt>=ST**Q4l~lZk#@tf@+7?hFT2+0_HA9NX`xQtqmo9qk$5x7_EK@& z$R6U4dH6(u#Yq+U>LQ?}ojpk<&ep$SkDsX1l3(f_c&KB17~-1vAVhs@t7|}_m(eOo zHtY!h8<5c`IcV5La=?RZZcW)l?-vi#i=1Y?HEVpsQ=e zxTOpFcQ9|0Cwf0!U0p}7-DkQV`ob{1s@b9eMtX-?w9V(;lCsotf0yn)lhZQ?z8JzX zm;pXpH*P#SSFx&>d#5ca*1ZNf{qVw#(aX}+7d9xE3^FXb6IEAm?!4*Q{FP&?q{B&n z`%DvV1T@F$u)-*uL^j58I}Be8;O0;>wK?Sd>IEj9z;wToW{zVk3ta({=GXf!p`KIk z21M?_)fP!u0DCvoXE^nXU^72+MvsD-_@cPeh47+JbfIS#rU(`7yc0{DdTxqd6|QZ% zTJ1gH<#h@yYx@H?qd53saZU)S|4dUeD`bXIc~k2`qy*+3!+ESk)~`aRV>}Skl=)Gw zGX5uXnFO+J4;nXp+wU-?Z)Cx=i1FWmF9B8H80a|gQtE~C%}oYM{h7xOj7edp!=mm) z+1V4%mSsIXz9(e&slx%Msfr?JOh++hST`JkiW=tX;@NP;iG}fKo`vdqyiQvboYv>R z?AI^eu+|O=#k%1Ty1tP5V43ibIIz{~j8`?UXE+foAoLK>P=@lG&zh(-oDJD2{b+oR|=ZPfp8VFB^jdhw!5LGkS|c8`@t&o z$Whrvz9)s06Ajdh$HnJI$u+<4(i(D?>TePJ34Op1n_kfs$}BYOzVmwfKPklI=g}ai z5gNDWQeU%JI9$B4g^ZKkdME%URn|=(`J2DLMPhivz16$!e*O#U7;8X{b|{W0HzZ?+ z4*4nici8o^m;8gA!YZoTsE92Mb=xun-BrDUOnVb#E?+u%2HD+AIWY7G1e3*O%M-@D zaFe`uVcUqAT-nD=f^eKi)D(nN&2G@NzvJcpDEH+#Q~^V?mr;NvIDY^7XYOcNpEehC zSY@((!M$qunB6@Zk;4oe*K;ZYgkg3)e;IjzdT;2fE*`#xm)~7ZA7AkoMuRMdxgs!> zijzuD=51sHQKIi)$IFyExv-m*_07hIUP37FK#3!_Lx#F_^GltB6|D4La!q-$-ntFq z7$ckur^5zv$MA3D6v;%$R~l(wax4Mbdn;GUT`gT_vynzI;5whjHTPnCr8e!Z?LAC10C6+>*6qWkx%5H zo^oEgE^Clk$gjWdihc&vX99R^f8Dh+z`CVaX3UN20~GYbg%*{CMp$o!kH+l?`X~{3 zC|r_C%0*|>TEvyQ5&rcw_SfjOopkou5Odhvvl6`fyv7N*VO zJE?ngl|1r!Afi@_J!rCI1~R+oBY$%7&g1B&gxydbL)vdOCZIubZ1QN|0}b~c`2eKM z>e8?-J-><%h;*%}cM9?EMZ@lmUoExxW4mIo%*1 zhTc7v{u1m6j5cU?n@@iJE9jW&8{4yok##I;f04stT*x=JFo05*>Ig&6{x)?(yq^Cd zx*0z|jl|duWu9zi$!SVYzZtVd4cMmb5>W^!xg;6dh#NF~S7T-V#Fc7Xx_8upy2!@i z>;3@`yLy0!CTyhgL?qO$Z0SJ%PkY&6-eaj+LC-4w;k2Up)c2yqk`=C`dlTKokcBcd zk#YTXdY+l|F?plC_1E&y>r)RaWRsh;>wq&WfLLS$2znCf0To4ix;|qlcAw0wj91qZ zp!empwAqrn7y27~&}@>>XH-w{z>_LDBr?7wvL-4<8Jo~&n*8G)@7U{vtx_e)2c`s3MSwb45;lyE zp9N+_LKurwY6YYJ@?T}08r}+FYpRCkbS`Ef;=gRQ~ z=(lfD!77vhRdRLEwwv30u8P09TOynQYN{H{)E`~hG1QSesfs{GaA3;b@7r(f=|jt0 z19ZoqX6MO-Y)VF(DD#fDZ{L>Z0bn?I#Yy*cT4qD7J`i!893#t933y_UJ+(-E6vK>! z7Z<>cFU7f+^*A!@8FLmqNMJ}6DOp{2OuDB=pUMM-pnem8&_Na9)}FK5q<2}R1Oq_C;>sp%A&q? zb;agZbYI_rALej?`uzM?KzWm-A+9d~zO2E6`o^8q5_Ck?aSN!MF-{??0%#Fs4g-DZ z{G)RgD(3vo!z%ojIX z(Fl@WEeJs;=3bJMj4&(3jp2AU}zz3sPt=BN9^xN&iq@yjX< z0+8r$CT<9tCqY{F{$dG(FT+oRf1=QO6o6uePhXDX?EsrMM~)u7@h-m!pb!7vUw6&; zRdH+R0Ky&oqs=+4lY%S0(lI+oD?4LIclj(syD`-w-xc_>U7I!wZXy9N6BJw&8^ry- zTY}934G@DQE7(OUQ=MAy6UnCD=3V}du zV9N_Q&Wi%%V|+uOe@>*tT$~EXz>_P6bTp!NV_T2RJ3)c*Pxvqvzmiy7`S$8;}!4UHqMgsmNV0Y#-4d@9N?!j(l1zxJ3APCgdkMJ0BFQy=sAb^VqnRWhSf@CcQ` zkH4&6zf=E5d3%kpSyZaY(xU_UaoMJcL`Ltyy}A7beabs|Z=Gxi6bpKqR|use%X8JQ z)9`!ze37}vsI3%b{*>7Z%K95!0`!6%3;Bw7ydP5nhC`cW1LYb?0sfsV90lQdfWj>T zAX+xhB82kYv{IhGszO`l*S3lRffvNy4s-_|3O|RItNr1Z7Zp?0%4d%1`TNc zH+-ux!uCJJw!=)3MMFjbCldzxjY*~=rF!3o;wKIwxBYLT@a4zLecQj~BWVGUm~0ZZ z!GB~X>y@aruNG9Y<(0;dNkbbzuI&W7U%R6zm7lQ16f^ICvL8zt6b=%@QiLH<>fT(u zHMxp12l&tXrrG`H^F{z-%oM3SuRTUh-nJ6BVzU_JNgv2G<@dtM==79lLdQis0I3IIK-ZWUrwvQ1*!(X?dj3ro!Xw6wVIwN%sH{%eNl zta3A}O3x@QhV*p$TnA0sqPu63B#r~8jaN0jeQ9AeoL{e!->q|~ZSm*!i*EDo=zwn+ zHr)SN6J1o8VeL^E2d2lJRyTw5$-sy0rg_vsB=p%K4R|_St7+mmSTQ~jGJnR2#i88r zZINazAL7B51+95Z@I$8 ze}oRBLd2*mLq6U8YAgODqqeiC^|JTAyXh9X??o84_|5bY$eaA{zI=I_wGL}cZ8z-& z#w+LCh15TJJj~-7~9U==aY>kl*u5oq;q%oSYe&DIlB%WT|}GVXT5u*uRRZLkOI(VqKpqiQ5kBrQ@>|gpMRd3KdlVuSZT!;-8*-- zr4N)}8S$ssvoHb$@NTFS*C}9483qRMupKf?{Z|5V{&|G_qI1a_CWuuh{L-}3J=cu!Ur$vgOQ~|Qc(o$x}xeFN=TGnCv?f&7P=W4{SvRN7ip^=GKl>#6` zQXVyG`>oWS8lx6r{{Qy|c7jr%Lds*>6e}K_U-zraCyjrpfxNEF!cSPp7m|XmlI)1z z^EB$0K3X#T?KiHB<^p?`ZUi)Cw&+s1l>igAL-2azC!o#0z9|0Uo5jCE#%@#-f_m?z zI6&=s=fXCLwd3PKI;Il5^qJ~c$WrUoB9q4;j^Tg%!ctWVO`so<66RK!u!YndD|I`P~+#^UB(x0!a2&A z=7R(3zaeuqA`g~UQ6=2Gu-EOkw&;ExzWwJ(?CQ_s9WL*tWc9bGRw-iVz565IqF^95 zPLRf1c9j9{&lsXFe_tb@%livG;wj}_C)3r(fG%~FlslFjmOknO?Dw<;?i0LEsS-p9 zPqDwByM)nZ6W_JDT%DoZBWo1hiWo{}Q~-&~e~@=^zd~@dVR%%09OzF%zG``i!-xO+ z0>T2x9lrkOcm#|S!5Wy7^d6C;)u~+livgbtobmtt z(R&{W-v7(z9nf68|F7@<)BOSOpSJh^^3Mgo{U_+O{2Ra2_n+zg?`rsW`{#1~>z`_m z{^v;k>z|HT|Idy6*FW`Y_5c5YuQ-?e>$f--99uminE;mvCO4lCa>ql)(@@W{aOg!@Dl_H?31taF_-TbB$)($@SD3JhjcUxW?uU;*o>1B4M$f_P91P4V{vvW75Z@o4H^+N(dA9bHzQ%J z#t$hME5}y;+uNZznu;YWn;=|H+_hBdh%lWAffZm3CMS<3B*-RX5g4w0<6hOc6jhUY z##P^E#37j4tJj7Hn}yACb8{2vja10Qzlb&Gr`vmvfph@@f3je}*p)|eJqDpLW{ubGjG`auh62!7(9HZtniH-?fq0rX>pfRZv+xA zNzxpFSZpyw(V~g~-iZU#lIG4(E6)_;fT&;mxz!_sKiBb|gn&k*A5gveEEBh@*Fr=| z$vG(lh)1U4WC)r*=n7CWJca1SYxv)C6hPoFH_-{mngW^3L;rKTb;iZrcou?X9xV$? zJO(7cqxCD8MyawT6f=S@ui|i;&~}62PNMOZf$DSh;ht)16sX<{}ND*94(Qpv9}2sqTNS^q0-Cg9snN=>!H6bcvD!{@tck1Ch8hHfs7OoBUDc;rRz-r`q; z=^3jpHv{46lIyu{V&@2+DNZmZ>?XFgE!s0dl->~3&3l6f7!DXa*7n8FWMY|=Teej# zscppB_aS?PC$Mt_`8ZA-98sT%7e!fsLMe;NLX8~7Gzy4DHd-EAG2W@lqTY*7(dXOG z5*G8o%Rymexi@!EIB~wr)tdv5$sM4}H(=JsWvOf~77-(VJ<5}tZfp~jzLb##v33#F z8SN{Ol?#A08pX72)!D^7wDrhG%$0M%No|0YEQ?MsA13vYi#OeobjpV>ecpB16Y);F zd8py~lC?fQdFU>4us1u#d3ZBp%eHSq)?ORS_yC_!6(Pv}>6SzBd_|9ca^~|~9_ei6 zN5$Djj9Fmn_>!s!k;2IY03SuBv?gv?P(qVD3@=J``Nag%iRC|}Z6dVj6CMRjCVmK` z!7XvF=s=^o`n3MSw))m-i(f1eDq~4O(4nJ>1{5-4Mu)eUIH<|khhTX<7DS3IxF_cO zISJw6_eP&&{y&K`?Kl%`G~UB-kcRi~|r|N@kP=4c5;vfqf+SdfIXg(eUS$$ylM$}-O>}g z1M>`^vG+j)2p8xSxTViu3r)&_cn0@D`Y5oAe_=dzaa0z_jP;BeGqzs(wCgfG^cp~@ z*~gbKOCv~Gjyn<`%DCDYo7J}hZ|KPkMj*QV=p$@J^)9|^?Z>|_Dg0L4V$`(dk>q$? z=90P_md&H~$E7)5c2HIMpb1%YVr_Hh+lf}pj0qBus7wi2!76&zirhZ-7K}C0_57zN z?|>LHmu#fg~dsv_-bi*0Ppo->GehT8wQk90q@z*pqw ziH@B+@6zfR4fhIFC}Tu9i-(@A6OoR%H-PXcv*|`f%(mFS-#~GGp#_&wEFhb5CeaD4 zM<|RAN);iSI$}QxvdZcQ^(6Qh6^gij)-b#Lgx{snpMr?T1`Ul2J0r$NTD4?-{Z=dr z2vhPX34{&KhDsucic9j=c$34E5-lcLF9z9Mh2fivR7^0*rYN$k3xiP1ZdjV%{;onfp=!gzWv~ zHID8U1H3I%gfhUevUrvTGou7@SWL+vs(2EW!U1h- zEb7OV6Ok0>+>_d?Z6%Z1!ema065|z~f@lbK=Y_j_-TrymN*o~e{n|Z!HbHQ(-Q;B- zx27%^jzVTd6haGZ^7qd=>d>*Q>e5rljDjhQ`1&yj_CjrNvU@RTL1jAj{X^E@wn9HHL$au!coVf^L+Ae|C;)8ClAxc&t z2^Drc7g%iZM6Dy2605&MhlMuHVN-Y#3bz{wR*xZFl=}wPIoUYfc<|O@*-@f7@@+u5 z87roy@Qp=ss`>R z2Ie+qX7f|ORcn2({)r>aEMC1T2cN^$(+B>lnL7hU+q zIg=wRYtkls0Y5SoU{MxApb>(UvJo|V+sgv9GNWnlcM$)j+3-5U+v5GhSyR2 z%BK+Vv+P$xnv&4yr8ACe=#-(EeC>q_uqd=obg#0mwrqN( zAc$=-qzUovKXc}cl#x)5q&ZLVDMOSm4I>no-jE@$_ioe~TKiViejNGzJZwtQA?7hC zj@<4FY>7yZr6pQfIa7Qqh4=y_Ux6E`)lYM0;x_0S5_FwKwjBGZw3@P1PkK<=#zi34 zautO;V9W{g(ry>?e$XB#rY87z0W3w4y<$<6aF^6)o)Sx@kvPOMQOmG$oG`+oN{GBT znUj$NHotUBm*LaE1A}7?Fu4&N`b_k=TtZ{8Kck^`gNCiO!$ zZ@ZIqo9bm}m@a2K-!=Yh$X~y;6hM~m0g`~Fb?3iX!^d|DyvGBo&y*_C$;*Zh9#Gom z(lF(r@yoe%4!2Q>R5BZ}Bp5-4vuQ)2jcX9pGTwppO@LPG%Jtw(x(N{0;^{3D8oXBH|PHoP6<&mrhx_lA9&a<-+J9^pQg*7abV<2|Rr+N+_LBjqOEQ@hCj3E{p?I ztRYtJY&@ukK&Z29lC2@x6JD=?#u);&jf|mOc@XXmlWOI zBs|oKcO%BV|!+FxZF@_{n7c4>jXG{2gr*<7ARKkAFpac|`=7{9dvqU)NZ!89+@KG&Iw;t0E z8IcB|x?&a@JZmapog@nqy}6BrjC!CrHFXcq6n}{_=lu$Z!#5i{Uu>OqTeNBfEGkC^ zbJAw~sgC6uE@^M_H@7?%kZWIn?c}qS*J{swR~SVi%bjCu$?Mb?ioA8S?$oxfGq(G7!bGWzX7vu48+Y(KWm;vyT-5rzig6-$2+C{CoR-D%{yGkR}&?H8SG_Xox6z@=zjgU~MMswWk>n6J~;Awz(^4Ifien&^J1%x0N^^C0EOc!BJ0zi7UbTz4S0{EUCGPT$u1?tiRfc}C$q zv2#DYm-qp+j+eS^=Vp847 z`>z*=+kVk|@W(+f-~Q#E@xisB4fls?Ioypu?sTYec|0hctx+32vN7GMnQdh=eZHk%U9scVa)Y$Y%v{v)OYX*j#LJ$F%`2x(3n7-< z***w8L8W_^qJG-vI=ZHC`zc2+&mYiwVsp;hiaE_jTZgkL9fjOI_XXCkhT)Pk4X9D6 zLZ^hE#KT;#mUDDM6aXzh%QpfI+%6oYrzG_rU4&KNb*Fpdw7oCIp8kuz=)v;0SOWSq z*5<1sP?{Ic^G)RLM{y?#7dE_kW4olJj|Crcb1%2)pBV2_Al+@JrM37NEc_a){Ju#) zp2pk#@hhc;h3vcEyb|%d_!)1=*QOUllx=?b!B;Y~Lcg(DE=T23v+WTCwp|4Zdu`*F z?`9$;MWk;AXm)}3!Y`+c?dPnYM5{&34H;>hB}o#viar>|^D zY&UyC-74qo4c2czS~t#SIFO}lu)|S2ayI09*4En?=^6CumHOcL)q^&&>A=SQ7A==n z&z^jpE3eNPw*Or=%x@^R_WbaZR!4Sa_7V^GX}}*_R+1a0dPl|{n4Nq=cKp0-q!EmJ zG;)n^y!CX^-EcP*?cxUlk@U4liQ7Ct-(%0!Yu9el_4QGnCi;nUT6BtfNL%B`hJP@9 zzllQ%hTNgRio3d&j_>I0!M%Hrmt0UF<|3qvO8b$xA$9E7QOp|YJ{B4+N~K&lIOrSd z6s9}>eBjMzTxr|tgdwFCEpPorou;|5zVz9fi(xv4KbyJG&yk zcEV6qO_py{&OD<`;2TKSGP(&_Xfgy7|QZUPrbczWLxlK`IeX zp(JB?y43!L%H8hxv(puQtxCyeX^URE*-Uw{F)uVU^s5HuXmY$JyfUpT+)pzVa;Kl3 zUV3b$dG}O6B4Y5PZN26{yK^!tcFCYB@>Nm~tJqUm9pbmz>+-nab28h6P9feaGt##R;Amko;`c!n={~cYb~u=C1#0{HJ!ZCO>Vz5 zcV(k~Xb_eZpUd=$v^p`dI**6K3)md@qj8armdD3|B^FJN4 z!t}Jt26Fzv0^RUI&*bdupSf|XBezn<%|xXrU;*Nv(`(-(6ghlgZ1VxLnTxgV#NC~` zx`zivme!}UKek@6w7KGZQ{4B((oG1K59%gEmX1zN`hy2gr+FTfl|)JX#g|_$>1p() zzP=pYp$*KjRK_Hi>3|OHHpLr>H8d4&f~Za@nMC6=xpA@0rw~gCi%LsNhkLFo3xANE z?-9uX-%(|k@cmELXUy&tm)FX-lca8GWyRDZU59iU;sn94Aifvq9g=C6mzKXQ%&ebg z@YhYJ7A5@kTsgtgN$9h($5bpoINI5UlwTA&;?IaeWI!TG~`~d-~lO~ z!AGx^Jbqku=;U3MLKJ^66)TM1nO~&-gAV&G+Q$h)Xe^t4#-$8?8RYG3uzW<*>tbP` zBOYGqis#*R__LW&_iFXKf|Z9;`aR9G8*aDd!u+G+*d%`6s1}~)COOg52p+rQ zjc3tglAkad0Fk;>98XZtTM^+#beR@uaMo-{^3gLZnzQo_diS0}BY&}?5RgL%o~N(e zocb7?)f;AD5LdXaYUX+H=RfG#<>ra9%T-}RFzZc=3rnC*`qzD-QRaD{;RTU!U^Aq(m;@6sJOP;tKNSibb34#w=PCD_&y9j+t;{ z7BXfax|8UOFMcDO_r{{q#l=On$v3JW6A7yqGyET}pd^e{8(Y4*rw)*m|7pIXx> z<+QLyrYzH_#|m6_Ljob$)edF3rT{a7)BD!vgwsWxav$kzZyy#j;A-zZU01~oXc&Gq z)AYWw>h+q((@x?*DS{PM&7U7?in(jw=FU+@DZQ^2RXhblIPP?{lTxI+3{sc;bf6{v zTzTIuPH~jmVq|0~*+h@SWP7^&=SNMo6{ksOE*J~3mS*4NaI-^vAW+Ghi=O`U)BlN*QO3y`B@A^9 z+UiRiNb&COeK2oh=o$^1R2rOPTt%g&zqYBkvY#H&alcW@k;{83o)V5XPw20+_W6&} z=SI(`@Cd};BI8UL`6<^^9m7Y}j{8UT^T&kFm@hzy&RZO7hrgvwc1t$&H+{9kZqU~^ zW2{nR_XGwms4vX6`F<9Cs(}b7ZX9N(Ul`Ns+qX_ND%(#rIH*Ca{y4=~($f0dXFw#p zJZlpoy?=QrW4^ep6SdR5`V1d#$1;nrvOJyanLqJ}d&Tt1TbK~xVlT2`9E0`=#p{|mdiW|%jSQCC&#FX2o_~fFzE3rP{O*y3@=Dn9k0yoao)kf8u z9@+UDptTjUiQ#hRNXDt=oot@Hs>m_V3~Pyu`PDC87;R-ySeI9;dg?{*^Qg!?(}WH< z#SL2&?#q(;!t#gIBI+|0HbRf|d$u~Ib2;kQ=!>zf#m6EKY-5n;-3J#p9Ez=|LG&Hz z7F?2P>RVsq61TD}+;~;|94FxW$ruadTUnSo)~*Je4m}r{f3m`Le{_<^l4F;j`rLn_ z-8;Qa_Tr*U(R03@S)62C}Joz1u%|@2#X!GzDLz8XN)c=0v+moQ4b#-cHXzS z>G7);0;l(TXQffWbxZ`XZ{p*R)&2IvtG}p_%=;NxELasx6EX8Rt0NuC&m25xh<(jp zUApNR7-YmQd8JpbJGfs#BCcu0rw~eh)v*2Wt+}s;t0L&l49LkC9?wmB*3i(6_p2lquF!FPmS{&n_%K!{o=z0b{w;H(AN&&RjYOtY(|~4 z^5kVSVhi)G@hE%V1$kZ?J10HDA9QNi)8^}=HEoyd?(2Qdl!5nb$Pu+pO#h?>Kp0W?nnQxS-9)KWAo|TkX1h^_rE9FSPAnXo*- zwE9vT?X9tkIN(D2ZeHo`o)9ysk|L(ROccZu)=ul5s~`#rBfOz8J3ITJ%8a8X3tDY! z-RkOf+oifXjXnc1ocAn5_4#M{#BbS-x6s+iX{Wt2k~mCvdMlC4ALmBTFCSrCm^7%s zy)khJI+^yxIOp4OYZGqgw+X-4<;>Y-Rn05wV3X@__qUqrgigL)CszofokKmMsOE9_ zl=3~!d-<UQtj*S+y&c=HN7a}vR^2?@6-Goc@!hIud+Yh@8C8?38M>cG(OS-JaQ zD0u40f+^lOlUQ}W1ZuiS>&}Eij2nX8kZB%*9JzJIo}LLQcbBbeiUi9bC3cW`lpzQE zZX&0+vS|2H#A^;0j58eACWP^%nWs4`ZU zggVa}mV%U^7L$Pkj3~+8`fZ~gxrk~VphPzPh`9f&hLN?oR>1Fr3M2#=!JRVP=D$3k9Me}JPR%>zE&G2%J~gGl<|kEKs*bA?kpn#0+{tL;Ln@1 zyGVrF@`awvqj_$%t(w(&i6nkmuX!>H^x#nIrLyiSc!=%JUjI*X=O0(|y~pu$@9dgw zGSVL9M>aa7n8{UsMD31LvXYLpagxX=$BeL)@?(AtJ+$IlB;0-|y$g`~7;o-mmw>&Eq4J=PqO0vU)&m zU>_sz0Hra`BR9sd44!)CWw8Ww)F>&a7d?x_dYc@Alf4Gie*co|WLEi-z!!J6)8{UG zxKG(XeP*f*kJ{Cx&p*B{eLK*?;!*cFt-8Tohp~4vGt#PugqHWh3LXD!^54SzqM&uR&P5l8iZ`RK!hpLS8bD}SM8sw_Iz+bP5OPe1bQ13c^`Uv0wVqCdwl)zv1DemHi@bu1xOKZENS7ci69zq@{FgIhTAjEF;Xcg?F z!+7u(S_2)ZQy0tJrUByOew1~ey7-(35lk*GyjG*!h>Hxplv(m=hjSGK9QYlrqV%!FC#31@l|twE)zLi}m*rWsHU z^(T-&-guE*Hv&8mgyo_aBJWEE@UXPA^A?*Zl*i%>*41>iy}QOEWW)-S5BFM1 z3DIM_wnn>}^3x*UNMvwWuoCm_8gRauk=^y*M}#(iRanbZ`B(}klf1=Zh(w6IYok!` zdVJgV^b7ek(73-g^u1o zzVvDnpTbAlOM*v6dSvW^F%>~)SM1{i4-YI1+-co1b5lg$bR{c{IZ2^U&ORYoAeR1W-G{j~ zCoeBAHza27UMx@ninkKYi;W9iU91qr27XQb*(pLyF$qOr$?jo^z{P!;RW|rp-6J#mRjE6E zIlDGoomydtX%QiITtY%EnMg5^7`;Q4PB~eI1k%Y#mL5?>cMSj$rKy(EXh=XXpCjo0 zvGX!q2-!?Z%t@ePFJHd=d1_A+6D`zM-;#wwtM8|v0RxjPM3hjBaYQ^dB$YDo9#cgw zGWqzQ%8&VOR);jHDi@kdShj7iB@3r})0~C+-0l9u<7I2Qv^&8m6v_;=z7Ukl%a$@M zxww%Yvvck~%T<5U+JM13Y^_0RI!x(~0d}`Og)#V1^Mm~&sq`aU=>H^|d0U^=)@qe{ zVC-@SpjZQ&JyY~zG7%&$(PWc6h^eBK&z7p!6#BTVy7fgu$&ANJ>J4TDp^%G;h?w`= z?9k9q8L^`I*%j>OjvbgdODvc_zY2?E4Bp%lCvfs%b zP7=Y=Y)+FlZvFU)=NZo_Nju9Pf$U1ZSzYtCLA?j3c_|ccUUZuDfycX7VnD{(8@)39 zj?H@ft6iHHx`dNosL7ZFw7}}|2fqWzkfkUv|Ki1OuP4ao!rDR%%J4|pUH)s7r}_#P z(&yWSbQ30lypa*_#S&J66x^+TP1MLuEBsSV7&V79ve#)5Y8#mv>Vj^MsUtl)jCCul zkP?fWvv2gqDcg;wYb$A*HHiguCJldq<^!p?X84JZItl7JfHc$YrnY7E4uPKHbuTW| zRJLMt@f8 z=3rv}La1a;IXyvtB4Z<{>fmOVrLRdwFTSH}VV2Ml!p7w1yCb=lXjM+ShH38M}9{P~=x(Yd^b&TzU?UX}G=K{P_pCKQU!vDkbTsS82 zvyl2oj4b#JTmKGxd3ta$I-=BBcR8V%G(k06t|)VZL>&zL9QqCk6;;C z%Oas>Q0)fNZ+&{+?wea%h8cO5qQ?_*o7k5IJd5O_k|ZofwIu>zG(+NczPvKQ+?(&> zKa1;;3o_=b>%m1=4^9 z&_iM+DzV|ol0jy*s$Qu|6~iRa zn6WybU)aesIfywL`^Pk~k7s?mSzCc-8o)cMwJ9}pD1fQ}@SHVkd&!~6Z%lY}VURQg zIM?Odn&fe*C-iyaDYp($uEH@wz54E}R75yi+;X#{q`gOb^bl z%n1mVb*@2Il=E2}sEmYCERQ0zOK30482nN}iWK7k{5eO6&YIDqe+SQJx7%90fq@e_ zg4F*0(MttRtXkLAv|&Yi@iAv*{(X7z3rbY?d9o#~P`SCi`@^_)pJYYi zV1s^Rq{fUQ+bJY_QgWN{p%8j*uhI&__&nqP$y_u2fxJ41CHv7xf?ijenK?#?nupE# z=DD8ou8FVHA!0`*8*>i_JuSHJm%2{f`_buEd6l4%rrW#l$M!1sM_E0H#;U$|)V)L4 ziI>rSlBC>+RKZLz7W3?|Br~NlXUu!exApsKY)z+jI%zB8&*Z$)CQ#1dY0*rpq5Im& zNKI!6qW$PA=+toPof^U9P=1y3)7>rdwyQs8?~`YoV0BK_pQ$u>ag+Cb&;c5H7>dvX z^uf}Jad}u@mnEHyNkk*eSO6|yz^lIo3dO=sR=UJdnm<&6m0sGgOWrNK`gh%>w$(|w94f7K2>P^S?vpQ|sp@@s^gLGxO z)SwL|1%Ng%)zyK9B?Qpc6N(^0SKtFdu76&k81q#JTbt;ICPvdQlwd-CU&4k+ubte7 zzTp$@wci>0$(X~J}I(4n+2leL|dtBNlkgx1Ay9v)DqOi?_^n<|Yr3*Pq zVrs+_SJdPdT@TA0yVH#|^$LaZ!&##xaAxs|F_62iP8WJ9Mcx-gUF5g z38C?MX^58~Cy}b`i2286iT%KfQ6VqNHu)~5qcrZ$L}F$YP>Rl**xdX=@m={F{nuyf z|MUD`JkR%8TOU4IArs-hzRWxCP3LdIEdH=&?tlJln8hVMk2T9D{J$G)x$X=YoGLu z6>Z7)Rs2<0jU^Ka1Xo(JN5#ZwGU`bjc$HYoq}81&w>5rOcd*3y9YaQuFNJYJLP({K zMRU;wNEVp}pwMECEmCmP?NU#8rb!rCwtMk2I>`(=^0U7X$M|}JW63^eYpm2AzTvHo zIj%UR95O_@rHElUTg}XSErN|9@MQMLcK7F3Ucoptky(E8s^OO9%a=C+Y~6OLlDba9 zkEJzzKyE$RpS~T3EPI&lqZ!bY78ubh9m%W^@D` zQV3@7Je3n?(>~g;TEJ8Y3ya|I%%)-H1W}c)^?$SfI1d7el!j)Kv=WL5mRf8g8}r8C z_^j6d{Q)_D5AUmc%)5BLyF8|hyjS^`dzG7__lHZWhKhDXV-RVU;Bin79_Yet3NM_+$^7+X#B#kZ~+b3#Fsj&U~#jIy+i z%0*;GNG9CBEb7s%+u3Skk+4S(zfl<6+jpX7t~|Dl9ANXlJla3KX5~aGu~NYd@%V3x zKdL{IB)bs>g{@QD9glRMxc#5m!*k1vd%paP{{OMVItN$hZ;Db4M#d%p literal 0 HcmV?d00001 diff --git a/docs/source/dev/dockerfile/dockerfile.rst b/docs/source/dev/dockerfile/dockerfile.rst new file mode 100644 index 0000000000000..a07463392dbe8 --- /dev/null +++ b/docs/source/dev/dockerfile/dockerfile.rst @@ -0,0 +1,50 @@ +Dockerfile +==================== + +See `here `_ for the main Dockerfile to construct +the image for running an OpenAI compatible server with vLLM. + +- Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes: + + - All build stages + - The default build target (highlighted in grey) + - External images (with dashed borders) + + The edges of the build graph represent: + + - FROM ... dependencies (with a solid line and a full arrow head) + - COPY --from=... dependencies (with a dashed line and an empty arrow head) + - RUN --mount=(.*)from=... dependencies (with a dotted line and an empty diamond arrow head) + + .. figure:: ../../assets/dev/dockerfile-stages-dependency.png + :alt: query + :width: 100% + :align: center + + Made using: https://github.com/patrickhoefler/dockerfilegraph + + Commands to regenerate the build graph (make sure to run it **from the `root` directory of the vLLM repository** where the dockerfile is present): + + .. code:: bash + + dockerfilegraph -o png --legend --dpi 200 --max-label-length 50 --filename Dockerfile + + or in case you want to run it directly with the docker image: + + .. code:: bash + + docker run \ + --rm \ + --user "$(id -u):$(id -g)" \ + --workdir /workspace \ + --volume "$(pwd)":/workspace \ + ghcr.io/patrickhoefler/dockerfilegraph:alpine \ + --output png \ + --dpi 200 \ + --max-label-length 50 \ + --filename Dockerfile \ + --legend + + (To run it for a different file, you can pass in a different argument to the flag `--filename`.) + + \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index e8daa5f052754..e0269987ec5d8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,6 +102,7 @@ Documentation dev/sampling_params dev/engine/engine_index dev/kernel/paged_attention + dev/dockerfile/dockerfile Indices and tables ================== From 111815d482ba2b724541994da12736615101ef5e Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:46:12 -0400 Subject: [PATCH 158/413] [Kernel] Support Fp8 Checkpoints (Dynamic + Static) (#4332) Co-authored-by: Philipp Moritz Co-authored-by: Woosuk Kwon Co-authored-by: mgoin Co-authored-by: Tyler Michael Smith Co-authored-by: Cody Yu --- tests/models/test_fp8.py | 90 ++++++++ vllm/model_executor/layers/linear.py | 58 ++++- .../model_executor/layers/quantization/fp8.py | 199 +++++++++++++++--- 3 files changed, 307 insertions(+), 40 deletions(-) create mode 100644 tests/models/test_fp8.py diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py new file mode 100644 index 0000000000000..e87a1783a83f1 --- /dev/null +++ b/tests/models/test_fp8.py @@ -0,0 +1,90 @@ +# flake8: noqa +"""Tests fp8 models against ground truth generation +Note: these tests will only pass on L4 GPU. +""" +import os + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-FP8", + "meta-llama/Meta-Llama-3-8B-Instruct", +] + +EXPECTED_STRS_MAP = { + "nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + ], + "meta-llama/Meta-Llama-3-8B-Instruct": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], +} + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) + + +@pytest.mark.skipif(fp8_not_supported, + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models( + example_prompts, + model_name, +) -> None: + model = LLM(model=model_name, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + quantization="fp8") + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + + params = SamplingParams(max_tokens=20, temperature=0) + generations = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(generations) + expected_strs = EXPECTED_STRS_MAP[model_name] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4d43ed4c5f14a..289b317cc991e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -246,6 +246,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -254,6 +258,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -317,7 +327,12 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -331,14 +346,13 @@ def weight_loader(self, current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -353,15 +367,14 @@ def weight_loader(self, if output_dim is not None: shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -370,11 +383,17 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -455,7 +474,11 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) if loaded_shard_id is None: # Loaded weight is already packed. @@ -473,14 +496,14 @@ def weight_loader(self, ] packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -502,6 +525,7 @@ def weight_loader(self, shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) @@ -509,8 +533,7 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -523,12 +546,17 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -611,6 +639,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data @@ -619,6 +651,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ba9f3149649c1..b57e1dde81a5f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,23 +1,36 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + class Fp8Config(QuantizationConfig): """Config class for FP8.""" def __init__( self, + is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme @classmethod @@ -30,10 +43,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - # TODO: PyTorch 2.3.0+ is required to run FP8 on - # SM 89 (e.g. Ada) GPUs. Specifically, this PR has to - # be included: https://github.com/pytorch/pytorch/pull/118881 - return 90 + return 89 @classmethod def get_config_filenames(cls) -> List[str]: @@ -41,11 +51,14 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - return cls(activation_scheme) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: if isinstance(layer, LinearBase): return Fp8LinearMethod(self) return None @@ -56,8 +69,12 @@ def get_scaled_act_names(self) -> List[str]: class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. - We now support common FP16/BF16 model checkpoints ONLY. The weight - scaling factor will be initialized after the model weights are loaded. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. @@ -71,6 +88,24 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, { + **extra_weight_attrs, + "fp8_scales_shard_indexer": + self.scales_shard_indexer, + }) + def create_weights( self, layer: torch.nn.Module, @@ -81,46 +116,150 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + del input_size, output_size output_size_per_partition = sum(output_partition_sizes) + + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - dtype=params_dtype), + dtype=weight_dtype), requires_grad=False) layer.register_parameter("weight", weight) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - set_weight_attrs(weight, extra_weight_attrs) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) - w_scale = Parameter( - torch.empty(1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("weight_scaling_factor", w_scale) + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + self._create_scale_param( + scale_name="weight_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + # ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + self._create_scale_param( + scale_name="act_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + def scales_shard_indexer( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, int): + pass + elif isinstance(shard_id, str): + if shard_id not in qkv_idxs: + raise ValueError(f"Unknown shard_id: {shard_id}") + shard_id = qkv_idxs[shard_id] + else: + ValueError(f"Shard id must be int or str but got {type(shard_id)}") + + return param[shard_id], loaded_weight def process_weights_after_loading(self, layer: Module) -> None: - # Although the quant_method is propagated to all layers, - # only linear layers invoke "create_weights". So we check - # whether "weight_scaling_facor" is registered to determine - # whether the layer is a linear layer that requires quantization. - if not hasattr(layer, "weight_scaling_factor"): + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.logical_widths = None + layer.act_scale = None return - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight) - # torch._scaled_mm requires column-major in the second - # input (weight), so we transpose the quantized weight. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scaling_factor.data.copy_(weight_scale) + # If checkpoint is fp8, requantize the separately quantized logical + # weights into a single fp8 weight with a single weight scale. + else: + # WEIGHT_SCALE / WEIGHT + # Loop over logical weights, requantizing with single scale. + max_w_scale = layer.weight_scale.max() + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(layer.weight[start:end, :], + layer.weight_scale[idx]) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max()) + start = end + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # ACT_SCALE + # Dynamic: set to None (required input to ops.scaled_fp8_quant). + # Static: set to max of the act_scales (since they are equal). + if self.quant_config.activation_scheme == "dynamic": + layer.act_scale = None + elif self.quant_config.activation_scheme == "static": + if not all_close_1d(layer.act_scale): + raise ValueError( + "All the act_scales for the logical weights of a layer " + f"must be equal. But got {layer.act_scale}") + layer.act_scale = Parameter(layer.act_scale.max(), + requires_grad=False) + else: + raise ValueError( + f"Unknown scheme {self.quant_config.activation_scheme}") def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qinput, x_scale = ops.scaled_fp8_quant(x) + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.act_scale is None and x_scale computed from x. + # If static, layer.act_scale is scalar and x_scale set to act_scale. + qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + + # Fused GEMM_DQ output, _ = torch._scaled_mm( qinput, layer.weight, out_dtype=x.dtype, scale_a=x_scale, - scale_b=layer.weight_scaling_factor, + scale_b=layer.weight_scale, bias=bias, ) + return output + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def per_tensor_quantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + +def per_tensor_dequantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight From a494140433be496a0321999955acf7e6387986b3 Mon Sep 17 00:00:00 2001 From: Florian Greinacher Date: Wed, 1 May 2024 01:28:46 +0200 Subject: [PATCH 159/413] [Frontend] Support complex message content for chat completions endpoint (#3467) Co-authored-by: Lily Liu Co-authored-by: Cyrus Leung --- tests/entrypoints/test_openai_server.py | 19 ++++++++++ vllm/entrypoints/openai/serving_chat.py | 48 ++++++++++++++----------- 2 files changed, 46 insertions(+), 21 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 68332228ace08..a2a98abe7031c 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +async def test_complex_message_content(server, client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": [{ + "type": + "text", + "text": + "what is 1+1? please provide the result without any other text." + }] + }], + temperature=0, + seed=0) + content = resp.choices[0].message.content + assert content == "2" + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5ed042ef386ea..599f99e56a726 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -55,9 +55,16 @@ def _parse_chat_message_content( if isinstance(content, str): return [ConversationMessage(role=role, content=content)], [] - # To be implemented: https://github.com/vllm-project/vllm/pull/3467 - # To be implemented: https://github.com/vllm-project/vllm/pull/4200 - raise NotImplementedError("Complex input not supported yet") + texts: List[str] = [] + for _, part in enumerate(content): + if part["type"] == "text": + text = part["text"] + + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part['type']}") + + return [ConversationMessage(role=role, content="\n".join(texts))], [] async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request @@ -122,11 +129,12 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id) + request, result_generator, request_id, conversation) else: try: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + request, raw_request, result_generator, request_id, + conversation) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -139,8 +147,9 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> AsyncGenerator[str, None]: + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" @@ -179,12 +188,10 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, - list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get( + "role") == role: + last_msg_content = conversation[-1]["content"] if last_msg_content: for i in range(request.n): @@ -279,9 +286,10 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + self, request: ChatCompletionRequest, raw_request: Request, + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -322,11 +330,9 @@ async def chat_completion_full_generator( if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] for choice in choices: full_message = last_msg_content + choice.message.content From 715c2d854d56f2026c31f126a90e6e7859434a50 Mon Sep 17 00:00:00 2001 From: Alpay Ariyak <98838263+alpayariyak@users.noreply.github.com> Date: Tue, 30 Apr 2024 19:32:13 -0400 Subject: [PATCH 160/413] [Frontend] [Core] Tensorizer: support dynamic `num_readers`, update version (#4467) --- requirements-dev.txt | 2 +- setup.py | 2 +- vllm/model_executor/model_loader/tensorizer.py | 17 ++++++++++------- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 324039186142b..e6d375cbafa39 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ types-setuptools # testing pytest -tensorizer==2.9.0a0 +tensorizer==2.9.0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 6ba36b85ea318..a47b14ffcfc6e 100644 --- a/setup.py +++ b/setup.py @@ -408,7 +408,7 @@ def _read_requirements(filename: str) -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "tensorizer": ["tensorizer==2.9.0a1"], + "tensorizer": ["tensorizer==2.9.0"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 2d654b2fefb8d..0ce9fa95aa7e5 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -44,7 +44,7 @@ class TensorizerConfig: str, bytes, os.PathLike, int] vllm_tensorized: bool verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 + num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None @@ -104,7 +104,7 @@ class TensorizerArgs: str, bytes, os.PathLike, int] vllm_tensorized: bool verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 + num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None @@ -125,8 +125,9 @@ class TensorizerArgs: the hashes stored in the metadata. A `HashMismatchError` will be raised if any of the hashes do not match. num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is 1. This greatly increases - performance. + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. encryption_keyfile: File path to a binary file containing a binary key to use for decryption. `None` (the default) means no decryption. See the example script in @@ -199,10 +200,12 @@ def add_cli_args( "use for decryption. Can be a file path or S3 network URI.") group.add_argument( "--num-readers", - default=1, + default=None, type=int, help="Controls how many threads are allowed to read concurrently " - "from the source file.") + "from the source file. Default is `None`, which will dynamically " + "set the number of readers based on the available resources " + "and model size. This greatly increases performance.") group.add_argument( "--s3-access-key-id", default=None, @@ -337,7 +340,7 @@ def deserialize(self): per_second = convert_bytes(deserializer.total_tensor_bytes / duration) after_mem = get_mem_usage() deserializer.close() - logger.info("Deserialized %s in %0.2fs, %f/s", total_bytes_str, + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second) logger.info("Memory usage before: %s", before_mem) logger.info("Memory usage after: %s", after_mem) From dd1a50a8bc520b0e52ce7914f0263ebd576c197f Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Wed, 1 May 2024 07:33:33 +0800 Subject: [PATCH 161/413] [Bugfix][Minor] Make ignore_eos effective (#4468) --- vllm/sampling_params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0ed6a01a62212..f6e7a3ca792e4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -275,7 +275,8 @@ def update_from_generation_config( self, generation_config: Dict[str, Any]) -> None: """Update if there are non-default values from generation_config""" # Update eos_token_id for generation - if eos_ids := generation_config.get("eos_token_id"): + if (not self.ignore_eos) and (eos_ids := + generation_config.get("eos_token_id")): # it can be either int or list of int if isinstance(eos_ids, int): eos_ids = [eos_ids] From 6ad58f42c59eaee0a57c89f1feb08757524b93cf Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Wed, 1 May 2024 07:38:50 +0800 Subject: [PATCH 162/413] fix_tokenizer_snapshot_download_bug (#4493) --- vllm/transformers_utils/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index fa4693cb7dac1..9066db5a9e7f1 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -79,7 +79,7 @@ def get_tokenizer( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, # Ignore weights - we only need the tokenizer. - ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) tokenizer_name = tokenizer_path if tokenizer_mode == "slow": From ee37328da085af14f89ad1af8eb2c359ae2f46a1 Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Wed, 1 May 2024 08:42:09 +0800 Subject: [PATCH 163/413] Unable to find Punica extension issue during source code installation (#4494) Co-authored-by: Simon Mo --- docs/source/getting_started/installation.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index e7826114ffa9d..0c81f7ec6d2a9 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -53,6 +53,7 @@ You can also build and install vLLM from source: $ git clone https://github.com/vllm-project/vllm.git $ cd vllm + $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. .. tip:: From 2e240c69a9874743abc8b0b681e8c13d675beda3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 30 Apr 2024 18:06:34 -0700 Subject: [PATCH 164/413] [Core] Centralize GPU Worker construction (#4419) --- vllm/executor/gpu_executor.py | 83 +++++++++++++++---------------- vllm/executor/ray_gpu_executor.py | 32 +++--------- 2 files changed, 47 insertions(+), 68 deletions(-) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 489e66d586028..527a14ff6c67a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -6,6 +6,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -23,30 +24,47 @@ def _init_executor(self) -> None: else: self._init_spec_worker() - def _init_non_spec_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, load_config=self.load_config, - local_rank=0, - rank=0, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - is_driver_worker=True, + is_driver_worker=rank == 0, + ) + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + wrapper = WorkerWrapperBase( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + def _init_non_spec_worker(self): + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() self.driver_worker.init_device() self.driver_worker.load_model() @@ -57,41 +75,18 @@ def _init_spec_worker(self): from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.worker.worker import Worker - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - target_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) + target_worker = self._create_worker() - draft_worker = MultiStepWorker( + draft_worker_kwargs = self._get_worker_kwargs() + # Override draft-model specific worker args. + draft_worker_kwargs.update( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, # TODO allow draft-model specific load config. - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, + #load_config=self.load_config, ) + draft_worker = MultiStepWorker(**draft_worker_kwargs) spec_decode_worker = SpecDecodeWorker.from_workers( proposer_worker=draft_worker, scorer_worker=target_worker) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 3eb3726bd5a6d..16d239b9ab580 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -153,29 +153,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - def collect_arg_helper_func(**kwargs): - # avoid writing `{"name": value}` manually - return kwargs - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [] - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): - local_rank = node_workers[node_id].index(rank) - init_worker_all_kwargs.append( - collect_arg_helper_func( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=rank == 0, - )) + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") @@ -201,8 +186,7 @@ def execute_model(self, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output + return all_outputs[0] def _run_workers( self, From f458112e8afdb01bd3cb2e435db314c6bc227973 Mon Sep 17 00:00:00 2001 From: harrywu <63134210+HarryWu99@users.noreply.github.com> Date: Wed, 1 May 2024 11:21:39 +0800 Subject: [PATCH 165/413] [Misc][Typo] type annotation fix (#4495) --- vllm/engine/llm_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 835803fd4e75d..4caecb8a51598 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,7 +8,8 @@ LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, + SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats from vllm.engine.output_processor.interfaces import ( @@ -485,7 +486,7 @@ def has_unfinished_requests(self) -> bool: def _process_model_outputs( self, output: List[SamplerOutput], - scheduled_seq_groups: List[SequenceGroup], + scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], ) -> List[RequestOutput]: From a822eb3413087062a38cea495564ec4a7093c3e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Wed, 1 May 2024 11:41:32 +0800 Subject: [PATCH 166/413] [Misc] fix typo in block manager (#4453) --- vllm/core/block_manager_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 1fac2636e86fa..73e7dafb72c7f 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -391,7 +391,7 @@ def append_slots( block_table.append(block_table[len(block_table) % self.block_sliding_window]) else: - # The sequence has a new logical block. + # The sequence hash a new logical block. # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) From c3845d82dc3d1831714898114f87d9c103e2dd41 Mon Sep 17 00:00:00 2001 From: Robert Caulk Date: Wed, 1 May 2024 05:48:39 +0200 Subject: [PATCH 167/413] Allow user to define whitespace pattern for outlines (#4305) --- tests/entrypoints/test_guided_processors.py | 4 +++- vllm/entrypoints/openai/protocol.py | 10 ++++++++++ .../guided_decoding/outlines_decoding.py | 8 +++++--- .../guided_decoding/outlines_logits_processors.py | 7 +++---- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 30f0ad5d8272f..41c871ca40bc8 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -57,7 +57,9 @@ def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) - json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + json_LP = JSONLogitsProcessor(TEST_SCHEMA, + tokenizer, + whitespace_pattern=None) regex_LP.init_state() token_ids = tokenizer.encode( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0a949f9867754..731596e80bd71 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -146,6 +146,11 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified, will override the default guided decoding backend " "of the server for this specific request. If set, must be either " "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-chat-completion-extra-params @@ -285,6 +290,11 @@ class CompletionRequest(OpenAIBaseModel): "If specified, will override the default guided decoding backend " "of the server for this specific request. If set, must be one of " "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-completion-extra-params diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 53efebb604048..8403604286903 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor( result = await loop.run_in_executor(global_thread_pool, _get_cached_logits_processor, guide, - tokenizer, mode) + tokenizer, mode, + request.guided_whitespace_pattern) logits_processor = copy(result) # reset logits processor's internal state @@ -117,9 +118,10 @@ def _get_guide_and_mode( @lru_cache(maxsize=32) def _get_cached_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, - mode: GuidedDecodingMode): + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None]): if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.GRAMMAR: diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 25ab5bf8b6a9c..a131c6a1b92b4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -18,7 +18,7 @@ import math from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Optional, Union +from typing import Callable, DefaultDict, Dict, List, Union import torch from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM @@ -80,10 +80,9 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, - schema: Union[str, Dict, BaseModel], + def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Optional[str] = None): + whitespace_pattern: Union[str, None]): """Compile the FSM that drives the JSON-guided generation. Parameters From d6f4bd7cddc9546c38568c92c3772d22940a09f2 Mon Sep 17 00:00:00 2001 From: Jee Li Date: Wed, 1 May 2024 12:18:14 +0800 Subject: [PATCH 168/413] [Misc]Add customized information for models (#4132) --- tests/models/test_big_models.py | 15 +++++++++++++ tests/models/test_models.py | 15 +++++++++++++ vllm/attention/layer.py | 7 ++++++ vllm/model_executor/layers/activation.py | 3 +++ vllm/model_executor/layers/layernorm.py | 5 +++++ vllm/model_executor/layers/linear.py | 22 +++++++++++++++++++ .../model_executor/layers/logits_processor.py | 6 +++++ .../model_executor/layers/rotary_embedding.py | 6 +++++ .../layers/vocab_parallel_embedding.py | 8 +++++++ 9 files changed, 87 insertions(+) diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 504eaad43c8d7..3dde498bcd639 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -43,3 +43,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/tests/models/test_models.py b/tests/models/test_models.py index cfe2539e3a052..e4609620387fa 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -49,3 +49,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index fc65ae108dbb1..ee7be26c0876c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,3 +47,10 @@ def forward( ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, kv_scale) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + return s diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index baf1d4f266181..d101aa323b0e1 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -67,6 +67,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_tanh_and_mul(out, x) return out + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + class NewGELU(nn.Module): diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a6619714b8aab..8de0794158986 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -64,3 +64,8 @@ def forward( self.variance_epsilon, ) return out + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 289b317cc991e..7726dcb9a5fbd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -181,6 +181,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -281,6 +287,14 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. @@ -685,3 +699,11 @@ def forward(self, input_): output = output_ output_bias = self.bias return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 22620d9fc86d9..91eb96998c3cf 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -70,6 +70,12 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = logits[:, :self.org_vocab_size] return logits + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + def _prune_hidden_states( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 25365a9b50a1f..857d70fadcb57 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -156,6 +156,12 @@ def forward( self.cos_sin_cache, self.is_neox_style) return query, key + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 088c0849243c0..4585b1679cb5c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -105,6 +105,14 @@ def forward(self, input_): output = tensor_model_parallel_all_reduce(output_parallel) return output + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. From 6f1df80436c46175e09f660a99075a5eba3a2273 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 1 May 2024 21:45:42 +0900 Subject: [PATCH 169/413] [Test] Add ignore_eos test (#4519) --- tests/samplers/test_ignore_eos.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 tests/samplers/test_ignore_eos.py diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py new file mode 100644 index 0000000000000..864657a3c2b28 --- /dev/null +++ b/tests/samplers/test_ignore_eos.py @@ -0,0 +1,31 @@ +"""Make sure ignore_eos works. + +Run `pytest tests/samplers/test_ignore_eos.py`. +""" + +import pytest + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [1024]) +def test_beam_search_single_input( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + example_prompts = "1 + 1 is" + + vllm_model = vllm_runner(model, dtype=dtype) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + ignore_eos_output = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params) + print(len(ignore_eos_output[0].outputs[0].token_ids)) + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10 + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0 From a88bb9b032d75aad74b2e1bd3d97b8e8a24e8b9d Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Thu, 2 May 2024 00:11:03 +0800 Subject: [PATCH 170/413] [Bugfix] Fix the fp8 kv_cache check error that occurs when failing to obtain the CUDA version. (#4173) Signed-off-by: AnyISalIn --- vllm/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index a5512c657e038..db4398addae3c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -353,7 +353,8 @@ def _verify_cache_dtype(self) -> None: elif self.cache_dtype == "fp8": if not is_hip(): nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version < Version("11.8"): + if nvcc_cuda_version is not None \ + and nvcc_cuda_version < Version("11.8"): raise ValueError( "FP8 is not supported when cuda version is" "lower than 11.8.") From 4dc8026d8614185ece28dd3fcd82aa0dabb4f79c Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Wed, 1 May 2024 12:14:13 -0400 Subject: [PATCH 171/413] [Bugfix] Fix 307 Redirect for `/metrics` (#4523) --- vllm/engine/metrics.py | 2 +- vllm/entrypoints/openai/api_server.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 45bfad03ec867..3c4aac91549a9 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -119,7 +119,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): buckets=[1, 2, 5, 10, 20], ) self.counter_request_success = Counter( - name="vllm:request_success", + name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index af9ba7a3bc825..40103f70a31a3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,6 +2,7 @@ import importlib import inspect import os +import re from contextlib import asynccontextmanager from http import HTTPStatus @@ -12,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app +from starlette.routing import Mount import vllm from vllm.engine.arg_utils import AsyncEngineArgs @@ -55,8 +57,10 @@ def parse_args(): # Add prometheus asgi middleware to route /metrics requests -metrics_app = make_asgi_app() -app.mount("/metrics", metrics_app) +route = Mount("/metrics", make_asgi_app()) +# Workaround for 307 Redirect for /metrics +route.path_regex = re.compile('^/metrics(?P.*)$') +app.routes.append(route) @app.exception_handler(RequestValidationError) From e491c7e053e5d774f321612b3a400ca2fb424d32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=CE=B1n=C3=A7ois?= Date: Wed, 1 May 2024 19:14:16 +0200 Subject: [PATCH 172/413] [Doc] update(example model): for OpenAI compatible serving (#4503) --- docs/source/serving/openai_compatible_server.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 388b5daa79a92..c157d8ba998da 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123 +python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -16,7 +16,7 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="mistralai/Mistral-7B-Instruct-v0.2", + model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ {"role": "user", "content": "Hello!"} ] @@ -37,7 +37,7 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="mistralai/Mistral-7B-Instruct-v0.2", + model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], @@ -87,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format) +An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat From 69909126a7f6fb1e3254dc0dec87dc6e78e1a0e2 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 1 May 2024 17:41:17 +0000 Subject: [PATCH 173/413] [Bugfix] Use random seed if seed is -1 (#4531) --- vllm/sampling_params.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f6e7a3ca792e4..5fa94eb149ffb 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -139,7 +139,10 @@ def __init__( self.top_p = top_p self.top_k = top_k self.min_p = min_p - self.seed = seed + if seed == -1: + self.seed = None + else: + self.seed = seed self.use_beam_search = use_beam_search self.length_penalty = length_penalty self.early_stopping = early_stopping From 8b798eec75cde6eb6fe65b5d673dd9bd4eaef799 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 1 May 2024 12:01:50 -0600 Subject: [PATCH 174/413] [CI/Build][Bugfix] VLLM_USE_PRECOMPILED should skip compilation (#4534) Signed-off-by: Travis Johnson --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index a47b14ffcfc6e..d534cec437261 100644 --- a/setup.py +++ b/setup.py @@ -378,6 +378,7 @@ def _read_requirements(filename: str) -> List[str]: "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } if os.environ.get("VLLM_USE_PRECOMPILED"): + ext_modules = [] package_data["vllm"].append("*.so") setup( From b38e42fbca978d62cc8330bdcf8da91c72cb2ebc Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Thu, 2 May 2024 02:13:03 +0800 Subject: [PATCH 175/413] [Speculative decoding] Add ngram prompt lookup decoding (#4237) Co-authored-by: Lei Wen --- tests/spec_decode/e2e/conftest.py | 58 +++++ ...tness.py => test_multistep_correctness.py} | 60 +---- .../spec_decode/e2e/test_ngram_correctness.py | 172 ++++++++++++++ tests/spec_decode/test_multi_step_worker.py | 50 ++--- tests/spec_decode/test_ngram_worker.py | 206 +++++++++++++++++ vllm/config.py | 87 +++++--- vllm/engine/arg_utils.py | 18 ++ vllm/executor/gpu_executor.py | 8 +- vllm/spec_decode/batch_expansion.py | 4 +- vllm/spec_decode/multi_step_worker.py | 209 ++---------------- vllm/spec_decode/ngram_worker.py | 190 ++++++++++++++++ vllm/spec_decode/spec_decode_worker.py | 45 ++-- vllm/spec_decode/top1_proposer.py | 200 +++++++++++++++++ vllm/spec_decode/util.py | 16 +- 14 files changed, 1004 insertions(+), 319 deletions(-) rename tests/spec_decode/e2e/{test_correctness.py => test_multistep_correctness.py} (88%) create mode 100644 tests/spec_decode/e2e/test_ngram_correctness.py create mode 100644 tests/spec_decode/test_ngram_worker.py create mode 100644 vllm/spec_decode/ngram_worker.py create mode 100644 vllm/spec_decode/top1_proposer.py diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 5d3469c4210ee..0eb784a9c5ac5 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,4 +1,5 @@ import asyncio +from itertools import cycle from typing import List, Optional, Tuple, Union import pytest @@ -185,3 +186,60 @@ def get_output_from_llm_generator( del llm return tokens, token_ids + + +def run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + ) + + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) + + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if print_tokens: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py similarity index 88% rename from tests/spec_decode/e2e/test_correctness.py rename to tests/spec_decode/e2e/test_multistep_correctness.py index ab8d913fb894a..f99e0f6778e59 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -35,7 +35,8 @@ from vllm import SamplingParams -from .conftest import get_output_from_llm_generator +from .conftest import (get_output_from_llm_generator, + run_greedy_equality_correctness_test) @pytest.mark.parametrize( @@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) - - -def run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - print_tokens: bool = False): - """Helper method that compares the outputs of both the baseline LLM and - the test LLM. It asserts greedy equality, e.g. that the outputs are exactly - the same when temperature is zero. - """ - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - "San Francisco is know for its", - "Facebook was created in 2004 by", - "Curious George is a", - "Python 3.11 brings improvements to its", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - # If the test requires that we generated max_output_len tokens, then set the - # sampling params to ignore eos token. - ignore_eos = force_output_len - - sampling_params = SamplingParams( - max_tokens=max_output_len, - ignore_eos=ignore_eos, - temperature=temperature, - ) - - spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) - - (baseline_batch_tokens, - baseline_batch_token_ids) = get_output_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - assert len(baseline_batch_token_ids) == len(prompts) - assert len(spec_batch_token_ids) == len(prompts) - - for i, (baseline_token_ids, baseline_tokens, spec_token_ids, - spec_tokens) in enumerate( - zip(baseline_batch_token_ids, baseline_batch_tokens, - spec_batch_token_ids, spec_batch_tokens)): - if print_tokens: - print(f'{i=} {baseline_tokens=}') - print(f'{i=} {spec_tokens=}') - print(f'{i=} {baseline_token_ids=}') - print(f'{i=} {spec_token_ids=}') - assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py new file mode 100644 index 0000000000000..44ef400c91d34 --- /dev/null +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -0,0 +1,172 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, +and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. +Since there is no model is needed for generate the proposal, we could make +the testcase much simpler than drafter multi-step one. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various ngram sizes / speculative sizes + +With those tests, we can say at least, ngram spec would not break the correctess +for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 3, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ] + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 1, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index e7aaa1ff4eff8..98f2731de9aa3 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -6,8 +6,8 @@ from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplerOutput -from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer, - MultiStepWorker) +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker from .utils import (assert_logprobs_dict_allclose, create_batch, @@ -117,8 +117,8 @@ def test_same_output_for_single_step(): zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - actual_output = multi_step_worker.execute_model_multi_step( - **multi_step_execute_model_data.to_dict(), num_steps=num_steps) + actual_output, _ = multi_step_worker.sampler_output( + **multi_step_execute_model_data.to_dict(), sample_len=num_steps) assert len(actual_output) == num_steps actual_output = actual_output[0] @@ -200,8 +200,8 @@ def test_same_output_for_multi_step(): # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - multi_step_output = multi_step_worker.execute_model_multi_step( - **execute_model_data.to_dict(), num_steps=num_steps) + multi_step_output, _ = multi_step_worker.sampler_output( + **execute_model_data.to_dict(), sample_len=num_steps) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -266,7 +266,7 @@ def test_same_output_for_multi_step(): @torch.inference_mode() def test_draft_proposals_full_speculation_len(): - """Verify DraftModelTop1Proposer correctly handles case where all sequences + """Verify Top1Proposer correctly handles case where all sequences can speculate. """ k = 10 @@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len(): device = 'cuda:0' draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=2048, vocab_size=vocab_size, + max_proposal_len=2048, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(batch_size, @@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len(): device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True execute_model_data, _, _ = create_batch(batch_size, k) proposals = proposer.get_proposals( **execute_model_data.to_dict(), - max_proposal_len=k, + proposal_len=k, ) assert torch.is_tensor(proposals.proposal_token_ids) @@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len(): @torch.inference_mode() def test_draft_proposals_no_speculations(): - """Verify DraftModelTop1Proposer correctly handles case where no sequences + """Verify Top1Proposer correctly handles case where no sequences can speculate. """ k = 10 @@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations(): prompt_len = 10 draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=prompt_len + k - 1, vocab_size=vocab_size, + max_proposal_len=prompt_len + k - 1, ) execute_model_data, _, _ = create_batch(batch_size, @@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations(): proposals = proposer.get_proposals( **execute_model_data.to_dict(), - max_proposal_len=k, + proposal_len=k, ) assert torch.is_tensor(proposals.proposal_token_ids) @@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations(): @torch.inference_mode() def test_draft_proposals_mixed_k(): - """Verify DraftModelTop1Proposer correctly handles case some sequences can + """Verify Top1Proposer correctly handles case some sequences can speculate and some can't. """ k = 10 @@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k(): for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=long_prompt_len + prev_output_token_len + k - 1, vocab_size=vocab_size, + max_proposal_len=long_prompt_len + prev_output_token_len + k - 1, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(expected_num_proposal_seqs, @@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k(): device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True execute_model_data, _, _ = create_batch( batch_size, @@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k(): proposals = proposer.get_proposals( **execute_model_data.to_dict(), - max_proposal_len=k, + proposal_len=k, ) assert torch.is_tensor(proposals.proposal_token_ids) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py new file mode 100644 index 0000000000000..ee4135015713d --- /dev/null +++ b/tests/spec_decode/test_ngram_worker.py @@ -0,0 +1,206 @@ +import torch + +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import (create_execute_model_data, + create_seq_group_metadata_from_prompts, create_worker) + + +def test_ngram_algo_correctness_for_single_no_match(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario cannot find any candidate in one single batch + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + ] + + proposal_len = 5 + final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + ngram_sampler_output_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + proposals = proposer.get_proposals( + **ngram_sampler_output_data.to_dict(), + proposal_len=proposal_len, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([1]) + assert proposals.proposal_lens.tolist() == [0] + + +def test_ngram_algo_correctness_for_batches_not_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find some candidate not full in batchs + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + # shall find no candidate as exceed max_proposal_len + [ + 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37, + 38, 31, 32, 33 + ], + ] + + proposal_len = 5 + final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + ngram_sampler_output_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + proposals = proposer.get_proposals( + **ngram_sampler_output_data.to_dict(), + proposal_len=proposal_len, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([5]) + + assert proposals.proposal_lens.tolist( + ) == [proposal_len for _ in range(4)] + [0] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == 0 + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] + assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] + assert proposals.proposal_token_ids[4][i] == -1 + + +def test_ngram_algo_correctness_for_batches_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find candidate in all batchs + """ + + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + ] + + proposal_len = 5 + final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + ngram_sampler_output_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + proposals = proposer.get_proposals( + **ngram_sampler_output_data.to_dict(), + proposal_len=proposal_len, + ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([3]) + + assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1] + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5] diff --git a/vllm/config.py b/vllm/config.py index db4398addae3c..257d49b6e804f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -682,6 +682,8 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -708,6 +710,10 @@ def maybe_create_spec_config( use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + ngram_prompt_lookup_max (Optional[int]): Max size of ngram token + window, if provided. + ngram_prompt_lookup_min (Optional[int]): Min size of ngram token + window, if provided. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if @@ -742,39 +748,57 @@ def maybe_create_spec_config( draft_code_revision = None draft_quantization = None - draft_model_config = ModelConfig( - model=speculative_model, - tokenizer=target_model_config.tokenizer, - tokenizer_mode=target_model_config.tokenizer_mode, - trust_remote_code=target_model_config.trust_remote_code, - dtype=target_model_config.dtype, - seed=target_model_config.seed, - revision=draft_revision, - code_revision=draft_code_revision, - tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=None, - quantization=draft_quantization, - enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, - max_logprobs=target_model_config.max_logprobs, - ) - - draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - speculative_max_model_len, - draft_model_config.max_model_len, - target_model_config.max_model_len, - )) + if speculative_model == "[ngram]": + assert (ngram_prompt_lookup_max is not None + and ngram_prompt_lookup_max > 0) + if ngram_prompt_lookup_min is None: + ngram_prompt_lookup_min = 0 + else: + assert ngram_prompt_lookup_max > ngram_prompt_lookup_min - draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - target_parallel_config)) + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + draft_model_config = target_model_config + draft_parallel_config = target_parallel_config + else: + ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_min = 0 + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=None, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_context_len_to_capture=target_model_config. + max_context_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) + + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config)) return SpeculativeConfig( draft_model_config, draft_parallel_config, num_speculative_tokens, + ngram_prompt_lookup_max, + ngram_prompt_lookup_min, ) @staticmethod @@ -842,6 +866,8 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + ngram_prompt_lookup_max: int, + ngram_prompt_lookup_min: int, ): """Create a SpeculativeConfig object. @@ -854,6 +880,8 @@ def __init__( self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min self._verify_args() @@ -877,7 +905,10 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def __repr__(self) -> str: - draft_model = self.draft_model_config.model + if self.ngram_prompt_lookup_max > 0: + draft_model = "[ngram]" + else: + draft_model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd6437ee44c28..7637616ae6089 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -75,6 +75,8 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -449,6 +451,20 @@ def add_cli_args( 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--ngram-prompt-lookup-max', + type=int, + default=EngineArgs.ngram_prompt_lookup_max, + help='Max size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--ngram-prompt-lookup-min', + type=int, + default=EngineArgs.ngram_prompt_lookup_min, + help='Min size of window for ngram prompt lookup in speculative ' + 'decoding.') + parser.add_argument('--model-loader-extra-config', type=str, default=EngineArgs.model_loader_extra_config, @@ -502,6 +518,8 @@ def create_engine_config(self, ) -> EngineConfig: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ) scheduler_config = SchedulerConfig( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 527a14ff6c67a..a58856a12f0c8 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -73,7 +73,6 @@ def _init_spec_worker(self): """ assert self.speculative_config is not None - from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker target_worker = self._create_worker() @@ -86,10 +85,11 @@ def _init_spec_worker(self): # TODO allow draft-model specific load config. #load_config=self.load_config, ) - draft_worker = MultiStepWorker(**draft_worker_kwargs) - spec_decode_worker = SpecDecodeWorker.from_workers( - proposer_worker=draft_worker, scorer_worker=target_worker) + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + ) assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index c29b838f854c0..8b113e93474ff 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -333,13 +333,13 @@ def _split_scoring_output( sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens target_token_ids, target_probs = sampler_output_to_torch( - [sampler_output]) + [sampler_output], True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens non_spec_target_token_ids, non_spec_target_probs = ( - sampler_output_to_torch([sampler_output])) + sampler_output_to_torch([sampler_output], True)) return (target_token_ids, target_probs, non_spec_target_token_ids, non_spec_target_probs) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 7cf338bbae5f0..d031bc85af160 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,12 +1,11 @@ import copy -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import torch from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -26,29 +25,37 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Lazy initialization list. - self._proposer: DraftModelTop1Proposer + self._proposer: Top1Proposer def init_device(self): super().init_device() - self._proposer = DraftModelTop1Proposer( + self._proposer = Top1Proposer( self, self.device, - self.max_model_len, self.vocab_size, + max_proposal_len=self.max_model_len, ) + def set_include_gpu_probs_tensor(self): + # Need include_gpu_probs_tensor for multi_step_worker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + @torch.inference_mode() - def execute_model_multi_step( + def sampler_output( self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - num_steps: int, - ) -> List[SamplerOutput]: - """Run the model forward pass num_steps times. Returns the list of - sampler output, one per model forward pass. + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. """ self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -58,12 +65,12 @@ def execute_model_multi_step( copied_seq_group_metadata_list = self._shallow_copy_inputs( seq_group_metadata_list) - # Assert enough KV space for num_steps tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, num_steps) + # Assert enough KV space for sample_len tokens per sequence. + self._assert_enough_kv_space(seq_group_metadata_list, sample_len) - # Run model num_steps times. + # Run model sample_len times. model_outputs = [] - for _ in range(num_steps): + for _ in range(sample_len): model_output = super().execute_model( seq_group_metadata_list=copied_seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -78,7 +85,7 @@ def execute_model_multi_step( copied_seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs + return model_outputs, True def get_spec_proposals( self, @@ -206,171 +213,3 @@ def _raise_if_unsupported( for seq_group_metadata in seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") - - -class DraftModelTop1Proposer(SpeculativeProposer): - """Helper class which separates out sequences which would exceed the max - model length when speculated upon. - - This allows combinations of models such as JackFram/llama-68m draft with - meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of - 2048 while Llama2-13b has max_position_embeddings of 4096. - - We treat the sequences which exceed the proposal draft model length as - "non-spec sequences". Essentially they skip the draft model and go through - normal decoding in the target model. - - Currently, only proposal_lens of 0 and k are supported, where k is a global - batch proposal length. In the future vLLM should support per-sequence - proposal lengths. - """ - - def __init__( - self, - draft_worker: MultiStepWorker, - device: str, - max_model_len: int, - vocab_size: int, - ): - self._draft_worker = draft_worker - self._device = device - self._max_model_len = max_model_len - self._vocab_size = vocab_size - - def get_proposals( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, - ) -> SpeculativeProposals: - """Get speculative proposals given the input batch. - - Sequences which would exceed the max model length are skipped during - speculation. - """ - - # Split speculative- and non-speculative- sequences. - (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) = self._split_by_max_model_len( - seq_group_metadata_list, max_proposal_len) - - if nonzero_proposal_len_seqs: - # Speculate tokens using the draft worker for the speculative - # sequences. - maybe_sampler_output = self._draft_worker.execute_model_multi_step( - seq_group_metadata_list=nonzero_proposal_len_seqs, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_steps=max_proposal_len, - ) - else: - # If no sequences can be speculated, set sampler output to None. - maybe_sampler_output = None - - # Combine speculative- and non-speculative sequences into the same - # representation. - proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), - max_proposal_len=max_proposal_len, - maybe_sampler_output=maybe_sampler_output, - proposal_lens=proposal_lens, - nonzero_proposal_len_indices=nonzero_proposal_len_indices, - ) - - proposals = SpeculativeProposals( - proposal_token_ids=proposal_tokens, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens, - ) - - return proposals - - def _split_by_max_model_len( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - max_proposal_len: int, - ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length. - """ - - proposal_lens: List[int] = [] - nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] - nonzero_proposal_len_indices: List[int] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_data = next(iter(seq_group_metadata.seq_data.values())) - seq_len = seq_data.get_len() - - # Currently only proposal lens of 0 or the global batch proposal len - # are supported. - if seq_len + max_proposal_len < self._max_model_len: - proposal_lens.append(max_proposal_len) - nonzero_proposal_len_seqs.append(seq_group_metadata) - nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) - - return (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) - - def _merge_outputs( - self, - batch_size: int, - max_proposal_len: int, - maybe_sampler_output: Optional[SamplerOutput], - proposal_lens: List[int], - nonzero_proposal_len_indices: List[int], - ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: - """After speculations are produced, merge the speculation results with - the skipped sequences. - """ - if maybe_sampler_output is None: - # If no speculative tokens, the sampler output will be None. - # In this case we return empty proposals. - proposal_tokens = torch.full(size=( - batch_size, - max_proposal_len, - ), - fill_value=-1, - dtype=torch.long, - device=self._device) - proposal_probs = torch.zeros(batch_size, - max_proposal_len, - self._vocab_size, - dtype=torch.float32, - device=self._device) - proposal_lens_tensor = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) - return proposal_tokens, proposal_probs, proposal_lens_tensor - - sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs = sampler_output_to_torch( - sampler_output) - - # Now, reformat the output GPU tensors such that each sequence has - # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - - entire_proposal_tokens = torch.full(size=(batch_size, - *proposal_tokens.shape[1:]), - fill_value=-1, - dtype=torch.long, - device=self._device) - entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = torch.zeros(batch_size, - *proposal_probs.shape[1:], - dtype=torch.float32, - device=self._device) - entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs - - proposal_tokens, proposal_probs = (entire_proposal_tokens, - entire_proposal_probs) - - proposal_lens_tensor = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len - - return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000000000..696ca964328cf --- /dev/null +++ b/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,190 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class NGramWorker(LoraNotSupportedWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implement prompt lookup decoding, + and in future we may also do RAG type drafter and other scenerios + which don't rely on LLM model to give proposals. + """ + + def __init__(self, *args, **kwargs): + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = kwargs["local_rank"] + self.vocab_size = kwargs["model_config"].get_vocab_size() + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"cuda:{self.local_rank}") + self.load_model = lambda *args, **kwargs: None + + # Current only support Top1Proposer + self._proposer = Top1Proposer( + self, + device=self.device, + vocab_size=self.vocab_size, + ) + + def set_include_gpu_probs_tensor(self): + # NGram don't need gpu sampler + pass + + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Optional[Dict[int, int]], + blocks_to_swap_out: Optional[Dict[int, int]], + blocks_to_copy: Optional[Dict[int, List[int]]], + ) -> None: + """NGram doesn't depend on model execution, just pass this function""" + pass + + def determine_num_available_blocks(self) -> None: + """NGram doesn't depend on model execution, no need to check blocks""" + pass + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """As there is no cache need to handle, just pass this function""" + pass + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes.""" + return 0 + + def sampler_output( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + sample_len: int, + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported( + seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + ) + + arr = [] + has_spec_out = False + for seq_group_metadata in seq_group_metadata_list: + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=self.device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min, + -1, + ): + ngram_tensor = input_ids[-1 * ngram_size:] + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + matches = (windows == ngram_tensor).all(dim=1) + match_indices = matches.nonzero(as_tuple=True)[0] + if match_indices.size()[0] > 1: + has_spec_out = True + res = seq_data.get_token_ids() + res = res[match_indices[0] + ngram_size:match_indices[0] + + ngram_size + sample_len] + res_len = len(res) + # pad 0 towards output as sample_len tokens required + res += [0] * (sample_len - res_len) + + break + else: + # if no candidate found, fill with 0 + res = [0] * sample_len + + arr.append(res) + + if not has_spec_out: + return None, False + + outputs = [] + token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device) + indices = token_ids.unsqueeze(2) + + token_probs = torch.zeros( + (len(seq_group_metadata_list), sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + token_probs.scatter_(2, indices, 1) + for i in range(len(seq_group_metadata_list)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_probs[i], + sampled_token_ids=token_ids[i], + )) + return outputs, False + + def get_spec_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + max_proposal_len: int, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_proposals( + seq_group_metadata_list, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + max_proposal_len, + ) + + def _raise_if_unsupported( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 4e70ea9686005..e33bb4f3f6337 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -12,6 +12,7 @@ SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -48,8 +49,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): """ @classmethod - def from_workers(cls, proposer_worker: MultiStepWorker, - scorer_worker: WorkerBase) -> "SpecDecodeWorker": + def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs, + ) -> "SpecDecodeWorker": + + if "ngram_prompt_lookup_max" in draft_worker_kwargs: + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + else: + ngram_prompt_lookup_max = 0 + + if ngram_prompt_lookup_max > 0: + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + return SpecDecodeWorker( proposer_worker, scorer_worker, @@ -59,7 +79,7 @@ def from_workers(cls, proposer_worker: MultiStepWorker, def __init__( self, - proposer_worker: MultiStepWorker, + proposer_worker: WorkerBase, scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -134,8 +154,7 @@ def _configure_model_sampler_for_spec_decode(self): """ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor ) = True - (self.proposer_worker.model_runner.model.sampler. - include_gpu_probs_tensor) = True + self.proposer_worker.set_include_gpu_probs_tensor() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. @@ -183,8 +202,8 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") - logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", - num_lookahead_slots) + #logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", + # num_lookahead_slots) # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. @@ -216,7 +235,7 @@ def _run_no_spec( proposer and scorer model so that the KV cache is consistent between the two. """ - logger.info("run proposer worker no spec") + #logger.info("run proposer worker no spec") self.proposer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, @@ -225,7 +244,7 @@ def _run_no_spec( blocks_to_copy=blocks_to_copy, ) - logger.info("run target worker no spec") + #logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, @@ -259,7 +278,7 @@ def _run_speculative_decoding_step( sequence. """ - logger.info("get spec proposals") + #logger.info("get spec proposals") # Generate proposals using draft worker. assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None @@ -268,7 +287,7 @@ def _run_speculative_decoding_step( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) - logger.info("score proposals") + #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( seq_group_metadata_list, blocks_to_swap_in, @@ -278,11 +297,11 @@ def _run_speculative_decoding_step( proposals, ) - logger.info("verify proposals") + #logger.info("verify proposals") accepted_token_ids = self._verify_tokens(seq_group_metadata_list, proposal_scores, proposals, k) - logger.info("create output list") + #logger.info("create output list") return self._create_output_sampler_list(seq_group_metadata_list, accepted_token_ids, k) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000000000..6766a2deb8eb8 --- /dev/null +++ b/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,200 @@ +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.util import sampler_output_to_torch +from vllm.worker.worker_base import WorkerBase + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: WorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_proposals( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + proposal_len: int, + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + maybe_sampler_output, transposed = self._worker.sampler_output( + seq_group_metadata_list=nonzero_proposal_len_seqs, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + sample_len=proposal_len, + ) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals( + proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + ) + + return proposals + + def _split_by_max_model_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Determine which sequences would exceed the max model length.""" + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall no exccess this + # quota for nonzero_proposal + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + proposal_lens.append(proposal_len) + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + else: + proposal_lens.append(0) + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[SamplerOutput], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.full( + size=( + batch_size, + proposal_len, + ), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + proposal_probs = torch.zeros( + batch_size, + proposal_len, + self._vocab_size, + dtype=torch.float32, + device=self._device, + ) + proposal_lens_tensor = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = torch.full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = torch.zeros( + batch_size, + *proposal_probs.shape[1:], + dtype=torch.float32, + device=self._device, + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index eb6d4ca1da8e6..894d2fd915948 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -49,10 +49,13 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], -) -> Tuple[torch.Tensor, torch.Tensor]: + sampler_output_list: List[SamplerOutput], + sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: """Utility function which converts a list of SamplerOutput to tensors. + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + Returns: sampled_token_ids: torch.Tensor shape: [batch_size, len(sampler_output_list)] @@ -68,7 +71,10 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( @@ -77,7 +83,9 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + if sampler_transposed: + sampled_token_ids = sampled_token_ids.transpose(0, 1) return sampled_token_ids, sampled_token_probs From 24750f4cadd15a2b3a52f982e39eb9803749efbc Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Thu, 2 May 2024 02:20:32 +0800 Subject: [PATCH 176/413] [Core] Enable prefix caching with block manager v2 enabled (#4142) Co-authored-by: Lei Wen Co-authored-by: Sage Moore --- benchmarks/benchmark_prefix_caching.py | 16 +- tests/core/block/e2e/test_correctness.py | 146 +++++++++++++++ tests/core/block/test_prefix_caching_block.py | 125 +++++++++++++ vllm/core/block/cpu_gpu_block_allocator.py | 12 +- vllm/core/block/interfaces.py | 4 + vllm/core/block/naive_block.py | 11 +- vllm/core/block/prefix_caching_block.py | 172 ++++++++++++++---- vllm/core/block_manager_v1.py | 2 +- vllm/core/block_manager_v2.py | 31 ++-- vllm/core/{evictor.py => evictor_v1.py} | 0 vllm/core/evictor_v2.py | 122 +++++++++++++ 11 files changed, 584 insertions(+), 57 deletions(-) rename vllm/core/{evictor.py => evictor_v1.py} (100%) create mode 100644 vllm/core/evictor_v2.py diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 1f3274a28cad5..089966986984f 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): def main(args): - llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", + llm = LLM(model=args.model, tokenizer_mode='auto', trust_remote_code=True, enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching) num_prompts = 100 prompts = [PROMPT] * num_prompts - sampling_params = SamplingParams(temperature=0, max_tokens=100) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) print("------warm up------") test_prefix( llm=llm, - prompts=prompts[:1], + prompts=prompts, sampling_params=sampling_params, ) @@ -45,8 +47,16 @@ def main(args): parser = argparse.ArgumentParser( description='Benchmark the performance with or without automatic ' 'prefix caching.') + parser.add_argument('--model', + type=str, + default='baichuan-inc/Baichuan2-13B-Chat') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 0ee78a9b0a8ea..c3666da7542b5 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -300,6 +300,152 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Enable prefill cache + "enable_prefix_caching": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size): + """Verify block manager v2 produces same outputs as block manager v1, even + when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that the KV + cache is not corrupted in the v2 block manager. + + NOTE: We want a significant number of generated tokens so that any incorrect + KV mapping has time to build up error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Test APC in v2 block + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "enable_prefix_caching": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_auto_prefix_caching_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager v2 with auto prefix caching enabled produces same + outputs as auto prefix caching disabled, even when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that auto + prefix caching itself at least don't cause result error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with APC disabled') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with APC enabled') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5f4d58dd5fd39..c4c680e109a84 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -358,6 +358,131 @@ def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, i) allocator.free(block) + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_get_common_computed_block_ids(num_blocks: int, block_size: int, + seed: int): + """Verify get_common_computed_block_ids could get correct result + by create two immutable chain sharing prefix at specified pos, + and compare whether we also could get right result + from get_common_computed_block_ids. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, + block_size=block_size) + num_blocks_to_consume = random.randint(1, num_blocks - 1) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks_to_consume * block_size)) + blocks = list(range(num_blocks_to_consume)) + + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # mark all blocks in first chain as computed + allocator.mark_blocks_as_computed(blocks) + + # After zero_point, second_chain's token_ids would be set -1, which + # make it different from here comparing with first_chain + zero_point = random.randint(1, len(token_ids) - 1) + zero_point_blocks = zero_point // block_size + token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) + + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + first_computed_ids = [ + first_chain[i].block_id for i in range(num_blocks_to_consume) + ] + second_computed_ids = [ + second_chain[i].block_id for i in range(num_blocks_to_consume) + ] + res = allocator.get_common_computed_block_ids( + [first_computed_ids, second_computed_ids]) + + assert (len(res) == zero_point_blocks) + + # Test case where two last accessed times are equal + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_eviction_order(num_blocks: int, block_size: int, seed: int): + """This test case simulate the two chain created and free in order, + and together they would exhaust the initial freed blocks. + + So the next block created after those two chain shall use the block + from the first chain as that block has long access time. + While first chain has two blocks, it shall pick up the last one, as + it has larger token number. + """ + + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = num_blocks + 1 + + token_ids = list(range(num_blocks_to_consume * block_size)) + + num_blocks_in_first_chain = 2 + num_tokens_in_first_chain = block_size * num_blocks_in_first_chain + # First chain takes the first block + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[:num_tokens_in_first_chain], + allocator=allocator, + ) + # There should only be one block allocated at this point + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_in_first_chain) + + # Set the last accessed time of the first block to 1 + blocks_ids = [block.block_id for block in first_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 1) + + # Second chain takes the rest of the blocks + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[num_tokens_in_first_chain:-block_size], + allocator=allocator, + ) + + # There shouldn't be any blocks left at this point + assert allocator.get_num_free_blocks() == (0) + + assert len(first_chain) == num_blocks_in_first_chain + last_block_id = first_chain[-1].block_id + # Free each block in the first chain. + for i, block in enumerate(first_chain): + allocator.free(block) + + # Set the last accessed time on all of the blocks in the second chain + # to 2 + blocks_ids = [block.block_id for block in second_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 2) + + # Free each block in the second chain. + for i, block in enumerate(second_chain): + allocator.free(block) + + # Allocate a new block and check that it's the least recently used block + # from the first chain. + new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[-block_size:], + allocator=allocator, + ) + + assert new_block[0].block_id == last_block_id + @staticmethod def create_immutable_chain( block_size: int, diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3135e194c5937..23e1a4cf91266 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -190,10 +190,18 @@ def clear_copy_on_writes(self) -> Dict[int, List[int]]: device = Device.GPU return self._allocators[device].clear_copy_on_writes() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" # Prefix caching only supported on GPU. device = Device.GPU - return self._allocators[device].mark_blocks_as_computed() + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 50ce922118124..440d6a4b04d3b 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -81,6 +81,10 @@ def all_block_ids(self) -> FrozenSet[int]: def clear_copy_on_writes(self) -> Dict[int, List[int]]: pass + @abstractmethod + def mark_blocks_as_accessed(self) -> None: + pass + @abstractmethod def mark_blocks_as_computed(self) -> None: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index f8e9265bb2d67..a0bf33912d935 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -174,7 +174,16 @@ def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching. Since the naive allocator does not implement prefix caching, we do diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 6aa75a8abb80a..292a750146ae6 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -7,10 +7,16 @@ get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor PrefixHash = int BlockId = int +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 + class PrefixCachingBlockAllocator(BlockAllocator): """A block allocator that implements prefix caching. @@ -27,22 +33,19 @@ class PrefixCachingBlockAllocator(BlockAllocator): from 0 to num_blocks - 1. """ - # TODO last access time / evictor integration - def __init__( self, num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, + eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU, ): # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset - # of self._cached_blocks. - self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {} + # A mapping of blockId to Block to track those cached blocks + self._blocks: Dict[BlockId, Block] = {} # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( @@ -54,6 +57,10 @@ def __init__( self._block_size = block_size + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.evictor: Evictor = make_evictor(eviction_policy) + # We share the refcounter between allocators. This allows us to promote # blocks originally allocated in the hashless allocator to immutable # blocks. @@ -72,6 +79,7 @@ def _create_block( block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, + computed: Optional[bool] = False, ) -> Block: # Bind block to self. allocator = self @@ -82,6 +90,7 @@ def _create_block( block_size=block_size, block_id=block_id, prefix_caching_allocator=allocator, + computed=computed, ) def allocate_immutable(self, prev_block: Optional[Block], @@ -109,14 +118,12 @@ def allocate_immutable(self, prev_block: Optional[Block], cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: block.block_id = cached_block_id - self._incr_refcount_cached_block(block.content_hash, - block.block_id) + self._incr_refcount_cached_block(block, block.block_id) return block block = self.allocate_mutable(prev_block) block.append_token_ids(token_ids) assert block.content_hash is not None - # TODO computed bit return block @@ -133,41 +140,67 @@ def allocate_mutable(self, prev_block: Block) -> Block: assert_prefix_caching_block_or_none(prev_block) try: - return self._hashless_allocator.allocate_mutable( + block = self._hashless_allocator.allocate_mutable( prev_block=prev_block) + + assert block.block_id not in self._blocks + self._blocks[block.block_id] = block + return block except BlockAllocator.NoFreeBlocksError: # We must check the unused cached blocks before raising OOM. pass - if self._unused_cached_blocks: - # TODO policy for selecting block to remove - content_hash_to_evict = next(iter(self._unused_cached_blocks)) + # If the evictor has blocks available for eviction, evict a block + # and return it. + if self.evictor.num_blocks > 0: + block_id, content_hash_to_evict = self.evictor.evict() + + # Here we may have scenario that several blocks have + # the same content hash, but due to the latter coming block + # is coming from mutable to immutable path, their physical + # block is added into evictor. + # However in this case, we shall not pop the _cached_blocks, + # as the same content is still used by others, which means + # we need to check ref before decide to pop the list. - # Clear content hash mapping; the block will be overwritten. - del self._cached_blocks[content_hash_to_evict] + _block_id = self._cached_blocks[content_hash_to_evict] + refcount = self._refcounter.get(_block_id) + if refcount == 1: + self._cached_blocks.pop(content_hash_to_evict) + assert _block_id == block_id - block_id = self._unused_cached_blocks.pop(content_hash_to_evict) - refcount = self._refcounter.incr(block_id) - assert refcount == 1 + self._refcounter.incr(block_id) + + # the block comes from evictor already contain computed result block = self._create_block( prev_block=prev_block, token_ids=[], block_size=self._block_size, allocator=self, block_id=block_id, + computed=True, ) assert block.content_hash is None + + assert block.block_id not in self._blocks + self._blocks[block.block_id] = block return block # No block available in hashless allocator, nor in unused cache blocks. raise BlockAllocator.NoFreeBlocksError() - def _incr_refcount_cached_block(self, content_hash: int, + def _incr_refcount_cached_block(self, block: Block, block_id: BlockId) -> None: + # since block is already computed, mark it + block.computed = True + refcount = self._refcounter.incr(block_id) if refcount == 1: - assert content_hash in self._unused_cached_blocks - del self._unused_cached_blocks[content_hash] + # if block get referred, then it shall not be in evictor + # and put it into _blocks for tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + self._blocks[block_id] = block def free(self, block: Block) -> None: """Decrement the refcount of the block. If the decremented refcount is @@ -180,6 +213,7 @@ def free(self, block: Block) -> None: is not None), "freeing unallocated block is undefined" self._free_block_id_for_block(block.block_id, block) + block.block_id = None def _free_block_id_for_block(self, block_id: BlockId, @@ -187,15 +221,21 @@ def _free_block_id_for_block(self, block_id: BlockId, assert isinstance(block, PrefixCachingBlock) if block.content_hash is None: + refcount = self._refcounter.get(block_id) + # We have fork case where block would get more than one ref, + # so we cannot free it from tracking if ref cnt large than 1 + if refcount <= 1: + del self._blocks[block.block_id] return self._hashless_allocator.free(block) refcount = self._refcounter.decr(block_id) - # If no longer used, add the block to the unused cached blocks. + # If no longer used, add the block to the evictor. if refcount == 0: - assert block.content_hash not in self._unused_cached_blocks assert block.content_hash in self._cached_blocks - self._unused_cached_blocks[block.content_hash] = block_id + del self._blocks[block.block_id] + self.evictor.add(block.block_id, block.content_hash, + block.num_tokens_total, block.last_accessed) def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -230,9 +270,9 @@ def fork(self, last_block: Block) -> List[Block]: def get_num_free_blocks(self) -> int: # The number of free blocks is the number of hashless free blocks - # plus the number of hashful blocks that are unused. - return self._hashless_allocator.get_num_free_blocks() + len( - self._unused_cached_blocks) + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks @property def all_block_ids(self) -> frozenset[int]: @@ -266,7 +306,7 @@ def promote_to_immutable_block(self, else: self._free_block_id_for_block(block.block_id, block) self._incr_refcount_cached_block( - block.content_hash, self._cached_blocks[block.content_hash]) + block, self._cached_blocks[block.content_hash]) return self._cached_blocks[block.content_hash] @@ -293,29 +333,60 @@ def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if block_id in self._blocks: + self._blocks[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching.""" - # TODO Track computed blocks. - pass + + for block_id in block_ids: + if block_id in self._blocks: + # only those full block is valid for prefix caching + if self._blocks[block_id].is_full: + self._blocks[block_id].computed = True + elif block_id not in self.evictor: + raise ValueError(f"Mark {block_id=} as computed which " + "is not belonged to GPU") + + def block_is_computed(self, block_id: int) -> bool: + if block_id in self._blocks: + return self._blocks[block_id].computed + else: + return block_id in self.evictor def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. - Used in prefill (can skip prefill of some blocks). + Only those blocks that are immutable and already be marked + compyted would be taken consideration. """ - # TODO: Track computed blocks. - computed = lambda block_id: False - # NOTE We exclude the last block to avoid the case where the entire # prompt is cached. This would cause erroneous behavior in model # runner. + ids_list = [ - takewhile(lambda block_id: computed(block_id), seq[:-1]) - for seq in seq_block_ids + list( + takewhile(lambda block_id: self.block_is_computed(block_id), + seq[:-1])) for seq in seq_block_ids ] - return commonprefix([ids for ids in ids_list if ids != []]) + res = commonprefix([ids for ids in ids_list if ids != []]) + return res class PrefixCachingBlock(Block): @@ -345,12 +416,16 @@ def __init__( block_size: int, prefix_caching_allocator: PrefixCachingBlockAllocator, block_id: Optional[int] = None, + computed: Optional[bool] = False, ): assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: Optional[int] = None self._prefix_caching_allocator = prefix_caching_allocator + self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME + self.computed = computed self._block = NaiveBlock( prev_block=prev_block, @@ -398,6 +473,27 @@ def is_full(self) -> bool: def num_empty_slots(self) -> int: return self._block.num_empty_slots + @property + def num_tokens_total(self) -> int: + """return the total tokens so far. + + Here we iterate the block chain till to the first block, while + cache the result in local to prevent repeated computations. + """ + if self._cached_num_tokens_total is not None: + return self._cached_num_tokens_total + + _block = self + self._cached_num_tokens_total = 0 + + # TODO: current implement here take O(N^2), we expect future + # we have O(1) here + while _block is not None: + self._cached_num_tokens_total += len(_block.token_ids) + _block = _block.prev_block + + return self._cached_num_tokens_total + @property def block_size(self) -> int: return self._block.block_size diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 73e7dafb72c7f..4a9a2999e3913 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,7 +8,7 @@ from typing import Set from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.sequence import Sequence, SequenceGroup, SequenceStatus diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6339a6baf4161..0857605e2d005 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -72,14 +72,12 @@ def __init__( self.watermark = watermark assert watermark >= 0.0 - assert not enable_caching, "Prefix caching not yet supported" self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) self.block_allocator = CpuGpuBlockAllocator.create( - # Currently, only naive blocks are supported (no prefix caching). - allocator_type="naive", + allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, @@ -194,17 +192,26 @@ def get_block_table(self, seq: Sequence) -> List[int]: assert all(b is not None for b in block_ids) return block_ids - def access_all_blocks_in_seq(self, seq, now): - # TODO add prefix caching support. - # Tracked here https://github.com/vllm-project/vllm/issues/3667 - pass + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + # Update the last accessed time of all the blocks accessed + # in this step. + # And the accessed time is only useful for prefix caching now, + # as it support internal evictor policy for which cached + # block could be refilled, to keep cached content could be reused + # at max extend. + if self.enable_caching: + block_table = self.block_tables[seq.seq_id] + block_ids = [] + for block_id in block_table.physical_block_ids: + block_ids.append(block_id) + self.block_allocator.mark_blocks_as_accessed(block_ids, now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # We ignore the sequence group as its not necessary. After the batch is - # formed by the scheduler, we do not need to mark blocks from individual - # sequence groups as computed -- all blocks in the batch can be marked - # as computed. - self.block_allocator.mark_blocks_as_computed() + # The only need for mark block as computed is for prefix caching, + # while currently we could determine whether one block is computed + # or not by check whether it has content hash. + # So this function is useless for block_v2. + pass def get_common_computed_block_ids( self, seqs: List[Sequence]) -> GenericSequence[int]: diff --git a/vllm/core/evictor.py b/vllm/core/evictor_v1.py similarity index 100% rename from vllm/core/evictor.py rename to vllm/core/evictor_v1.py diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py new file mode 100644 index 0000000000000..b902a39263d14 --- /dev/null +++ b/vllm/core/evictor_v2.py @@ -0,0 +1,122 @@ +import enum +from abc import ABC, abstractmethod, abstractproperty +from typing import OrderedDict, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: int): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: int): + """Update corresponding block's access time in metadata""" + pass + + @abstractproperty + def num_blocks(self) -> int: + pass + + +class BlockMetaData(): + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: int): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict() + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + evicted_block = next(iter(self.free_table.values())) + evicted_block_id = next(iter(self.free_table.keys())) + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _id, block in self.free_table.items(): + if evicted_block.last_accessed > block.last_accessed or ( + evicted_block.last_accessed == block.last_accessed and + evicted_block.num_hashed_tokens < block.num_hashed_tokens): + evicted_block = block + evicted_block_id = _id + + self.free_table.pop(evicted_block_id) + + return evicted_block_id, evicted_block.content_hash + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: int): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + + def update(self, block_id: int, last_accessed: int): + self.free_table[block_id].last_accessed = last_accessed + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") From a657bfc48a11d87de146629a7b6c03e9ccfbc3fc Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 1 May 2024 11:41:59 -0700 Subject: [PATCH 177/413] [Core] Add `multiproc_worker_utils` for multiprocessing-based workers (#4357) --- tests/engine/test_multiproc_workers.py | 176 ++++++++++++++++ vllm/executor/multiproc_worker_utils.py | 264 ++++++++++++++++++++++++ 2 files changed, 440 insertions(+) create mode 100644 tests/engine/test_multiproc_workers.py create mode 100644 vllm/executor/multiproc_worker_utils.py diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py new file mode 100644 index 0000000000000..610ad9732fb91 --- /dev/null +++ b/tests/engine/test_multiproc_workers.py @@ -0,0 +1,176 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker.Worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.process.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(5) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 0000000000000..0c04796bc38e3 --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,264 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + +# Use dedicated multiprocess context for workers. +# Both spawn and fork work +mp_method = os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory() + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + "Exception in worker %s while processing method %s: %s, %s", + process_name, method, e, tb) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] From 24bb4fe432fffeccf7a27270ee70aff1b1b8a89a Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 1 May 2024 11:47:38 -0700 Subject: [PATCH 178/413] [Kernel] Update fused_moe tuning script for FP8 (#4457) This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo. All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens. Before this PR (with static activation scaling): qps = 1: 9.8 ms ITL, 0.49s e2e latency qps = 2: 9.7 ms ITL, 0.49s e2e latency qps = 4: 10.1 ms ITL, 0.52s e2e latency qps = 6: 11.9 ms ITL, 0.59s e2e latency qps = 8: 14.0 ms ITL, 0.70s e2e latency qps = 10: 15.7 ms ITL, 0.79s e2e latency After this PR (with static activation scaling): qps = 1: 9.8 ms ITL, 0.49s e2e latency qps = 2: 9.7 ms ITL, 0.49s e2e latency qps = 4: 10.2 ms ITL, 0.53s e2e latency qps = 6: 11.9 ms ITL, 0.59s e2e latency qps = 8: 11.9 ms ITL, 0.59s e2e latency qps = 10: 12.1 ms ITL, 0.61s e2e latency --- benchmarks/kernels/benchmark_mixtral_moe.py | 109 +++++++++----- ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 140 ++++++++++++++++++ 2 files changed, 211 insertions(+), 38 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 8e976fbcb3028..5280b214144c9 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -1,3 +1,4 @@ +import argparse import json import os import sys @@ -5,6 +6,7 @@ import torch import torch.nn.functional as F import triton +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) @@ -12,16 +14,16 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0' -def main(): +def main(dtype: str): method = fused_moe for bs in [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ]: - run_grid(bs, method=method) + run_grid(bs, method=method, dtype=dtype) -def run_grid(bs, method): +def run_grid(bs, method, dtype: str): d_model = 4096 num_total_experts = 8 top_k = 2 @@ -34,39 +36,29 @@ def run_grid(bs, method): num_trials = 1 configs = [] - if bs <= 16: - BLOCK_SIZES_M = [16] - elif bs <= 32: - BLOCK_SIZES_M = [16, 32] - elif bs <= 64: - BLOCK_SIZES_M = [16, 32, 64] - elif bs <= 128: - BLOCK_SIZES_M = [16, 32, 64, 128] - else: - BLOCK_SIZES_M = [16, 32, 64, 128, 256] for block_size_n in [32, 64, 128, 256]: - for block_size_m in BLOCK_SIZES_M: + for block_size_m in [16, 32, 64, 128, 256]: for block_size_k in [64, 128, 256]: for group_size_m in [1, 16, 32, 64]: for num_warps in [4, 8]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": 4, - }) + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) best_config = None best_time_us = 1e20 - for config in configs: - print(f'{tp_size=} {bs=}') - print(f'{config}') + print(f'{tp_size=} {bs=}') + + for config in tqdm(configs): # warmup - print('warming up') try: for _ in range(num_warmup_trials): run_timing( @@ -79,12 +71,12 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) except triton.runtime.autotuner.OutOfResources: continue # trial - print('benchmarking') for _ in range(num_trials): kernel_dur_ms = run_timing( num_calls=num_calls, @@ -96,6 +88,7 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) kernel_dur_us = 1000 * kernel_dur_ms @@ -105,16 +98,18 @@ def run_grid(bs, method): best_config = config best_time_us = kernel_dur_us - print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') print("best_time_us", best_time_us) print("best_config", best_config) # holds Dict[str, Dict[str, int]] filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size) + model_intermediate_size // tp_size, + "float8" if dtype == "float8" else None) print(f"writing config to file {filename}") existing_content = {} if os.path.exists(filename): @@ -128,27 +123,48 @@ def run_grid(bs, method): def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, top_k: int, tp_size: int, model_intermediate_size: int, method, - config) -> float: + config, dtype: str) -> float: shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( (bs, d_model), device="cuda:0", - dtype=torch.bfloat16, + dtype=torch.float16, ) - ws = torch.rand( + w1 = torch.rand( (num_total_experts, 2 * shard_intermediate_size, d_model), device=hidden_states.device, dtype=hidden_states.dtype, ) - w2s = torch.rand( + w2 = torch.rand( (num_total_experts, d_model, shard_intermediate_size), device=hidden_states.device, dtype=hidden_states.dtype, ) + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + gating_output = F.softmax(torch.rand( (num_calls, bs, num_total_experts), device=hidden_states.device, @@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, for i in range(num_calls): hidden_states = method( hidden_states=hidden_states, - w1=ws, - w2=w2s, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, gating_output=gating_output[i], topk=2, renormalize=True, inplace=True, override_config=config, + use_fp8=dtype == "float8", ) end_event.record() end_event.synchronize() @@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, if __name__ == "__main__": - sys.exit(main()) + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + args = parser.parse_args() + sys.exit(main(args.dtype)) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..9287808a94d0e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,140 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} From c47ba4aaa94d067bbb0437526cae9a33c698c717 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 1 May 2024 19:31:22 +0000 Subject: [PATCH 179/413] [Bugfix] Add validation for seed (#4529) --- tests/entrypoints/test_openai_server.py | 20 ++++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 8 ++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index a2a98abe7031c..1323dba469117 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -13,6 +13,7 @@ # and debugging. import ray import requests +import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -870,5 +871,24 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +async def test_long_seed(server, client: openai.AsyncOpenAI): + for seed in [ + torch.iinfo(torch.long).min - 1, + torch.iinfo(torch.long).max + 1 + ]: + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + }], + temperature=0, + seed=seed) + + assert ("greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 731596e80bd71..3cd9ddad3b7b7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -79,7 +79,9 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False temperature: Optional[float] = 0.7 @@ -228,7 +230,9 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False suffix: Optional[str] = None From 3a922c1e7ee6753f41c6cc9d6d47d3b2d0110447 Mon Sep 17 00:00:00 2001 From: Roy Date: Thu, 2 May 2024 04:08:14 +0800 Subject: [PATCH 180/413] [Bugfix][Core] Fix and refactor logging stats (#4336) --- vllm/engine/async_llm_engine.py | 14 +++++++++----- vllm/engine/llm_engine.py | 12 +++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4aceb19b50776..5591893d267a2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -8,6 +8,7 @@ from transformers import PreTrainedTokenizer from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -15,7 +16,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData +from vllm.sequence import MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -224,8 +225,7 @@ async def step_async(self) -> List[RequestOutput]: scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) + self.do_log_stats(scheduler_outputs, output) return request_outputs @@ -707,9 +707,13 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() - async def do_log_stats(self) -> None: + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: if self.engine_use_ray: - await self.engine.do_log_stats.remote() # type: ignore + await self.engine.do_log_stats.remote( # type: ignore + scheduler_outputs, model_output) else: self.engine.do_log_stats() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4caecb8a51598..19e7143ac2b45 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -597,16 +597,18 @@ def step(self) -> List[RequestOutput]: scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. - if self.log_stats: - self.stat_logger.log( - self._get_stats(scheduler_outputs, model_output=output)) + self.do_log_stats(scheduler_outputs, output) return request_outputs - def do_log_stats(self) -> None: + def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs=None)) + self.stat_logger.log( + self._get_stats(scheduler_outputs, model_output)) def _get_stats( self, From 6ef09b08f88b675f84b7140238286e5d4c5304c8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 1 May 2024 15:23:06 -0700 Subject: [PATCH 181/413] [Core][Distributed] fix pynccl del error (#4508) --- vllm/distributed/device_communicators/pynccl.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 9434867e1b120..f21fcd262d810 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -200,6 +200,10 @@ def from_torch(cls, op: ReduceOp) -> int: ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p ] +# be cautious! this is a collective call, it will block until all +# processes in the communicator have called this function. +# because Python object destruction can happen in random order, +# it is better not to call it at all. # equivalent to c declaration: # ncclResult_t ncclCommDestroy(ncclComm_t comm); _c_ncclCommDestroy = nccl.ncclCommDestroy @@ -278,11 +282,3 @@ def all_reduce(self, ncclDataTypeEnum.from_torch(tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, ctypes.c_void_p(stream.cuda_stream))) - - def __del__(self): - # `dist` module might have been already destroyed - if hasattr(dist, 'destroy_process_group'): - dist.destroy_process_group() - # function might have been already destroyed - if _c_ncclCommDestroy is not None: - _c_ncclCommDestroy(self.comm) From c9d852d601ce1a02f6748ab62db8694c22772583 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 1 May 2024 16:30:52 -0700 Subject: [PATCH 182/413] [Misc] Remove Mixtral device="cuda" declarations (#4543) Remove the device="cuda" declarations in mixtral as promised in #4343 --- vllm/model_executor/models/mixtral.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c5dd1a63e2f7a..9ff9ba298588a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -96,13 +96,11 @@ def __init__( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda", dtype=self.params_dtype)) set_weight_attrs(self.ws, { @@ -114,22 +112,20 @@ def __init__( # Scaling factors for FP8 weights self.ws_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), + torch.ones(self.num_total_experts, dtype=torch.float32), requires_grad=False) if self.use_fp8 else None self.w2s_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), + torch.ones(self.num_total_experts, dtype=torch.float32), requires_grad=False) if self.use_fp8 else None # Scaling factors for FP8 activations need_act_scales = (self.use_fp8 and quant_config.activation_scheme == "static") self.as_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), requires_grad=False) if need_act_scales else None self.a2s_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), requires_grad=False) if need_act_scales else None if need_act_scales: From 826b82a260ebb1ea7edd04a3278d5fb9b103a76e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 1 May 2024 16:47:59 -0700 Subject: [PATCH 183/413] [Misc] Fix expert_ids shape in MoE (#4517) --- vllm/model_executor/layers/fused_moe/fused_moe.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b4f81527141a8..3cb0419404625 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -203,14 +203,15 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ - sorted_ids = torch.empty( - (topk_ids.numel() + num_experts * (block_size - 1), ), - dtype=torch.int32, - device=topk_ids.device) - expert_ids = torch.empty((topk_ids.numel() + num_experts, ), + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) From b8afa8b95a4eee008a9b72440620113e5bfbe962 Mon Sep 17 00:00:00 2001 From: Danny Guinther Date: Wed, 1 May 2024 20:34:40 -0400 Subject: [PATCH 184/413] [MISC] Rework logger to enable pythonic custom logging configuration to be provided (#4273) --- examples/logging_configuration.md | 178 ++++++++++++++++++++++++++++ tests/test_logger.py | 189 +++++++++++++++++++++++++++++- vllm/logger.py | 112 ++++++++++-------- vllm/logging/__init__.py | 5 + vllm/logging/formatter.py | 15 +++ 5 files changed, 451 insertions(+), 48 deletions(-) create mode 100644 examples/logging_configuration.md create mode 100644 vllm/logging/__init__.py create mode 100644 vllm/logging/formatter.py diff --git a/examples/logging_configuration.md b/examples/logging_configuration.md new file mode 100644 index 0000000000000..75b4b31a80462 --- /dev/null +++ b/examples/logging_configuration.md @@ -0,0 +1,178 @@ +# Logging Configuration + +vLLM leverages Python's `logging.config.dictConfig` functionality to enable +robust and flexible configuration of the various loggers used by vLLM. + +vLLM offers two environment variables that can be used to accommodate a range +of logging configurations that range from simple-and-inflexible to +more-complex-and-more-flexible. + +- No vLLM logging (simple and inflexible) + - Set `VLLM_CONFIGURE_LOGGING=0` (leaving `VLLM_LOGGING_CONFIG_PATH` unset) +- vLLM's default logging configuration (simple and inflexible) + - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` +- Fine-grained custom logging configuration (more complex, more flexible) + - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and + set `VLLM_LOGGING_CONFIG_PATH=` + + +## Logging Configuration Environment Variables + +### `VLLM_CONFIGURE_LOGGING` + +`VLLM_CONFIGURE_LOGGING` controls whether or not vLLM takes any action to +configure the loggers used by vLLM. This functionality is enabled by default, +but can be disabled by setting `VLLM_CONFIGURE_LOGGING=0` when running vLLM. + +If `VLLM_CONFIGURE_LOGGING` is enabled and no value is given for +`VLLM_LOGGING_CONFIG_PATH`, vLLM will use built-in default configuration to +configure the root vLLM logger. By default, no other vLLM loggers are +configured and, as such, all vLLM loggers defer to the root vLLM logger to make +all logging decisions. + +If `VLLM_CONFIGURE_LOGGING` is disabled and a value is given for +`VLLM_LOGGING_CONFIG_PATH`, an error will occur while starting vLLM. + +### `VLLM_LOGGING_CONFIG_PATH` + +`VLLM_LOGGING_CONFIG_PATH` allows users to specify a path to a JSON file of +alternative, custom logging configuration that will be used instead of vLLM's +built-in default logging configuration. The logging configuration should be +provided in JSON format following the schema specified by Python's [logging +configuration dictionary +schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details). + +If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is +disabled, an error will occur while starting vLLM. + + +## Examples + +### Example 1: Customize vLLM root logger + +For this example, we will customize the vLLM root logger to use +[`python-json-logger`](https://github.com/madzak/python-json-logger) to log to +STDOUT of the console in JSON format with a log level of `INFO`. + +To begin, first, create an appropriate JSON logging configuration file: + +**/path/to/logging_config.json:** + +```json +{ + "formatters": { + "json": { + "class": "pythonjsonlogger.jsonlogger.JsonFormatter" + } + }, + "handlers": { + "console": { + "class" : "logging.StreamHandler", + "formatter": "json", + "level": "INFO", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "vllm": { + "handlers": ["console"], + "level": "INFO", + "propagate": false + } + }, + "version": 1 +} +``` + +Next, install the `python-json-logger` package if it's not already installed: + +```bash +pip install python-json-logger +``` + +Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set +to the path of the custom logging configuration JSON file: + +```bash +VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +### Example 2: Silence a particular vLLM logger + +To silence a particular vLLM logger, it is necessary to provide custom logging +configuration for the target logger that configures the logger so that it won't +propagate its log messages to the root vLLM logger. + +When custom configuration is provided for any logger, it is also necessary to +provide configuration for the root vLLM logger since any custom logger +configuration overrides the built-in default logging configuration used by vLLM. + +First, create an appropriate JSON logging configuration file that includes +configuration for the root vLLM logger and for the logger you wish to silence: + +**/path/to/logging_config.json:** + +```json +{ + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": "%m-%d %H:%M:%S", + "format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" + } + }, + "handlers": { + "vllm": { + "class" : "logging.StreamHandler", + "formatter": "vllm", + "level": "INFO", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagage": false + }, + "vllm.example_noisy_logger": { + "propagate": false + } + }, + "version": 1 +} +``` + +Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set +to the path of the custom logging configuration JSON file: + +```bash +VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +### Example 3: Disable vLLM default logging configuration + +To disable vLLM's default logging configuration and silence all vLLM loggers, +simple set `VLLM_CONFIGURE_LOGGING=0` when running vLLM. This will prevent vLLM +for configuring the root vLLM logger, which in turn, silences all other vLLM +loggers. + +```bash +VLLM_CONFIGURE_LOGGING=0 \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +## Additional resources + +- [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details) diff --git a/tests/test_logger.py b/tests/test_logger.py index 601f72b50811c..74f1125fb37c9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,8 +1,19 @@ +import json +import logging import os import sys import tempfile +from json.decoder import JSONDecodeError +from tempfile import NamedTemporaryFile +from typing import Any +from unittest.mock import patch +from uuid import uuid4 -from vllm.logger import enable_trace_function_call +import pytest + +from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, + enable_trace_function_call, init_logger) +from vllm.logging import NewLineFormatter def f1(x): @@ -25,3 +36,179 @@ def test_trace_function_call(): assert "f2" in content sys.settrace(None) os.remove(path) + + +def test_default_vllm_root_logger_configuration(): + """This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and + VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + logger = logging.getLogger("vllm") + assert logger.level == logging.DEBUG + assert not logger.propagate + + handler = logger.handlers[0] + assert handler.stream == sys.stdout + assert handler.level == logging.INFO + + formatter = handler.formatter + assert formatter is not None + assert isinstance(formatter, NewLineFormatter) + assert formatter._fmt == _FORMAT + assert formatter.datefmt == _DATE_FORMAT + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None) +def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): + """This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and + VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + root_logger = logging.getLogger("vllm") + root_handler = root_logger.handlers[0] + + unique_name = f"vllm.{uuid4()}" + logger = init_logger(unique_name) + assert logger.name == unique_name + assert logger.level == logging.NOTSET + assert not logger.handlers + assert logger.propagate + + message = "Hello, world!" + with patch.object(root_handler, "emit") as root_handle_mock: + logger.info(message) + + root_handle_mock.assert_called_once() + _, call_args, _ = root_handle_mock.mock_calls[0] + log_record = call_args[0] + assert unique_name == log_record.name + assert message == log_record.msg + assert message == log_record.msg + assert log_record.levelno == logging.INFO + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) +@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None) +def test_logger_configuring_can_be_disabled(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + + with patch("logging.config.dictConfig") as dict_config_mock: + _configure_vllm_root_logger() + dict_config_mock.assert_not_called() + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@patch( + "vllm.logger.VLLM_LOGGING_CONFIG_PATH", + "/if/there/is/a/file/here/then/you/did/this/to/yourself.json", +) +def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with pytest.raises(RuntimeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == RuntimeError + assert "File does not exist" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write("---\nloggers: []\nversion: 1") + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(JSONDecodeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == JSONDecodeError + assert "Expecting value" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@pytest.mark.parametrize("unexpected_config", ( + "Invalid string", + [{ + "version": 1, + "loggers": [] + }], + 0, +)) +def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( + unexpected_config: Any): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(unexpected_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(ValueError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == ValueError + assert "Invalid logging config. Expected Dict, got" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +def test_custom_logging_config_is_parsed_and_used_when_provided(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "vllm.test_logger.logger": { + "handlers": [], + "propagate": False, + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name), patch( + "logging.config.dictConfig") as dict_config_mock: + _configure_vllm_root_logger() + assert dict_config_mock.called_with(valid_logging_config) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) +def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "vllm.test_logger.logger": { + "handlers": [], + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(RuntimeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type is RuntimeError + expected_message_snippet = ( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given.") + assert expected_message_snippet in str(ex_info) + + # Remember! The root logger is assumed to have been configured as + # though VLLM_CONFIGURE_LOGGING=1 and VLLM_LOGGING_CONFIG_PATH=None. + root_logger = logging.getLogger("vllm") + other_logger_name = f"vllm.test_logger.{uuid4()}" + other_logger = init_logger(other_logger_name) + assert other_logger.handlers != root_logger.handlers + assert other_logger.level != root_logger.level + assert other_logger.propagate diff --git a/vllm/logger.py b/vllm/logger.py index 3928e5367d1e6..40c29da2b70ce 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,73 +1,91 @@ -# Adapted from -# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" import datetime +import json import logging import os import sys from functools import partial -from typing import Optional +from logging import Logger +from logging.config import dictConfig +from os import path +from typing import Dict, Optional VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) +VLLM_LOGGING_CONFIG_PATH = os.getenv("VLLM_LOGGING_CONFIG_PATH") _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "vllm": { + "class": "logging.StreamHandler", + "formatter": "vllm", + "level": "INFO", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagate": False, + }, + }, + "version": 1, +} + + +def _configure_vllm_root_logger() -> None: + logging_config: Optional[Dict] = None + + if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: + raise RuntimeError( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " + "implies VLLM_CONFIGURE_LOGGING. Please enable " + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") -class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" + if VLLM_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG - def __init__(self, fmt, datefmt=None): - logging.Formatter.__init__(self, fmt, datefmt) + if VLLM_LOGGING_CONFIG_PATH: + if not path.exists(VLLM_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + VLLM_LOGGING_CONFIG_PATH) + with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8", + mode="r") as file: + custom_config = json.loads(file.read()) - def format(self, record): - msg = logging.Formatter.format(self, record) - if record.message != "": - parts = msg.split(record.message) - msg = msg.replace("\n", "\r\n" + parts[0]) - return msg + if not isinstance(custom_config, dict): + raise ValueError("Invalid logging config. Expected Dict, got %s.", + type(custom_config).__name__) + logging_config = custom_config + if logging_config: + dictConfig(logging_config) -_root_logger = logging.getLogger("vllm") -_default_handler: Optional[logging.Handler] = None +def init_logger(name: str) -> Logger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root vllm logger has + already been configured.""" -def _setup_logger(): - _root_logger.setLevel(logging.DEBUG) - global _default_handler - if _default_handler is None: - _default_handler = logging.StreamHandler(sys.stdout) - _default_handler.flush = sys.stdout.flush # type: ignore - _default_handler.setLevel(logging.INFO) - _root_logger.addHandler(_default_handler) - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) - _default_handler.setFormatter(fmt) - # Setting this will avoid the message - # being propagated to the parent logger. - _root_logger.propagate = False + return logging.getLogger(name) -# The logger is initialized when the module is imported. +# The root logger is initialized when the module is imported. # This is thread-safe as the module is only imported once, # guaranteed by the Python GIL. -if VLLM_CONFIGURE_LOGGING: - _setup_logger() - - -def init_logger(name: str): - # Use the same settings as above for root logger - logger = logging.getLogger(name) - logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) - - if VLLM_CONFIGURE_LOGGING: - if _default_handler is None: - raise ValueError( - "_default_handler is not set up. This should never happen!" - " Please open an issue on Github.") - logger.addHandler(_default_handler) - logger.propagate = False - return logger - +_configure_vllm_root_logger() logger = init_logger(__name__) diff --git a/vllm/logging/__init__.py b/vllm/logging/__init__.py new file mode 100644 index 0000000000000..b9aec380776f3 --- /dev/null +++ b/vllm/logging/__init__.py @@ -0,0 +1,5 @@ +from vllm.logging.formatter import NewLineFormatter + +__all__ = [ + "NewLineFormatter", +] diff --git a/vllm/logging/formatter.py b/vllm/logging/formatter.py new file mode 100644 index 0000000000000..b24b4e11d1fcb --- /dev/null +++ b/vllm/logging/formatter.py @@ -0,0 +1,15 @@ +import logging + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg From 0d62fe58dbb58cfe4132005ce7ff37319d66981d Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 2 May 2024 11:24:13 +0900 Subject: [PATCH 185/413] [Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451) --- .buildkite/test-pipeline.yaml | 1 + .../basic_correctness/test_chunked_prefill.py | 1 - tests/basic_correctness/test_preemption.py | 138 ++++++++++++++++++ tests/conftest.py | 3 +- tests/spec_decode/test_spec_decode_worker.py | 6 +- vllm/core/scheduler.py | 36 ++++- 6 files changed, 172 insertions(+), 13 deletions(-) create mode 100644 tests/basic_correctness/test_preemption.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 11cda053260ec..641f366d06031 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -17,6 +17,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Core Test command: pytest -v -s core diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index d83416eb51b43..47d582c726c66 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -55,7 +55,6 @@ def test_models( ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model - print(vllm_outputs[0]) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py new file mode 100644 index 0000000000000..1adfc7dddd6fa --- /dev/null +++ b/tests/basic_correctness/test_preemption.py @@ -0,0 +1,138 @@ +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. + +Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 +pytest tests/basic_correctness/test_preemption.py`. +""" +import pytest + +from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, + ENABLE_ARTIFICIAL_PREEMPT) + +MODELS = [ + "facebook/opt-125m", +] + +assert ENABLE_ARTIFICIAL_PREEMPT is True, ( + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "tests/basic_correctness/test_preemption.py`") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_chunked_prefill_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + """Ensure that chunked prefill works with preemption.""" + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """By default, recompute preemption is enabled""" + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Use beam search enables swapping.""" + example_prompts = example_prompts[:1] + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, swap_space=10) + vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, _ = hf_outputs[i] + vllm_output_ids, _ = vllm_outputs[i] + assert len(hf_output_ids) == len(vllm_output_ids) + for j in range(len(hf_output_ids)): + assert hf_output_ids[j] == vllm_output_ids[j], ( + f"Test{i} output{j}:\nHF: {hf_output_ids}\n" + f"vLLM: {vllm_output_ids}") diff --git a/tests/conftest.py b/tests/conftest.py index 5c50fc2d1bab6..671326915b22b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -296,6 +296,7 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, + swap_space=4, **kwargs, ) -> None: self.model = LLM( @@ -303,7 +304,7 @@ def __init__( tokenizer=tokenizer_name, trust_remote_code=True, dtype=dtype, - swap_space=0, + swap_space=swap_space, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index d24d726c9c0cf..91315df9b5e60 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -33,7 +33,7 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) execute_model_data, _, _ = create_batch(batch_size, k) @@ -101,7 +101,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_probs=proposal_probs, proposal_lens=proposal_lens) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): @@ -197,7 +197,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): target_worker.execute_model.return_value = [target_output[0]] - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 024b7e7013441..b17b6cc7fe733 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,4 +1,6 @@ import enum +import os +import random import time from collections import deque from dataclasses import dataclass, field @@ -15,6 +17,13 @@ logger = init_logger(__name__) +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + class PreemptionMode(enum.Enum): """Preemption modes. @@ -286,6 +295,13 @@ def __init__( # Latency of the last prompt step self.last_prompt_latency = 0.0 + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + @property def lora_enabled(self) -> bool: return bool(self.lora_config) @@ -386,15 +402,13 @@ def _schedule_running( # groups to preempt. now = time.time() running_queue = policy.sort_by_priority(now, running_queue) - while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - # We can have up to 1 running prefill at any given time in running - # queue, which means we can guarantee chunk size is at least 1. - assert num_running_tokens != 0 + if num_running_tokens == 0: + break running_queue.popleft() while not self._can_append_slots(seq_group): @@ -449,9 +463,6 @@ def _schedule_running( if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - # Make sure all queues are updated. - assert len(running_queue) == 0 - return running_queue, SchedulerRunningOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, @@ -545,7 +556,6 @@ def _schedule_swapped( ScheduledSequenceGroup(seq_group, token_chunk_size=num_new_tokens)) else: - assert num_new_tokens == 1 decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) @@ -868,6 +878,13 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + # Appending slots only occurs in decoding. is_prefill = False @@ -1116,11 +1133,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. + + Returns 0 if the new token cannot be computed due to token budget. """ num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() + assert num_new_tokens > 0 # Chunk if a running request cannot fit in. # If number of seq > 1, it means it is doing beam search in a # decode phase. Do not chunk in that case. From 5e401bce17ae9b327020ade6ba0ddceea2853451 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Thu, 2 May 2024 05:57:12 +0300 Subject: [PATCH 186/413] [CI]Add regression tests to ensure the async engine generates metrics (#4524) --- tests/metrics/test_metrics.py | 94 +++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0ab9c63ce4377..311e60ba60f61 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,4 +1,10 @@ import pytest +from prometheus_client import REGISTRY + +from vllm import EngineArgs, LLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams MODELS = [ "facebook/opt-125m", @@ -68,3 +74,91 @@ def test_metric_counter_generation_tokens( assert vllm_generation_count == metric_count, ( f"generation token count: {vllm_generation_count!r}\n" f"metric: {metric_count!r}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +@pytest.mark.asyncio +async def test_async_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + """ + Regression test ensuring async engine generates metrics + when disable_log_stats=False + (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) + """ + engine_args = AsyncEngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + async_engine = AsyncLLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + results = async_engine.generate( + prompt, + SamplingParams(max_tokens=max_tokens), + f"request-id-{i}", + ) + # Exhaust the async iterator to make the async engine work + async for _ in results: + pass + + assert_metrics(async_engine.engine, disable_log_stats, + len(example_prompts)) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +def test_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + engine_args = EngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + engine = LLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + + assert_metrics(engine, disable_log_stats, len(example_prompts)) + + +def assert_metrics(engine: LLMEngine, disable_log_stats: bool, + num_requests: int) -> None: + if disable_log_stats: + with pytest.raises(AttributeError): + _ = engine.stat_logger + else: + assert (engine.stat_logger + is not None), "engine.stat_logger should be set" + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + labels = {'model_name': engine.model_config.model} + request_histogram_metrics = [ + "vllm:e2e_request_latency_seconds", + "vllm:request_prompt_tokens", + "vllm:request_generation_tokens", + "vllm:request_params_best_of", + "vllm:request_params_n", + ] + for metric_name in request_histogram_metrics: + metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", + labels) + assert ( + metric_value == num_requests), "Metrics should be collected" From cf8cac8c701079a3fda068ffd1cd6f72a490aa6d Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 2 May 2024 12:01:00 +0900 Subject: [PATCH 187/413] [mypy][6/N] Fix all the core subdirectory typing (#4450) Co-authored-by: Cade Daniel --- .github/workflows/mypy.yaml | 6 +- format.sh | 2 +- vllm/core/block/block_table.py | 16 ++-- vllm/core/block/common.py | 20 +++- vllm/core/block/cpu_gpu_block_allocator.py | 49 ++++++---- vllm/core/block/interfaces.py | 104 +++++++++++++++++++-- vllm/core/block/naive_block.py | 52 +++++++++-- vllm/core/block/prefix_caching_block.py | 85 ++++++++++++----- vllm/core/block_manager_v2.py | 9 +- vllm/core/evictor_v2.py | 15 ++- 10 files changed, 275 insertions(+), 83 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index a19be8525f902..5b2bad1476dc3 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -33,6 +33,7 @@ jobs: - name: Mypy run: | mypy vllm/attention --config-file pyproject.toml + mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml @@ -42,9 +43,6 @@ jobs: mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml - mypy vllm/lora --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml - - # TODO(sang): Fix nested dir - mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml diff --git a/format.sh b/format.sh index bd12e61d77806..49149afe41d04 100755 --- a/format.sh +++ b/format.sh @@ -95,7 +95,7 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' mypy vllm/attention --config-file pyproject.toml -mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index f1b65b2514f76..b0d9511fba521 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -40,7 +40,9 @@ def __init__( ): self._block_size = block_size self._allocator = block_allocator - self._blocks: Optional[List[Block]] = _blocks + if _blocks is None: + _blocks = [] + self._blocks: List[Block] = _blocks # Use helper method instead of directly calculating, as blocks # may not be allocated. @@ -104,7 +106,7 @@ def append_token_ids(self, token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated - assert self._blocks is not None + assert len(self._blocks) > 0 self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) @@ -141,6 +143,7 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 self._blocks.append( self._allocator.allocate_mutable(prev_block=self._blocks[-1], device=device)) @@ -159,6 +162,7 @@ def fork(self) -> "BlockTable": the current instance. """ assert self._is_allocated + assert len(self._blocks) > 0 forked_blocks = self._allocator.fork(self._blocks[-1]) return BlockTable( block_size=self._block_size, @@ -177,10 +181,10 @@ def free(self) -> None: assert self._is_allocated for block in self._blocks: self._allocator.free(block) - self._blocks = None + self._blocks = [] @property - def physical_block_ids(self) -> List[int]: + def physical_block_ids(self) -> List[Optional[int]]: """Returns a list of physical block indices for the blocks in the BlockTable. @@ -235,7 +239,7 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], def _get_all_token_ids(self) -> List[int]: # NOTE: This function is O(seq_len); use sparingly. - token_ids = [] + token_ids: List[int] = [] if not self._is_allocated: return token_ids @@ -247,7 +251,7 @@ def _get_all_token_ids(self) -> List[int]: @property def _is_allocated(self) -> bool: - return self._blocks is not None + return len(self._blocks) > 0 @property def _num_empty_slots(self) -> int: diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index f11234a0bf2dd..3f97a1210b096 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Protocol from vllm.core.block.interfaces import Block, BlockAllocator @@ -7,7 +7,19 @@ RefCount = int -class RefCounter: +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): """A class for managing reference counts for a set of block indices. The RefCounter class maintains a dictionary that maps block indices to their @@ -54,7 +66,7 @@ def as_readonly(self) -> "ReadOnlyRefCounter": return ReadOnlyRefCounter(self) -class ReadOnlyRefCounter: +class ReadOnlyRefCounter(RefCounterProtocol): """A read-only view of the RefCounter class. The ReadOnlyRefCounter class provides a read-only interface to access the @@ -96,7 +108,7 @@ class CopyOnWriteTracker: def __init__( self, - refcounter: RefCounter, + refcounter: RefCounterProtocol, allocator: BlockAllocator, ): self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 23e1a4cf91266..d25d22cf52838 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional +from typing import Dict, FrozenSet, List, Optional -from vllm.core.block.interfaces import (Block, BlockAllocator, +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator @@ -57,15 +57,15 @@ def create( cpu_block_ids = block_ids[num_gpu_blocks:] if allocator_type == "naive": - gpu_allocator = NaiveBlockAllocator( - create_block=NaiveBlock, + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, ) - cpu_allocator = NaiveBlockAllocator( - create_block=NaiveBlock, + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, @@ -105,13 +105,14 @@ def __init__( Device.GPU: gpu_block_allocator, } - self._block_ids_to_allocator = {} + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} for _, allocator in self._allocators.items(): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator - def allocate_mutable(self, prev_block: Optional[Block], - device: Device) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a new mutable block on the specified device. Args: @@ -122,10 +123,13 @@ def allocate_mutable(self, prev_block: Optional[Block], Returns: Block: The newly allocated mutable block. """ + assert device is not None return self._allocators[device].allocate_mutable(prev_block) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. @@ -140,6 +144,7 @@ def allocate_immutable(self, prev_block: Optional[Block], Block: The newly allocated immutable block containing the provided token IDs. """ + assert device is not None return self._allocators[device].allocate_immutable( prev_block, token_ids) @@ -149,7 +154,9 @@ def free(self, block: Block) -> None: Args: block (Block): The block to be freed. """ - allocator = self._block_ids_to_allocator[block.block_id] + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] return allocator.free(block) def fork(self, last_block: Block) -> List[Block]: @@ -163,19 +170,22 @@ def fork(self, last_block: Block) -> List[Block]: List[Block]: A new list of blocks that shares the same memory as the original sequence. """ - allocator = self._block_ids_to_allocator[last_block.block_id] + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] return allocator.fork(last_block) - def get_num_free_blocks(self, device: Device) -> int: + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: """Returns the number of free blocks available on the specified device. Args: device (Device): The device for which to query the number of free - blocks. + blocks. AssertionError is raised if None is passed. Returns: int: The number of free blocks available on the specified device. """ + assert device is not None return self._allocators[device].get_num_free_blocks() def clear_copy_on_writes(self) -> Dict[int, List[int]]: @@ -210,5 +220,12 @@ def get_common_computed_block_ids( return self._allocators[device].get_common_computed_block_ids( seq_block_ids) - def all_block_ids(self) -> frozenset[int]: + @property + def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError + + def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + raise NotImplementedError diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 440d6a4b04d3b..08d2f87301d92 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -3,6 +3,8 @@ from vllm.utils import Device +BlockId = int + class Block(ABC): @@ -15,6 +17,12 @@ def append_token_ids(self, token_ids: List[int]) -> None: def block_id(self) -> Optional[int]: pass + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + @property @abstractmethod def token_ids(self) -> List[int]: @@ -35,6 +43,27 @@ def is_full(self) -> bool: def prev_block(self) -> Optional["Block"]: pass + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + class Factory(Protocol): @abstractmethod @@ -48,6 +77,17 @@ def __call__( ) -> "Block": pass + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + class BlockAllocator(ABC): @@ -57,7 +97,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + token_ids: List[int]) -> Block: pass @abstractmethod @@ -69,7 +109,7 @@ def fork(self, last_block: Block) -> List[Block]: pass @abstractmethod - def get_num_free_blocks(self, device: Device) -> int: + def get_num_free_blocks(self) -> int: pass @property @@ -82,11 +122,12 @@ def clear_copy_on_writes(self) -> Dict[int, List[int]]: pass @abstractmethod - def mark_blocks_as_accessed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: pass @abstractmethod - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass @abstractmethod @@ -94,21 +135,66 @@ def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: pass + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + class NoFreeBlocksError(ValueError): pass -class DeviceAwareBlockAllocator(BlockAllocator): +class DeviceAwareBlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: pass @abstractmethod - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: pass @abstractmethod - def get_num_free_blocks(self, device: Device) -> int: + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> Dict[int, List[int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index a0bf33912d935..10af129246889 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,10 +1,9 @@ -from typing import Dict, Iterable, List, Optional, Set +from typing import Dict, FrozenSet, Iterable, List, Optional, Set from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -BlockId = int Refcount = int @@ -49,8 +48,10 @@ def __init__( allocator=self, ) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -63,11 +64,14 @@ def allocate_immutable(self, prev_block: Optional[Block], Returns: Block: The newly allocated immutable block. """ + assert device is None block = self.allocate_mutable(prev_block=prev_block) block.append_token_ids(token_ids) return block - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a new mutable block, linked to the previous block. Args: @@ -78,6 +82,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: Returns: Block: The newly allocated mutable block. """ + assert device is None block_id = self._allocate_new_block_id() return self._create_block( prev_block=prev_block, @@ -88,6 +93,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: ) def free(self, block: Block) -> None: + assert block.block_id is not None self._free_block_id(block.block_id) # Mark the block as having no allocation. @@ -111,6 +117,7 @@ def fork(self, last_block: Block) -> List[Block]: for block in source_blocks: # Increment refcount for each block. + assert block.block_id is not None refcount = self._refcounter.incr(block.block_id) assert refcount != 1, "can't fork free'd block" @@ -126,7 +133,8 @@ def fork(self, last_block: Block) -> List[Block]: return forked_blocks - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None return len(self._free_block_indices) def _allocate_new_block_id(self) -> BlockId: @@ -148,7 +156,7 @@ def refcounter(self): return self._refcounter @property - def all_block_ids(self): + def all_block_ids(self) -> FrozenSet[int]: return self._all_block_indices def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: @@ -200,6 +208,9 @@ def get_common_computed_block_ids( """ return [] + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix @@ -224,13 +235,13 @@ class NaiveBlock(Block): """ def __init__(self, - prev_block: Block, + prev_block: Optional[Block], token_ids: List[int], block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, _cow_target: Optional[Block] = None): - self._token_ids = [] + self._token_ids: List[int] = [] self._block_size = block_size self._prev_block = prev_block self._block_id = block_id @@ -256,6 +267,22 @@ def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: assert self.num_empty_slots >= len(token_ids) self._token_ids.extend(token_ids) + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + @property def block_id(self) -> Optional[int]: return self._block_id @@ -276,9 +303,14 @@ def num_empty_slots(self) -> int: def token_ids(self) -> List[int]: return self._token_ids + @property def block_size(self) -> int: return self._block_size @property def prev_block(self) -> Optional["Block"]: return self._prev_block + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 292a750146ae6..e9000c9bfff7f 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,16 +1,15 @@ """Token blocks.""" from itertools import takewhile from os.path import commonprefix -from typing import Dict, Iterable, List, Optional +from typing import Dict, FrozenSet, Iterable, List, Optional from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor PrefixHash = int -BlockId = int # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, @@ -38,7 +37,7 @@ def __init__( num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, - eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, ): # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. @@ -49,7 +48,7 @@ def __init__( # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( - create_block=self._create_block, + create_block=self._create_block, # type: ignore num_blocks=num_blocks, block_size=block_size, block_ids=block_ids, @@ -79,7 +78,7 @@ def _create_block( block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, - computed: Optional[bool] = False, + computed: bool = False, ) -> Block: # Bind block to self. allocator = self @@ -93,8 +92,10 @@ def _create_block( computed=computed, ) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -105,6 +106,7 @@ def allocate_immutable(self, prev_block: Optional[Block], Returns: Block: The allocated immutable block. """ + assert device is None assert_prefix_caching_block_or_none(prev_block) block = self._create_block( @@ -127,16 +129,20 @@ def allocate_immutable(self, prev_block: Optional[Block], return block - def allocate_mutable(self, prev_block: Block) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a mutable block. If there are no free blocks, this will evict unused cached blocks. Args: prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. Returns: Block: The allocated mutable block. """ + assert device is None assert_prefix_caching_block_or_none(prev_block) try: @@ -144,6 +150,7 @@ def allocate_mutable(self, prev_block: Block) -> Block: prev_block=prev_block) assert block.block_id not in self._blocks + assert block.block_id is not None self._blocks[block.block_id] = block return block except BlockAllocator.NoFreeBlocksError: @@ -183,6 +190,7 @@ def allocate_mutable(self, prev_block: Block) -> Block: assert block.content_hash is None assert block.block_id not in self._blocks + assert block.block_id is not None self._blocks[block.block_id] = block return block @@ -225,6 +233,7 @@ def _free_block_id_for_block(self, block_id: BlockId, # We have fork case where block would get more than one ref, # so we cannot free it from tracking if ref cnt large than 1 if refcount <= 1: + assert block.block_id is not None del self._blocks[block.block_id] return self._hashless_allocator.free(block) @@ -233,6 +242,7 @@ def _free_block_id_for_block(self, block_id: BlockId, # If no longer used, add the block to the evictor. if refcount == 0: assert block.content_hash in self._cached_blocks + assert block.block_id is not None del self._blocks[block.block_id] self.evictor.add(block.block_id, block.content_hash, block.num_tokens_total, block.last_accessed) @@ -268,18 +278,18 @@ def fork(self, last_block: Block) -> List[Block]: return forked_blocks - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None # The number of free blocks is the number of hashless free blocks # plus the number of blocks evictor could free from its list. return self._hashless_allocator.get_num_free_blocks( ) + self.evictor.num_blocks @property - def all_block_ids(self) -> frozenset[int]: + def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids - def promote_to_immutable_block(self, - block: "PrefixCachingBlock") -> BlockId: + def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable block. This means that its content can be referenced by future blocks having the same prefix. @@ -289,7 +299,7 @@ def promote_to_immutable_block(self, block. Args: - block (PrefixCachingBlock): The mutable block to be promoted. + block: The mutable block to be promoted. Returns: BlockId: Either the original block index, or the block index of @@ -385,8 +395,11 @@ def get_common_computed_block_ids( takewhile(lambda block_id: self.block_is_computed(block_id), seq[:-1])) for seq in seq_block_ids ] - res = commonprefix([ids for ids in ids_list if ids != []]) - return res + # It returns a list of int although type annotation says list of string. + return commonprefix([ + ids for ids in ids_list # type: ignore + if ids != [] + ]) class PrefixCachingBlock(Block): @@ -403,7 +416,7 @@ class PrefixCachingBlock(Block): token_ids (List[int]): The initial token IDs to be stored in the block. block_size (int): The maximum number of token IDs that can be stored in the block. - prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix + prefix_caching_allocator (BlockAllocator): The prefix caching block allocator associated with this block. block_id (Optional[int], optional): The physical block index of this block. Defaults to None. @@ -411,21 +424,25 @@ class PrefixCachingBlock(Block): def __init__( self, - prev_block: Optional["PrefixCachingBlock"], + prev_block: Optional[Block], token_ids: List[int], block_size: int, - prefix_caching_allocator: PrefixCachingBlockAllocator, + prefix_caching_allocator: BlockAllocator, block_id: Optional[int] = None, - computed: Optional[bool] = False, + computed: bool = False, ): + assert isinstance(prefix_caching_allocator, + PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator.") assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block self._cached_content_hash: Optional[int] = None self._cached_num_tokens_total: Optional[int] = None self._prefix_caching_allocator = prefix_caching_allocator - self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME - self.computed = computed + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed self._block = NaiveBlock( prev_block=prev_block, @@ -436,6 +453,22 @@ def __init__( _cow_target=self, ) + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + def append_token_ids(self, token_ids: List[int]) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. @@ -483,7 +516,7 @@ def num_tokens_total(self) -> int: if self._cached_num_tokens_total is not None: return self._cached_num_tokens_total - _block = self + _block: Optional[Block] = self self._cached_num_tokens_total = 0 # TODO: current implement here take O(N^2), we expect future @@ -524,8 +557,10 @@ def content_hash(self) -> Optional[int]: return None is_first_block = self._prev_block is None - prev_block_hash = (None if is_first_block else - self._prev_block.content_hash) + prev_block_hash = ( + None if is_first_block else + self._prev_block.content_hash # type: ignore + ) # Previous block exists but does not yet have a hash. # Return no hash in this case. diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 0857605e2d005..3fbd8b787cf6c 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -190,7 +190,7 @@ def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) - return block_ids + return block_ids # type: ignore def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed @@ -204,7 +204,9 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): block_ids = [] for block_id in block_table.physical_block_ids: block_ids.append(block_id) - self.block_allocator.mark_blocks_as_accessed(block_ids, now) + self.block_allocator.mark_blocks_as_accessed( + block_ids, # type: ignore + now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): # The only need for mark block as computed is for prefix caching, @@ -227,8 +229,9 @@ def get_common_computed_block_ids( seq_block_ids = [ self.block_tables[seq.seq_id].physical_block_ids for seq in seqs ] + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( - seq_block_ids) + seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: src_block_table = self.block_tables[parent_seq.seq_id] diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py index b902a39263d14..57759b29347f4 100644 --- a/vllm/core/evictor_v2.py +++ b/vllm/core/evictor_v2.py @@ -32,15 +32,20 @@ def evict(self) -> Tuple[int, int]: @abstractmethod def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, - last_accessed: int): + last_accessed: float): """Adds block to the evictor, making it a candidate for eviction""" pass @abstractmethod - def update(self, block_id: int, last_accessed: int): + def update(self, block_id: int, last_accessed: float): """Update corresponding block's access time in metadata""" pass + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + @abstractproperty def num_blocks(self) -> int: pass @@ -55,7 +60,7 @@ class BlockMetaData(): """ def __init__(self, content_hash: int, num_hashed_tokens: int, - last_accessed: int): + last_accessed: float): self.content_hash = content_hash self.num_hashed_tokens = num_hashed_tokens self.last_accessed = last_accessed @@ -96,12 +101,12 @@ def evict(self) -> Tuple[int, int]: return evicted_block_id, evicted_block.content_hash def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, - last_accessed: int): + last_accessed: float): self.free_table[block_id] = BlockMetaData(content_hash, num_hashed_tokens, last_accessed) - def update(self, block_id: int, last_accessed: int): + def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed def remove(self, block_id: int): From 2a85f9300733c09ec90819bc6df4bff8f103fd67 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 1 May 2024 21:28:21 -0700 Subject: [PATCH 188/413] [Core][Distributed] enable multiple tp group (#4512) Co-authored-by: Zhuohan Li --- .buildkite/test-pipeline.yaml | 11 ++++++-- .buildkite/test-template.j2 | 3 ++ tests/distributed/test_pynccl.py | 28 +++++++++++++++++++ .../device_communicators/pynccl.py | 5 +++- 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 641f366d06031..d518fb9ccecfa 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -25,19 +25,24 @@ steps: - label: Distributed Comm Ops Test command: pytest -v -s test_comm_ops.py working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + num_gpus: 2 - label: Distributed Tests working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + num_gpus: 2 commands: - - pytest -v -s test_pynccl.py - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py +- label: Distributed Tests (Multiple Groups) + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 4 + commands: + - pytest -v -s test_pynccl.py + - label: Engine Test command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 5c9515840bb03..2cb21cacd065b 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -45,6 +45,9 @@ steps: plugins: - kubernetes: podSpec: + {% if step.num_gpus %} + priorityClassName: gpu-priority-cls-{{ step.num_gpus }} + {% endif %} volumes: - name: dshm emptyDir: diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 6d7d4a5806bd0..e71d839648c83 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -58,6 +58,34 @@ def test_pynccl(): distributed_run(worker_fn, 2) +@worker_fn_wrapper +def multiple_tp_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 1], backend="gloo"), + torch.distributed.new_group(ranks=[2, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] + comm = NCCLCommunicator(group=group, device=device) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + comm.all_reduce(tensor) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_multiple_tp(): + distributed_run(worker_fn, 4) + + @worker_fn_wrapper def worker_fn_with_cudagraph(): with torch.no_grad(): diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f21fcd262d810..758994352e3de 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -232,6 +232,7 @@ def __init__( assert dist.get_backend(group) != dist.Backend.NCCL, ( "NCCLCommunicator should be attached to a non-NCCL group.") self.group = group + # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) if self.rank == 0: @@ -239,7 +240,9 @@ def __init__( else: self.unique_id = NcclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) - dist.broadcast(tensor, src=0, group=group) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte From 7038e8b80303bf6128acbe508dec910183a1be56 Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Thu, 2 May 2024 12:56:22 -0400 Subject: [PATCH 189/413] [Kernel] Support running GPTQ 8-bit models in Marlin (#4533) --- csrc/ops.h | 4 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 552 ++++++++++++------ csrc/quantization/gptq_marlin/gptq_marlin.cuh | 8 +- .../gptq_marlin/gptq_marlin_repack.cu | 152 +++-- tests/models/test_gptq_marlin.py | 13 +- vllm/_custom_ops.py | 14 +- .../layers/quantization/gptq_marlin.py | 134 ++--- 7 files changed, 553 insertions(+), 324 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 04b97d1784cd2..8ae052427052f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, @@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack( torch::Tensor &b_q_weight, torch::Tensor &perm, int64_t size_k, - int64_t size_n); + int64_t size_n, + int64_t num_bits); #endif void squeezellm_gemm( diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9902f55167d89..fd0837f0cb39c 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template = 8.0"); return torch::empty({1, 1}); @@ -114,11 +115,21 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } +// Constructs destination register by taking bytes from 2 sources (based on mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +__device__ inline FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) { return frag_b; } +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { @@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float *c, FragS &s) { + __half *s_ptr = reinterpret_cast<__half *>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int *lock, int count) { if (threadIdx.x == 0) { @@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); - FragB frag_b0 = dequant(b_quant); + } else { + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped portioning + // finally have to globally reduce over the results. As the striped partitioning // minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { @@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto write = [&](int idx, float c0, float c1, FragS &s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { res = __hmul2(res, s[0]); } ((half2 *)sh)[idx] = res; }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll @@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (num_bits == 8) { if (s_sh_wr_pred) { - cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (num_bits == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } } } @@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } - // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("Move\n"); - // } start_pipes(); } } } } -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ +#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ + Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ + Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ prob_k, locks); \ @@ -1158,28 +1270,92 @@ typedef struct { int num_threads; } thread_config_t; -thread_config_t small_batch_thread_configs[] = { +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {64, 256, 256}, // Default (max cache usage) + {64, 128, 128}, // Reduce N, reduce warps + {128, 64, 128}, // Reduce N more, but increase K + }; -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority +int get_scales_cache_size(thread_config_t const &th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 128, 128}, // Reduce N 2X, same K - // {128, 64, 128}, // Reduce N 4X, increase K 2X -}; + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, - int prob_k) { + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - - // TODO: Enable if needed after some more testing - if (prob_m <= 0) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + for (auto th_config : thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; } } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } + printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM " + "GPU cache. This may " + "hurt performance. Consider upgrading your GPU.\n"); + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ +#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, - void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, - void *workspace, bool has_act_order, bool is_k_full, - int num_groups, int group_size, int dev = 0, - cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, - int sms = -1, int max_par = 16) { + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, + void *g_idx, void *perm, void *a_tmp, int prob_m, + int prob_n, int prob_k, void *workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, TORCH_CHECK(max_shared_mem > 0); // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, } // Main loop - for (int i = 0; i < tot_m_blocks; i += 4) { + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; - if (thread_m_blocks > 4) { + if (thread_m_blocks > exec_cfg.max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * thread_m_blocks - pad) / 64; + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; } // Define kernel configurations if (false) { } - CALL_IF(16, 4, 256) - CALL_IF(8, 8, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + // Verify A - TORCH_CHECK(a.size(0) == size_m, - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(a.size(1) == size_k, - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(gptq_marlin::tile_size)); + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(gptq_marlin::tile_size)); - TORCH_CHECK( - b_q_weight.size(1) % gptq_marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(gptq_marlin::tile_size)); - int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * - gptq_marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); @@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, // Verify g_idx and perm TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) + - " and perm.size(0) = " + str(perm.size(0)) + - ", where size_k = " + str(size_k)); + "Unexpected g_idx.size(0) = ", g_idx.size(0), + " and perm.size(0) = ", perm.size(0), + ", where size_k = ", size_k); // Detect groupsize and act_order int num_groups = -1; @@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, - "size_k = " + str(size_k) + - ", is not divisible by num_groups = " + str(num_groups)); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); group_size = size_k / num_groups; } else { group_size = 0; @@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, } else { if (num_groups > 1) { - TORCH_CHECK(size_k % num_groups == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); group_size = size_k / num_groups; } else { group_size = -1; @@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, } // Verify workspace size - TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(gptq_marlin::min_thread_n)); + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - gptq_marlin::marlin_cuda( + gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, gptq_marlin::max_par); + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); return c; } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index 8cfce6b2575d5..35ea48aaba310 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; -static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit - template struct Vec { T elems[n]; @@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index fa45ce68a0c77..0d3da6240dbca 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -template +template __global__ void marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const *__restrict__ perm_ptr, @@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } // namespace gptq_marlin torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n) { + int64_t size_k, int64_t size_n, + int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, #else -template +template __global__ void marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const *__restrict__ perm_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + int k_tiles = size_k / tile_k_size; int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); @@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, sh_pipe_ptr += perm_size; } + constexpr int tile_ints = tile_k_size / pack_factor; + constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = - has_perm ? tile_k_size : tile_k_size / pack_factor_4bit; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { @@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor_4bit; + int src_k_packed = src_k / pack_factor; - cp_async4_stream( + cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); @@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor_4bit; + int first_k_packed = first_k / pack_factor; - cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + - first_n + (n_id * 4)]))); + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); } } @@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int cur_n = warp_id * 16 + tc_col; constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - uint32_t vals[pack_factor_4bit]; + uint32_t vals[8]; if constexpr (has_perm) { for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor_4bit; + uint32_t src_k_pos = src_k % pack_factor; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; vals[i] = b1_cur_val; vals[4 + i] = b2_cur_val; @@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } else { - uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; - uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; - - uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8]; - uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8]; + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; #pragma unroll - for (int i = 0; i < 2; i++) { - int cur_elem = tc_row + tc_offsets[i]; - vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; - vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf; + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } #pragma unroll - for (int i = 2; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i] - 8; - vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; - vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < pack_factor_4bit; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + out_ptr[out_offset + th_id * 4 + warp_id] = res; - out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { @@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } // namespace gptq_marlin +#define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n) { + int64_t size_k, int64_t size_n, + int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + // Verify B - TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), + TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, - ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit); + ", size_k = ", size_k, ", pack_factor = ", pack_factor); TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n); @@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, - options); + torch::Tensor out = + torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; @@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); - if (has_perm) { - cudaFuncSetAttribute( - gptq_marlin::marlin_repack_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); - gptq_marlin::marlin_repack_kernel - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); - - } else { - cudaFuncSetAttribute( - gptq_marlin::marlin_repack_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); - gptq_marlin::marlin_repack_kernel - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", has_perm = ", has_perm); } return out; diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index dc027697ffd4d..4d73843f970c4 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -39,6 +39,13 @@ ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), # act_order==True, group_size=32 ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), + + # 8-bit, act_order==True, group_size=channelwise + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"), + # 8-bit, act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"), + # 8-bit, act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"), ] @@ -65,8 +72,7 @@ def test_models( dtype=dtype, quantization="marlin", max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1, - disable_custom_all_reduce=True) + tensor_parallel_size=1) gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -78,8 +84,7 @@ def test_models( dtype=dtype, quantization="gptq", max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1, - disable_custom_all_reduce=True) + tensor_parallel_size=1) gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4af8b09b1e16c..3faed5ea85307 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int, + perm: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, size_m, size_n, size_k, - is_k_full) + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index efbffa0878c4b..e2464008a875f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import numpy import torch from torch.nn.parameter import Parameter @@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] -# Precompute permutations for Marlin weight and scale shuffling -# -# Marlin works on [16,64] tiles. The goal of the permutations -# is to reorder the weight data so that it is compatible -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the -# kernel will get the data as it is needed for tensor-core -# (without the need to use ldmatrix instructions) -def _get_perms(): - perm = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm) - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore - perm = torch.from_numpy(perm) +# Permutations for Marlin scale shuffling +def get_scale_perms(num_bits): scale_perm = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) @@ -59,23 +30,21 @@ def _get_perms(): for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm, _scale_perm, _scale_perm_single = _get_perms() + return scale_perm, scale_perm_single def get_pack_factor(num_bits): - assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( - f"Unsupported num_bits = {num_bits}") + assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" return 32 // num_bits -def marlin_permute_scales(s, size_k, size_n, group_size): +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + scale_perm, scale_perm_single = get_scale_perms(num_bits) if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s @@ -279,13 +248,15 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - qweight, { + qweight, + { **extra_weight_attrs, "input_dim": 0, "output_dim": 1, "packed_dim": 0, "pack_factor": self.quant_config.pack_factor, - }) + }, + ) # Activation order g_idx = Parameter( @@ -296,10 +267,13 @@ def create_weights( requires_grad=False, ) # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs(g_idx, { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - }) + set_weight_attrs( + g_idx, + { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }, + ) g_idx_sort_indices = Parameter( torch.empty( @@ -320,29 +294,34 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - scales, { + scales, + { **extra_weight_attrs, "input_dim": scales_and_zp_input_dim, "output_dim": 1, - }) + }, + ) # Quantized zero-points qzeros = Parameter( - torch.empty(scales_and_zp_size, - output_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32, - device="meta"), + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device="meta", + ), requires_grad=False, ) set_weight_attrs( - qzeros, { + qzeros, + { **extra_weight_attrs, "input_dim": scales_and_zp_input_dim, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - }) + }, + ) # Allocate marlin workspace max_workspace_size = ( @@ -405,13 +384,14 @@ def replace_tensor(name, new_t): else: # Reset g_idx related tensors - layer.g_idx = Parameter(torch.empty(0, - dtype=torch.int, - device=cur_device), - requires_grad=False) - layer.g_idx_sort_indices = Parameter(torch.empty( - 0, dtype=torch.int, device=cur_device), - requires_grad=False) + layer.g_idx = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) + layer.g_idx_sort_indices = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) # Repack weights marlin_qweight = ops.gptq_marlin_repack( @@ -419,6 +399,7 @@ def replace_tensor(name, new_t): layer.g_idx_sort_indices, part_size_k, part_size_n, + self.quant_config.weight_bits, ) replace_tensor("qweight", marlin_qweight) @@ -428,15 +409,28 @@ def replace_tensor(name, new_t): if self.quant_config.desc_act: scales_size_k = full_size_k - marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, - scales_size_n, - self.quant_config.group_size) + marlin_scales = marlin_permute_scales( + layer.scales, + scales_size_k, + scales_size_n, + self.quant_config.group_size, + self.quant_config.weight_bits, + ) replace_tensor("scales", marlin_scales) - output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, - layer.g_idx, layer.g_idx_sort_indices, - layer.workspace, size_m, part_size_n, - part_size_k, layer.is_k_full) + output = ops.gptq_marlin_gemm( + reshaped_x, + layer.qweight, + layer.scales, + layer.g_idx, + layer.g_idx_sort_indices, + layer.workspace, + self.quant_config.weight_bits, + size_m, + part_size_n, + part_size_k, + layer.is_k_full, + ) if bias is not None: output.add_(bias) # In-place add From fb087af52e3834d98250a455355a3ef329663168 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 3 May 2024 02:47:41 +0900 Subject: [PATCH 190/413] [mypy][7/N] Cover all directories (#4555) --- .github/workflows/mypy.yaml | 2 ++ format.sh | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 5b2bad1476dc3..a20753d8a7702 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -45,4 +45,6 @@ jobs: mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml + mypy vllm/logging --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml diff --git a/format.sh b/format.sh index 49149afe41d04..233e6af0c9479 100755 --- a/format.sh +++ b/format.sh @@ -107,6 +107,8 @@ mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml +mypy vllm/logging --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml CODESPELL_EXCLUDES=( From 5ad60b0cbd0a396eb3f1fda6bbf2c95aff6d5ecf Mon Sep 17 00:00:00 2001 From: Hu Dong Date: Fri, 3 May 2024 01:50:25 +0800 Subject: [PATCH 191/413] [Misc] Exclude the `tests` directory from being packaged (#4552) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d534cec437261..801d8d50db1dc 100644 --- a/setup.py +++ b/setup.py @@ -404,7 +404,7 @@ def _read_requirements(filename: str) -> List[str]: "Topic :: Scientific/Engineering :: Artificial Intelligence", ], packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples", - "tests")), + "tests*")), python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, From 1ff0c73a79b0c2788b12bd83523b74c01d414480 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Thu, 2 May 2024 18:52:51 +0100 Subject: [PATCH 192/413] [BugFix] Include target-device specific requirements.txt in sdist (#4559) --- MANIFEST.in | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index d385f194c6c0f..82be639ef4d73 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,9 @@ include LICENSE include requirements-common.txt include requirements-cuda.txt +include requirements-rocm.txt +include requirements-neuron.txt +include requirements-cpu.txt include CMakeLists.txt recursive-include cmake * From 5b8a7c1cb0f1bb81266bae98944c055a8abb1a68 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 2 May 2024 11:13:25 -0700 Subject: [PATCH 193/413] [Misc] centralize all usage of environment variables (#4548) --- vllm/attention/backends/rocm_flash_attn.py | 5 +- vllm/attention/selector.py | 6 +- vllm/config.py | 5 - .../device_communicators/custom_all_reduce.py | 8 +- vllm/distributed/parallel_state.py | 4 +- vllm/distributed/utils.py | 7 +- vllm/engine/async_llm_engine.py | 5 +- vllm/entrypoints/openai/api_server.py | 4 +- vllm/envs.py | 160 ++++++++++++++++++ vllm/executor/cpu_executor.py | 5 +- vllm/executor/multiproc_worker_utils.py | 5 +- vllm/executor/ray_gpu_executor.py | 8 +- vllm/logger.py | 6 +- vllm/model_executor/model_loader/loader.py | 7 +- .../model_executor/model_loader/tensorizer.py | 12 +- vllm/transformers_utils/tokenizer.py | 2 +- vllm/usage/usage_lib.py | 16 +- vllm/utils.py | 19 ++- 18 files changed, 220 insertions(+), 64 deletions(-) create mode 100644 vllm/envs.py diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 934acea0a3d60..b7d15de772556 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,10 +1,10 @@ """Attention layer ROCm GPUs.""" -import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch +import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) @@ -156,8 +156,7 @@ def __init__( self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. - self.use_triton_flash_attn = (os.environ.get( - "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7cc17f21dcd0e..7ae8c31fae1ac 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,18 +1,16 @@ import enum -import os from functools import lru_cache from typing import Type import torch +import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) -VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" - class _Backend(enum.Enum): FLASH_ATTN = enum.auto() @@ -79,7 +77,7 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "package is not found. Please install it for better performance.") return _Backend.XFORMERS - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + backend_by_env_var = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: return _Backend[backend_by_env_var] diff --git a/vllm/config.py b/vllm/config.py index 257d49b6e804f..aaa2f60739d55 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,5 @@ import enum import json -import os from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union @@ -24,10 +23,6 @@ logger = init_logger(__name__) -# If true, will load models from ModelScope instead of Hugging Face Hub. -VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE", - "False").lower() == "true" - _GB = 1 << 30 diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index ec4533326e841..cc5f8166877ce 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,10 +1,10 @@ -import os from contextlib import contextmanager from typing import Any, List, Optional import torch import torch.distributed as dist +import vllm.envs as envs from vllm.logger import init_logger try: @@ -54,9 +54,9 @@ def init_custom_ar() -> None: return # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = list( - map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))) + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) else: device_ids = list(range(num_dev)) # this checks hardware and driver support for NVLink diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6ca6fc5b5f9fe..a82a1254693df 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -4,11 +4,11 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib -import os from typing import Optional import torch +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -80,7 +80,7 @@ def init_distributed_environment( # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 if local_rank == -1 and distributed_init_method == "env://": - local_rank = int(os.environ['LOCAL_RANK']) + local_rank = envs.LOCAL_RANK global _LOCAL_RANK _LOCAL_RANK = local_rank diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 9a13b94c3ada1..1965d4c1d3cbc 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist +import vllm.envs as envs from vllm.logger import init_logger from .parallel_state import get_cpu_world_group, get_local_rank @@ -102,11 +103,13 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: is_distributed = dist.is_initialized() num_dev = torch.cuda.device_count() - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT path = os.path.expanduser( - f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) os.makedirs(os.path.dirname(path), exist_ok=True) if (not is_distributed or get_local_rank() == 0) \ and (not os.path.exists(path)): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5591893d267a2..cf5053bba1d48 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,5 +1,4 @@ import asyncio -import os import time from functools import partial from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, @@ -7,6 +6,7 @@ from transformers import PreTrainedTokenizer +import vllm.envs as envs from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs @@ -20,8 +20,7 @@ from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) -ENGINE_ITERATION_TIMEOUT_S = int( - os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S class AsyncEngineDeadError(RuntimeError): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 40103f70a31a3..8b3c5ea9de9c0 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,7 +1,6 @@ import asyncio import importlib import inspect -import os import re from contextlib import asynccontextmanager from http import HTTPStatus @@ -16,6 +15,7 @@ from starlette.routing import Mount import vllm +import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -129,7 +129,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): allow_headers=args.allowed_headers, ) - if token := os.environ.get("VLLM_API_KEY") or args.api_key: + if token := envs.VLLM_API_KEY or args.api_key: @app.middleware("http") async def authentication(request: Request, call_next): diff --git a/vllm/envs.py b/vllm/envs.py new file mode 100644 index 0000000000000..26ed731caa5ff --- /dev/null +++ b/vllm/envs.py @@ -0,0 +1,160 @@ +import os +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +if TYPE_CHECKING: + VLLM_HOST_IP: str = "" + VLLM_USE_MODELSCOPE: bool = False + VLLM_INSTANCE_ID: Optional[str] = None + VLLM_NCCL_SO_PATH: Optional[str] = None + LD_LIBRARY_PATH: Optional[str] = None + VLLM_USE_TRITON_FLASH_ATTN: bool = False + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_API_KEY: Optional[str] = None + S3_ACCESS_KEY_ID: Optional[str] = None + S3_SECRET_ACCESS_KEY: Optional[str] = None + S3_ENDPOINT_URL: Optional[str] = None + VLLM_CONFIG_ROOT: str = "" + VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" + VLLM_NO_USAGE_STATS: bool = False + VLLM_DO_NOT_TRACK: bool = False + VLLM_USAGE_SOURCE: str = "" + VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_TRACE_FUNCTION: int = 0 + VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + +environment_variables: Dict[str, Callable[[], Any]] = { + # used in distributed environment to determine the master address + 'VLLM_HOST_IP': + lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + + # If true, will load models from ModelScope instead of Hugging Face Hub. + # note that the value is true or false, not numbers + "VLLM_USE_MODELSCOPE": + lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", + + # Instance id represents an instance of the VLLM. All processes in the same + # instance should have the same instance id. + "VLLM_INSTANCE_ID": + lambda: os.environ.get("VLLM_INSTANCE_ID", None), + + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": + lambda: os.environ.get("CUDA_HOME", None), + + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "VLLM_NCCL_SO_PATH": + lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), + + # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": + lambda: os.environ.get("LD_LIBRARY_PATH", None), + + # flag to control if vllm should use triton flash attention + "VLLM_USE_TRITON_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in + ("true", "1")), + + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": + lambda: int(os.environ.get("LOCAL_RANK", "0")), + + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": + lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + + # timeout for each iteration in the engine + "VLLM_ENGINE_ITERATION_TIMEOUT_S": + lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), + + # API key for VLLM API server + "VLLM_API_KEY": + lambda: os.environ.get("VLLM_API_KEY", None), + + # S3 access information, used for tensorizer to load model from S3 + "S3_ACCESS_KEY_ID": + lambda: os.environ.get("S3_ACCESS_KEY", None), + "S3_SECRET_ACCESS_KEY": + lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": + lambda: os.environ.get("S3_ENDPOINT_URL", None), + + # Root directory for VLLM configuration files + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv( + "XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"), + + # Usage stats collection + "VLLM_USAGE_STATS_SERVER": + lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), + "VLLM_NO_USAGE_STATS": + lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DO_NOT_TRACK": + lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( + "DO_NOT_TRACK", None) or "0") == "1", + "VLLM_USAGE_SOURCE": + lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), + + # Logging configuration + # If set to 0, vllm will not configure logging + # If set to 1, vllm will configure logging using the default configuration + # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH + "VLLM_CONFIGURE_LOGGING": + lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": + lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + + # Trace function calls + # If set to 1, vllm will trace function calls + # Useful for debugging + "VLLM_TRACE_FUNCTION": + lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), + + # Backend for attention computation + # Available options: + # - "TORCH_SDPA": use torch.nn.MultiheadAttention + # - "FLASH_ATTN": use FlashAttention + # - "XFORMERS": use XFormers + # - "ROCM_FLASH": use ROCmFlashAttention + "VLLM_ATTENTION_BACKEND": + lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + + # CPU key-value cache space + # default is 4GB + "VLLM_CPU_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + "VLLM_USE_RAY_COMPILED_DAG": + lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), + + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "VLLM_WORKER_MULTIPROC_METHOD": + lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), +} + + +def __getattr__(name): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index da1b500cddaf6..733eef828adc4 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,8 +1,8 @@ -import os from typing import Dict, List, Set, Tuple import torch +import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -152,8 +152,7 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: logger.warning("Prefix caching is not supported on CPU, disable it.") config.enable_prefix_caching = False - kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") - kv_cache_space = int(kv_cache_space_str) + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: if kv_cache_space == 0: diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 0c04796bc38e3..62887533f5c27 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -12,6 +12,7 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union) +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -26,9 +27,7 @@ JOIN_TIMEOUT_S = 2 -# Use dedicated multiprocess context for workers. -# Both spawn and fork work -mp_method = os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") +mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD mp = multiprocessing.get_context(mp_method) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 16d239b9ab580..4684b857ccd39 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -5,6 +5,7 @@ from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -21,10 +22,7 @@ logger = init_logger(__name__) -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG class RayGPUExecutor(DistributedGPUExecutor): @@ -145,7 +143,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "VLLM_INSTANCE_ID": VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": - os.getenv("VLLM_TRACE_FUNCTION", "0"), + str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) diff --git a/vllm/logger.py b/vllm/logger.py index 40c29da2b70ce..153cdfb373bb4 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -10,8 +10,10 @@ from os import path from typing import Dict, Optional -VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) -VLLM_LOGGING_CONFIG_PATH = os.getenv("VLLM_LOGGING_CONFIG_PATH") +import vllm.envs as envs + +VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING +VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 70e64167f8698..bafa2de62e5df 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,9 +9,10 @@ import torch from torch import nn -from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig, - LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) +from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 0ce9fa95aa7e5..af433b86e604d 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -11,6 +11,7 @@ from torch import nn from transformers import PretrainedConfig +import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -142,13 +143,10 @@ class TensorizerArgs: def __post_init__(self): self.file_obj = self.tensorizer_uri - self.s3_access_key_id = (self.s3_access_key_id - or os.environ.get("S3_ACCESS_KEY_ID")) or None - self.s3_secret_access_key = ( - self.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY")) or None - self.s3_endpoint = (self.s3_endpoint - or os.environ.get("S3_ENDPOINT_URL")) or None + self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (self.s3_secret_access_key + or envs.S3_SECRET_ACCESS_KEY) + self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL self.stream_params = { "s3_access_key_id": self.s3_access_key_id, "s3_secret_access_key": self.s3_secret_access_key, diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 9066db5a9e7f1..f5684dbf1271c 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from vllm.config import VLLM_USE_MODELSCOPE +from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import BaichuanTokenizer diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index b2672f7f1da61..9029a5b16af72 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -15,20 +15,22 @@ import requests import torch -_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) +import vllm.envs as envs + +_config_home = envs.VLLM_CONFIG_ROOT _USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json") _USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "vllm/do_not_track") _USAGE_STATS_ENABLED = None -_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER", - "https://stats.vllm.ai") +_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER def is_usage_stats_enabled(): """Determine whether or not we can send usage stats to the server. The logic is as follows: - By default, it should be enabled. - - Two environment variables can disable it: + - Three environment variables can disable it: + - VLLM_DO_NOT_TRACK=1 - DO_NOT_TRACK=1 - VLLM_NO_USAGE_STATS=1 - A file in the home directory can disable it if it exists: @@ -36,8 +38,8 @@ def is_usage_stats_enabled(): """ global _USAGE_STATS_ENABLED if _USAGE_STATS_ENABLED is None: - do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1" - no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1" + do_not_track = envs.VLLM_DO_NOT_TRACK + no_usage_stats = envs.VLLM_NO_USAGE_STATS do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats @@ -167,7 +169,7 @@ def _report_usage_once(self, model_architecture: str, # Metadata self.log_time = _get_current_timestamp_ns() - self.source = os.environ.get("VLLM_USAGE_SOURCE", "production") + self.source = envs.VLLM_USAGE_SOURCE data = vars(self) if extra_kvs: diff --git a/vllm/utils.py b/vllm/utils.py index 88447878f1706..ce55253ce2199 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -21,6 +21,7 @@ import torch from packaging.version import Version, parse +import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") @@ -174,7 +175,7 @@ def get_vllm_instance_id(): Instance id represents an instance of the VLLM. All processes in the same instance should have the same instance id. """ - return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}") + return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" @lru_cache(maxsize=None) @@ -243,7 +244,7 @@ async def consumer(): def get_ip() -> str: - host_ip = os.environ.get("HOST_IP") + host_ip = envs.VLLM_HOST_IP if host_ip: return host_ip @@ -269,7 +270,8 @@ def get_ip() -> str: warnings.warn( "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable HOST_IP.", + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", stacklevel=2) return "0.0.0.0" @@ -314,7 +316,7 @@ def cdiv(a: int, b: int) -> int: @lru_cache(maxsize=None) def get_nvcc_cuda_version() -> Optional[Version]: - cuda_home = os.environ.get('CUDA_HOME') + cuda_home = envs.CUDA_HOME if not cuda_home: cuda_home = '/usr/local/cuda' if os.path.isfile(cuda_home + '/bin/nvcc'): @@ -581,7 +583,7 @@ def find_library(lib_name: str) -> str: # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] # `LD_LIBRARY_PATH` searches the library in the user-defined paths - env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + env_ld_library_path = envs.LD_LIBRARY_PATH if not locs and env_ld_library_path: locs = [ os.path.join(dir, lib_name) @@ -594,14 +596,15 @@ def find_library(lib_name: str) -> str: def find_nccl_library(): - so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") + so_file = envs.VLLM_NCCL_SO_PATH + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT # check if we have vllm-managed nccl vllm_nccl_path = None if torch.version.cuda is not None: cuda_major = torch.version.cuda.split(".")[0] path = os.path.expanduser( - f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") + f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*") files = glob.glob(path) vllm_nccl_path = files[0] if files else None @@ -626,7 +629,7 @@ def enable_trace_function_call_for_thread() -> None: if enabled via the VLLM_TRACE_FUNCTION environment variable """ - if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + if envs.VLLM_TRACE_FUNCTION: tmp_dir = tempfile.gettempdir() filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" f"_thread_{threading.get_ident()}_" From 32881f3f3106e17d2fd52d8ac00217a0f0b2476a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Thu, 2 May 2024 11:23:37 -0700 Subject: [PATCH 194/413] [kernel] fix sliding window in prefix prefill Triton kernel (#4405) Co-authored-by: SangBin Cho --- tests/kernels/test_prefix_prefill.py | 34 ++++++++-- vllm/attention/backends/flash_attn.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/xformers.py | 1 + vllm/attention/ops/paged_attn.py | 2 + vllm/attention/ops/prefix_prefill.py | 75 ++++++++++++++++------ 6 files changed, 91 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index ad31b0a7c2a19..8ab1167384c45 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -15,6 +15,7 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -22,11 +23,13 @@ @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, num_queries_per_kv: int, head_size: int, + sliding_window: int, dtype: torch.dtype, device: str, ) -> None: @@ -123,12 +126,32 @@ def test_contexted_kv_attention( # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -156,6 +179,9 @@ def test_contexted_kv_attention( attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( subquery_lens, seq_lens) + if sliding_window > 0: + attn_bias = attn_bias.make_local_attention_from_bottomright( + sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 12e8c4404b94e..10b8c19b7499e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -249,6 +249,7 @@ def forward( prefill_meta.context_lens, prefill_meta.max_subquery_len, self.alibi_slopes, + self.sliding_window[0], ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b7d15de772556..3bc436315c3de 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -307,6 +307,7 @@ def forward( prefill_meta.context_lens, prefill_meta.max_subquery_len, self.alibi_slopes, + self.sliding_window[0], ) if decode_meta := attn_metadata.decode_metadata: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 572a4dc79a719..dc64ac0bf985d 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -246,6 +246,7 @@ def forward( prefill_meta.context_lens, prefill_meta.max_subquery_len, self.alibi_slopes, + self.sliding_window, ) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index cd0690a4ba957..c20b94ac8315b 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -172,6 +172,7 @@ def forward_prefix( context_lens: torch.Tensor, max_subquery_len: int, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( @@ -188,6 +189,7 @@ def forward_prefix( context_lens, max_subquery_len, alibi_slopes, + sliding_window, ) return output diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 4896cf3909c6e..79878b26c5294 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -50,6 +50,7 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -62,42 +63,53 @@ def _fwd_kernel( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + # start position inside of the query + # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m # initialize offsets + # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] q = tl.load(Q + off_q, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len), - other=0.0) + other=0.0) # [M,D] - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + # compute query against context (no causal mask here) for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0) # [N] + # [D,N] off_k = (bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] off_v = ( bn[:, None] * stride_v_cache_bs + cur_kv_head * stride_v_cache_h + @@ -106,23 +118,39 @@ def _fwd_kernel( k = tl.load(K_cache + off_k, mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) + other=0.0) # [D,N] - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] qk += tl.dot(q, k) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + # -- update output accumulator -- # scale p p_scale = beta / l_i_new @@ -134,7 +162,7 @@ def _fwd_kernel( v = tl.load(V_cache + off_v, mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) + other=0.0) # [N,D] p = p.to(v.dtype) acc += tl.dot(p, v) @@ -149,8 +177,10 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v + # block_mask is 0 when we're already past the current query length block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + # compute query against itself (with causal mask) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- @@ -163,8 +193,13 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale + # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -636,7 +671,8 @@ def context_attention_fwd(q, b_seq_len, b_ctx_len, max_input_len, - alibi_slopes=None): + alibi_slopes=None, + sliding_window=None): cap = torch.cuda.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 @@ -644,7 +680,7 @@ def context_attention_fwd(q, Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = 2**((Lk - 1).bit_length()) + Lk_padded = triton.next_power_of_2(Lk) sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -749,6 +785,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window if sliding_window is not None else 0, num_warps=num_warps, num_stages=1, ) From 9b5c9f9484858279a937498ebf9239a9df67f61f Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Thu, 2 May 2024 14:29:07 -0500 Subject: [PATCH 195/413] [CI/Build] AMD CI pipeline with extended set of tests. (#4267) Co-authored-by: simon-mo --- .buildkite/run-amd-test.sh | 58 +++++++++++++++-------------------- .buildkite/run-benchmarks.sh | 5 +++ .buildkite/test-pipeline.yaml | 15 ++++++++- .buildkite/test-template.j2 | 21 ++++++++++--- Dockerfile.rocm | 13 ++++---- 5 files changed, 67 insertions(+), 45 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 38aff57a410dc..c04e05a994894 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,10 +1,11 @@ -# This script build the ROCm docker image and run the API server inside the container. -# It serves a sanity check for compilation and basic model usage. +# This script build the ROCm docker image and runs test inside it. set -ex # Print ROCm version +echo "--- ROCm info" rocminfo +echo "--- Resetting GPUs" echo "reset" > /opt/amdgpu/etc/gpu_state @@ -16,37 +17,28 @@ while true; do fi done +echo "--- Building container" +sha=$(git rev-parse --short HEAD) +container_name=rocm_${sha} +docker build \ + -t ${container_name} \ + -f Dockerfile.rocm \ + --progress plain \ + . + +remove_docker_container() { + docker rm -f ${container_name} || docker image rm -f ${container_name} || true +} +trap remove_docker_container EXIT +echo "--- Running container" -# Try building the docker image -docker build -t rocm -f Dockerfile.rocm . - -# Setup cleanup -remove_docker_container() { docker rm -f rocm || true; } -trap remove_docker_container EXIT -remove_docker_container - -# Run the image -export HIP_VISIBLE_DEVICES=1 -docker run --device /dev/kfd --device /dev/dri --network host -e HIP_VISIBLE_DEVICES --name rocm rocm python3 -m vllm.entrypoints.api_server & - -# Wait for the server to start -wait_for_server_to_start() { - timeout=300 - counter=0 - - while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do - sleep 1 - counter=$((counter + 1)) - if [ $counter -ge $timeout ]; then - echo "Timeout after $timeout seconds" - break - fi - done -} -wait_for_server_to_start +docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --rm \ + -e HF_TOKEN \ + --name ${container_name} \ + ${container_name} \ + /bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//") -# Test a simple prompt -curl -X POST -H "Content-Type: application/json" \ - localhost:8000/generate \ - -d '{"prompt": "San Francisco is a"}' diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index f6a542afe1a3d..7fbad1c4bd950 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -53,6 +53,11 @@ echo '```' >> benchmark_results.md tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines echo '```' >> benchmark_results.md +# if the agent binary is not found, skip uploading the results, exit 0 +if [ ! -f /workspace/buildkite-agent ]; then + exit 0 +fi + # upload the results to buildkite /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d518fb9ccecfa..e49a5650c44ea 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -20,6 +20,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Core Test + mirror_hardwares: [amd] command: pytest -v -s core - label: Distributed Comm Ops Test @@ -29,7 +30,10 @@ steps: - label: Distributed Tests working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 + + num_gpus: 2 # only support 1 or 2 for now. + mirror_hardwares: [amd] + commands: - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py @@ -44,6 +48,7 @@ steps: - pytest -v -s test_pynccl.py - label: Engine Test + mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test @@ -54,6 +59,7 @@ steps: - label: Examples Test working_dir: "/vllm-workspace/examples" + mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - pip install awscli @@ -67,16 +73,19 @@ steps: parallelism: 4 - label: Models Test + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py - label: Llava Test + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py - label: Prefix Caching Test + mirror_hardwares: [amd] commands: - pytest -v -s prefix_caching @@ -84,12 +93,15 @@ steps: command: pytest -v -s samplers - label: LogitsProcessor Test + mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py - label: Worker Test + mirror_hardwares: [amd] command: pytest -v -s worker - label: Speculative decoding tests + mirror_hardwares: [amd] command: pytest -v -s spec_decode - label: LoRA Test %N @@ -107,6 +119,7 @@ steps: - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" + mirror_hardwares: [amd] commands: - pip install aiohttp - bash run-benchmarks.sh diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 2cb21cacd065b..ea02b6b1e9c9e 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -16,18 +16,29 @@ steps: limit: 5 - wait - - label: "AMD Test" - agents: - queue: amd - command: bash .buildkite/run-amd-test.sh + - group: "AMD Tests" + depends_on: ~ + steps: + {% for step in steps %} + {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} + - label: "AMD: {{ step.label }}" + agents: + queue: amd + command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" + env: + DOCKER_BUILDKIT: "1" + {% endif %} + {% endfor %} - label: "Neuron Test" + depends_on: ~ agents: queue: neuron command: bash .buildkite/run-neuron-test.sh soft_fail: true - - label: "CPU Test" + - label: "Intel Test" + depends_on: ~ command: bash .buildkite/run-cpu-test.sh {% for step in steps %} diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 3f84b949481d1..d04bb9915e2ab 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -46,7 +46,7 @@ RUN apt-get update && apt-get install -y \ ### Mount Point ### # When launching the container, mount the code directory to /app -ARG APP_MOUNT=/app +ARG APP_MOUNT=/vllm-workspace VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} @@ -89,15 +89,16 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ && cd ../..; \ fi -COPY ./ /app/vllm +WORKDIR /vllm-workspace +COPY . . RUN python3 -m pip install --upgrade pip numba -RUN cd /app \ - && cd vllm \ - && pip install -U -r requirements-rocm.txt \ - && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -U -r requirements-rocm.txt \ + && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. RUN python3 -m pip install --upgrade pip From 0f8a91401c89ac0a8018def3756829611b57727f Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 3 May 2024 06:31:20 +0900 Subject: [PATCH 196/413] [Core] Ignore infeasible swap requests. (#4557) --- tests/basic_correctness/test_preemption.py | 85 ++++++++++++++++++++ tests/core/test_block_manager.py | 2 +- tests/core/test_chunked_prefill_scheduler.py | 5 +- tests/core/test_scheduler.py | 30 ++++++- vllm/core/block/cpu_gpu_block_allocator.py | 19 ++--- vllm/core/block/interfaces.py | 21 +++-- vllm/core/block/naive_block.py | 6 +- vllm/core/block/prefix_caching_block.py | 3 + vllm/core/block_manager_v1.py | 19 ++++- vllm/core/block_manager_v2.py | 4 +- vllm/core/interfaces.py | 2 +- vllm/core/scheduler.py | 33 +++++--- 12 files changed, 187 insertions(+), 42 deletions(-) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 1adfc7dddd6fa..ffb0717b3bfdb 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -7,6 +7,7 @@ """ import pytest +from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT) @@ -136,3 +137,87 @@ def test_swap( assert hf_output_ids[j] == vllm_output_ids[j], ( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Verify infeasible swap request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + example_prompts = example_prompts[:1] + + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + block_size=BLOCK_SIZE, + # Since beam search have more than 1 sequence, prefill + decode blocks + # are not enough to finish. + num_gpu_blocks_override=prefill_blocks + decode_blocks, + max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + ) + sampling_params = SamplingParams(n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + assert req_outputs[0].outputs[0].finish_reason == "length" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """Verify infeasible preemption request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + vllm_model = vllm_runner( + model, + dtype=dtype, + block_size=BLOCK_SIZE, + # Not enough gpu blocks to complete a single sequence. + # preemption should happen, and the sequence should be + # ignored instead of hanging forever. + num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, + max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + ) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + for req_output in req_outputs: + outputs = req_output.outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == "length" diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62984ef4caabb..9f9a6180add78 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -224,7 +224,7 @@ def test_swap(): # Swap seq group from CPU -> GPU. cpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_in(seq_group) + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index cce396bf4953c..92498c0014666 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -4,6 +4,7 @@ import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -410,7 +411,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER _, seq_group2 = create_dummy_prompt("2", prompt_length=60) scheduler.add_seq_group(seq_group2) @@ -423,7 +424,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. - scheduler.block_manager.can_swap_in.return_value = True + scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ab471d206618b..1358dffec8104 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in(): # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( @@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 +def test_infeasible_swap(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER + # Since we cannot swap in, none of the requests are swapped in. + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert len(output.infeasible_seq_groups) == 2 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + def test_schedule_swapped_blocks_to_copy(): scheduler = initialize_scheduler() swapped = deque() diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index d25d22cf52838..5b25e1bcdada0 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -110,9 +110,8 @@ def __init__( for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator - def allocate_mutable(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: """Allocates a new mutable block on the specified device. Args: @@ -123,13 +122,10 @@ def allocate_mutable(self, Returns: Block: The newly allocated mutable block. """ - assert device is not None return self._allocators[device].allocate_mutable(prev_block) - def allocate_immutable(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int], device: Device) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. @@ -144,7 +140,6 @@ def allocate_immutable(self, Block: The newly allocated immutable block containing the provided token IDs. """ - assert device is not None return self._allocators[device].allocate_immutable( prev_block, token_ids) @@ -175,7 +170,7 @@ def fork(self, last_block: Block) -> List[Block]: allocator = self._block_ids_to_allocator[block_id] return allocator.fork(last_block) - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + def get_num_free_blocks(self, device: Device) -> int: """Returns the number of free blocks available on the specified device. Args: @@ -185,9 +180,11 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: Returns: int: The number of free blocks available on the specified device. """ - assert device is not None return self._allocators[device].get_num_free_blocks() + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + def clear_copy_on_writes(self) -> Dict[int, List[int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 08d2f87301d92..634c4016ca19c 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -108,6 +108,10 @@ def free(self, block: Block) -> None: def fork(self, last_block: Block) -> List[Block]: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def get_num_free_blocks(self) -> int: pass @@ -152,20 +156,21 @@ class NoFreeBlocksError(ValueError): class DeviceAwareBlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: + pass + + @abstractmethod + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int], device: Device) -> Block: pass @abstractmethod - def allocate_immutable(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def get_num_free_blocks(self, device: Device) -> int: pass @abstractmethod - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + def get_num_total_blocks(self, device: Device) -> int: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 10af129246889..a1b901bf78efc 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -133,10 +133,12 @@ def fork(self, last_block: Block) -> List[Block]: return forked_blocks - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: - assert device is None + def get_num_free_blocks(self) -> int: return len(self._free_block_indices) + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + def _allocate_new_block_id(self) -> BlockId: if not self._free_block_indices: raise BlockAllocator.NoFreeBlocksError() diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e9000c9bfff7f..4a37e8f87c379 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -285,6 +285,9 @@ def get_num_free_blocks(self, device: Optional[Device] = None) -> int: return self._hashless_allocator.get_num_free_blocks( ) + self.evictor.num_blocks + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() + @property def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 4a9a2999e3913..268c5c135d887 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -47,6 +47,10 @@ def free(self, block: PhysicalTokenBlock) -> None: def get_num_free_blocks(self) -> int: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def contains_block(self, block_hash: int) -> bool: pass @@ -131,6 +135,9 @@ def get_num_free_blocks(self) -> int: return (self.num_blocks - self.current_num_blocks + self.evictor.num_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: return block_hash in self.cached_blocks or block_hash in self.evictor @@ -190,6 +197,9 @@ def free(self, block: PhysicalTokenBlock) -> None: def get_num_free_blocks(self) -> int: return len(self.free_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: raise NotImplementedError( "Invalid codepath for uncached block allocator.") @@ -444,7 +454,7 @@ def _get_physical_blocks( def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: + num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) @@ -454,7 +464,12 @@ def can_swap_in(self, # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append_slot(). num_required_blocks = len(blocks) + num_swapped_seqs - return num_free_blocks - num_required_blocks >= self.watermark_blocks + if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: + return AllocStatus.NEVER + elif num_free_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 3fbd8b787cf6c..ce90ce2f17278 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -238,8 +238,8 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_tables[child_seq.seq_id] = src_block_table.fork() def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return False + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> Dict[int, int]: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 56c2c5995c38b..09ccaddb62615 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -63,7 +63,7 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: @abstractmethod def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + num_lookahead_slots: int) -> AllocStatus: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b17b6cc7fe733..7c55b08d4857d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -210,6 +210,8 @@ class SchedulerSwappedInOutputs: blocks_to_copy: Dict[int, List[int]] # The number of slots for lookahead decoding. num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] @classmethod def create_empty(cls) -> "SchedulerSwappedInOutputs": @@ -219,6 +221,7 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs": blocks_to_swap_in={}, blocks_to_copy={}, num_lookahead_slots=0, + infeasible_seq_groups=[], ) @@ -511,14 +514,26 @@ def _schedule_swapped( prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) + infeasible_seq_groups: List[SequenceGroup] = [] leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): + alloc_status = self.block_manager.can_swap_in(seq_group) + if alloc_status == AllocStatus.LATER: break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue lora_int_id = 0 if self.lora_enabled: @@ -569,7 +584,9 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False)) + is_prefill=False), + infeasible_seq_groups=infeasible_seq_groups, + ) def _schedule_prefills( self, @@ -777,7 +794,8 @@ def _schedule_default(self) -> SchedulerOutputs: blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), - ignored_seq_groups=prefills.ignored_seq_groups, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, ) @@ -893,15 +911,6 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def _can_swap_in(self, seq_group: SequenceGroup) -> bool: - # Swapping in is considered decode. - is_prefill = False - - return self.block_manager.can_swap_in( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. # This function call changes the internal states of the scheduler From 344a5d0c332c3945caf336fd1d21f450f1455e6c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 2 May 2024 17:32:33 -0700 Subject: [PATCH 197/413] [Core][Distributed] enable allreduce for multiple tp groups (#4566) --- tests/distributed/test_pynccl.py | 43 +++++++++++++++++++++++++--- vllm/distributed/communication_op.py | 1 - vllm/distributed/parallel_state.py | 36 ++++++++++++++++------- vllm/worker/worker.py | 13 +++++---- 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index e71d839648c83..b6f461b76ed03 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,9 +3,13 @@ import pytest import torch +import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) -from vllm.distributed.parallel_state import init_distributed_environment +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, + init_distributed_environment, with_pynccl_for_all_reduce) from vllm.utils import update_environment_variables @@ -67,7 +71,7 @@ def multiple_tp_worker_fn(): ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] comm = NCCLCommunicator(group=group, device=device) - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: comm.all_reduce(tensor) @@ -81,9 +85,40 @@ def multiple_tp_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 2 GPUs to run the test.") + reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): - distributed_run(worker_fn, 4) + # this tests pynccl for multiple tp groups, in a standalone way + # i.e. call `comm.all_reduce` directly + distributed_run(multiple_tp_worker_fn, 4) + + +@worker_fn_wrapper +def multiple_tp_with_vllm_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + torch.cuda.set_device(torch.distributed.get_rank()) + ensure_model_parallel_initialized(2, 2) + pynccl_utils.init_process_group( + group=get_tensor_model_parallel_cpu_group()) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + with with_pynccl_for_all_reduce(): + # two tp groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = tensor_model_parallel_all_reduce(tensor) + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_tp_with_vllm(): + # this tests pynccl for multiple tp groups, together with vllm + # i.e. call `tensor_model_parallel_all_reduce` + distributed_run(multiple_tp_with_vllm_worker_fn, 4) @worker_fn_wrapper diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 8b2c26c3a8afb..b539a7beedbfe 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: if out is not None: return out if is_pynccl_enabled_for_all_reduce(): - # TODO: support multiple parallel groups. pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a82a1254693df..be5bb4e857caf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -14,7 +14,8 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None +_TP_DEVICE_GROUP = None +_TP_CPU_GROUP = None # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None @@ -132,15 +133,17 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group + _TP_DEVICE_GROUP = group + _TP_CPU_GROUP = cpu_group # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP @@ -185,7 +188,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TENSOR_MODEL_PARALLEL_GROUP is not None + return (_TP_DEVICE_GROUP is not None and _PIPELINE_MODEL_PARALLEL_GROUP is not None) @@ -197,9 +200,16 @@ def get_cpu_world_group(): def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( + assert _TP_DEVICE_GROUP is not None, ( "tensor model parallel group is not initialized") - return _TENSOR_MODEL_PARALLEL_GROUP + return _TP_DEVICE_GROUP + + +def get_tensor_model_parallel_cpu_group(): + """Get the tensor model parallel cpu group the caller rank belongs to.""" + assert _TP_CPU_GROUP is not None, ( + "tensor model parallel cpu group is not initialized") + return _TP_CPU_GROUP def get_pipeline_model_parallel_group(): @@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" - global _TENSOR_MODEL_PARALLEL_GROUP - if _TENSOR_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) - _TENSOR_MODEL_PARALLEL_GROUP = None + global _TP_DEVICE_GROUP + if _TP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_TP_DEVICE_GROUP) + _TP_DEVICE_GROUP = None + global _TP_CPU_GROUP + if _TP_CPU_GROUP: + torch.distributed.destroy_process_group(_TP_CPU_GROUP) + _TP_CPU_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP if _PIPELINE_MODEL_PARALLEL_GROUP: torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 39ad428f16fe3..808261e47318b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,6 +11,7 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, + get_tensor_model_parallel_cpu_group, init_distributed_environment) from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -288,6 +289,9 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + if pynccl_utils.is_initialized(): pynccl_world_size = pynccl_utils.get_world_size() if pynccl_world_size != parallel_config.world_size: @@ -298,12 +302,9 @@ def init_worker_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - # NOTE(kaichao): By default, pynccl will use information inside - # `parallel_state` for initialization. - pynccl_utils.init_process_group() - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + # NOTE(kaichao): By default, pynccl is initialized for tp group. + pynccl_utils.init_process_group( + group=get_tensor_model_parallel_cpu_group()) # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: From 808632d3b4effd3c0807325b529d0354894c31b1 Mon Sep 17 00:00:00 2001 From: "Yang, Bo" Date: Thu, 2 May 2024 18:35:18 -0700 Subject: [PATCH 198/413] [BugFix] Prevent the task of `_force_log` from being garbage collected (#4567) --- vllm/entrypoints/openai/api_server.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8b3c5ea9de9c0..f9e294af47253 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,6 +4,7 @@ import re from contextlib import asynccontextmanager from http import HTTPStatus +from typing import Any, Set import fastapi import uvicorn @@ -33,6 +34,8 @@ openai_serving_completion: OpenAIServingCompletion logger = init_logger(__name__) +_running_tasks: Set[asyncio.Task[Any]] = set() + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): @@ -43,7 +46,9 @@ async def _force_log(): await engine.do_log_stats() if not engine_args.disable_log_stats: - asyncio.create_task(_force_log()) + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) yield From ce3f1eedf8e7e015054a166f17205eb3206e4625 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Fri, 3 May 2024 12:48:08 +0800 Subject: [PATCH 199/413] [Misc] remove chunk detected debug logs (#4571) --- vllm/engine/llm_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 19e7143ac2b45..94a5b397a4d43 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -662,10 +662,10 @@ def _get_stats( # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: num_generation_tokens_from_prefill_groups = 0. - if scheduler_outputs.num_prefill_groups > 0 and len( - scheduler_outputs.scheduled_seq_groups - ) != scheduler_outputs.num_prefill_groups: - print("DETECTED CHUNKED") + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. for idx, scheduled_seq_group in enumerate( scheduler_outputs.scheduled_seq_groups): From 2d7bce9cd5981db146b18a8a95c5a7e0480687bd Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 2 May 2024 22:13:49 -0700 Subject: [PATCH 200/413] [Doc] add env vars to the doc (#4572) --- docs/source/index.rst | 1 + docs/source/serving/env_vars.rst | 9 +++++++++ vllm/envs.py | 7 +++++++ 3 files changed, 17 insertions(+) create mode 100644 docs/source/serving/env_vars.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index e0269987ec5d8..5cc28a2d70139 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,6 +75,7 @@ Documentation serving/deploying_with_docker serving/distributed_serving serving/metrics + serving/env_vars serving/usage_stats serving/integrations diff --git a/docs/source/serving/env_vars.rst b/docs/source/serving/env_vars.rst new file mode 100644 index 0000000000000..0ce1374a3967b --- /dev/null +++ b/docs/source/serving/env_vars.rst @@ -0,0 +1,9 @@ +Environment Variables +======================== + +vLLM uses the following environment variables to configure the system: + +.. literalinclude:: ../../../vllm/envs.py + :language: python + :start-after: begin-env-vars-definition + :end-before: end-env-vars-definition diff --git a/vllm/envs.py b/vllm/envs.py index 26ed731caa5ff..2dbb57e6253a7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -28,6 +28,11 @@ VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + environment_variables: Dict[str, Callable[[], Any]] = { # used in distributed environment to determine the master address 'VLLM_HOST_IP': @@ -148,6 +153,8 @@ lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), } +# end-env-vars-definition + def __getattr__(name): # lazy evaluation of environment variables From 3521ba4f2554bcf246a95a9fb2d1b80990a6835b Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 4 May 2024 02:20:12 +0900 Subject: [PATCH 201/413] [Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518) --- .../kernels/benchmark_paged_attention.py | 25 ++- csrc/attention/attention_kernels.cu | 76 ++++---- csrc/cpu/attention.cpp | 92 +++++----- csrc/ops.h | 8 +- tests/kernels/test_attention.py | 35 ++-- tests/kernels/test_prefix_prefill.py | 16 +- tests/samplers/test_sampler.py | 34 ++-- tests/spec_decode/e2e/conftest.py | 4 +- tests/spec_decode/test_multi_step_worker.py | 24 +-- tests/spec_decode/test_ngram_worker.py | 24 ++- tests/spec_decode/utils.py | 8 +- tests/test_logits_processor.py | 8 +- tests/worker/test_model_runner.py | 99 +++++------ vllm/_custom_ops.py | 18 +- vllm/attention/backends/flash_attn.py | 44 ++--- vllm/attention/backends/rocm_flash_attn.py | 60 +++---- vllm/attention/backends/torch_sdpa.py | 36 ++-- vllm/attention/backends/xformers.py | 65 ++++--- vllm/attention/ops/paged_attn.py | 35 ++-- vllm/config.py | 23 ++- vllm/engine/arg_utils.py | 14 +- vllm/entrypoints/llm.py | 7 +- vllm/model_executor/layers/sampler.py | 6 +- vllm/model_executor/sampling_metadata.py | 63 ++++--- vllm/worker/cpu_model_runner.py | 58 +++--- vllm/worker/model_runner.py | 167 +++++++++--------- vllm/worker/neuron_model_runner.py | 30 ++-- 27 files changed, 554 insertions(+), 525 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 5c3650fa72d17..ca7967c1ab0d2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -16,7 +16,7 @@ def main( version: str, num_seqs: int, - context_len: int, + seq_len: int, num_query_heads: int, num_kv_heads: int, head_size: int, @@ -48,12 +48,12 @@ def main( dtype=torch.float, device=device) - context_lens = [context_len for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) + seq_lens = [seq_len for _ in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -77,8 +77,7 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +109,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -129,9 +128,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -166,7 +165,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--seq_len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", @@ -199,7 +198,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: main( version=args.version, num_seqs=args.batch_size, - context_len=args.context_len, + seq_len=args.seq_len, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_size=args.head_size, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index f3a5bbfd3098d..8b1b5e098015f 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -115,23 +115,23 @@ __device__ void paged_attention_kernel( const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; + const bool mask = token_idx >= seq_len; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -364,14 +364,14 @@ __device__ void paged_attention_kernel( } else { v_vec = *reinterpret_cast(v_ptr + offset); } - if (block_idx == num_context_blocks - 1) { + if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( const float kv_scale) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( const float kv_scale) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -639,8 +639,8 @@ void paged_attention_v1_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -664,11 +664,11 @@ void paged_attention_v1_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); + int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! @@ -715,8 +715,8 @@ void paged_attention_v1_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -746,9 +746,9 @@ void paged_attention_v1( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { @@ -790,7 +790,7 @@ void paged_attention_v1( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -803,7 +803,7 @@ void paged_attention_v1( exp_sums_ptr, \ max_logits_ptr, \ tmp_out_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< @@ -824,8 +824,8 @@ void paged_attention_v2_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -852,10 +852,10 @@ void paged_attention_v2_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -909,8 +909,8 @@ void paged_attention_v2_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -943,9 +943,9 @@ void paged_attention_v2( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 365bbd5e23728..c1d765be05598 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -70,11 +70,11 @@ template FORCE_INLINE std::pair reduceSoftmaxAlibi(T *data, const int size, const int capacity, const float alibi_slope, const int start_index, - const int context_len) { - data[0] += alibi_slope * (start_index - context_len + 1); + const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); data[i] = qk; max = max >= qk ? max : qk; } @@ -225,7 +225,7 @@ struct paged_attention_v1_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -235,32 +235,32 @@ struct paged_attention_v1_impl { static_assert(BLOCK_SIZE == 16); - int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); const int parallel_work_item_num = omp_get_max_threads(); size_t logits_bytes = - parallel_work_item_num * max_context_len_padded * sizeof(float); + parallel_work_item_num * max_seq_len_padded * sizeof(float); float *logits = (float *)std::aligned_alloc( 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_context_len_padded] + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int context_len = context_lens[seq_idx]; + int seq_len = seq_lens[seq_idx]; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const int last_block_token_num = - context_len - (block_num - 1) * BLOCK_SIZE; + seq_len - (block_num - 1) * BLOCK_SIZE; float *__restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_context_len_padded; + logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { @@ -278,11 +278,11 @@ struct paged_attention_v1_impl { // Compute softmax if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, context_len, + reduceSoftmaxAlibi(thread_block_logits, seq_len, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - context_len); + seq_len); } else { - reduceSoftmax(thread_block_logits, context_len, + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } @@ -340,7 +340,7 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); @@ -348,8 +348,8 @@ template void paged_attention_v1_impl_launcher( torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v1_impl_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - context_lens, max_context_len, alibi_slopes); + seq_lens, max_seq_len, alibi_slopes); #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); @@ -448,7 +448,7 @@ struct paged_attention_v2_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -465,22 +465,22 @@ struct paged_attention_v2_impl { for (int partition_idx = 0; partition_idx < max_num_partitions; ++partition_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= context_len) + if (start_token_idx >= seq_len) continue; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; const bool no_reduce = (partition_num == 1); - const int context_token_num = - (std::min(context_len, start_token_idx + PARTITION_SIZE) - + const int token_num = + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); const int block_num = - (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = - context_token_num - (block_num - 1) * BLOCK_SIZE; + token_num - (block_num - 1) * BLOCK_SIZE; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; @@ -507,10 +507,10 @@ struct paged_attention_v2_impl { std::pair max_and_sum; if (alibi_slopes) { max_and_sum = reduceSoftmaxAlibi( - logits, context_token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, context_len); + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, context_token_num, + max_and_sum = reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } @@ -583,9 +583,9 @@ struct paged_attention_v2_impl { #pragma omp parallel for collapse(2) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -612,9 +612,9 @@ struct paged_attention_v2_impl { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -649,7 +649,7 @@ struct paged_attention_v2_impl { paged_attention_v2_impl::call( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); @@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v2_impl_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, block_size, \ - max_context_len, alibi_slopes); + num_kv_heads, scale, block_tables, seq_lens, block_size, \ + max_seq_len, alibi_slopes); #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); diff --git a/csrc/ops.h b/csrc/ops.h index 8ae052427052f..9541adcb3de88 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -10,9 +10,9 @@ void paged_attention_v1( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seq_lens, int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); @@ -28,9 +28,9 @@ void paged_attention_v2( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seq_lens, int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 9b1f3e30b6dca..84539205e0ae3 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: @@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention( num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] - context_len = int(context_lens[i]) + seq_len = int(seq_lens[i]) keys = [] values = [] - for j in range(context_len): + for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len).int() - alibi_bias = (position_ids - context_len + 1).float() + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -149,13 +149,13 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -186,16 +186,15 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, ) elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -218,9 +217,9 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -255,7 +254,7 @@ def test_paged_attention( key_cache, value_cache, block_tables, - context_lens, + seq_lens, scale, alibi_slopes, ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 8ab1167384c45..5a5987e2242fa 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -51,12 +51,12 @@ def test_contexted_kv_attention( cache_size = 640 block_size = 32 max_block_per_request = 64 - subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv - num_tokens = sum(subquery_lens) + num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -75,15 +75,15 @@ def test_contexted_kv_attention( num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -92,7 +92,7 @@ def test_contexted_kv_attention( dtype=torch.long), dim=0) for i in range(BS): - for j in range(subquery_lens[i]): + for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + @@ -178,7 +178,7 @@ def test_contexted_kv_attention( value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - subquery_lens, seq_lens) + query_lens, seq_lens) if sliding_window > 0: attn_bias = attn_bias.make_local_attention_from_bottomright( sliding_window) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 7859f0b21812f..e4fea165a4d46 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -58,7 +58,7 @@ def _do_sample( device: str, ): seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -68,12 +68,12 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -421,7 +421,7 @@ def run_test_case(*, "Invalid test case, need seq_group_metadata_list" batch_size = 0 - prompt_lens = [] + seq_lens = [] sampling_params_per_row = [] for sgm in seq_group_metadata_list: sampling_params = sgm.sampling_params @@ -431,7 +431,7 @@ def run_test_case(*, # a prompt seq_group has only one sequence seq_data = next(iter(sgm.seq_data.values())) prompt_len = seq_data.get_prompt_len() - prompt_lens.append(prompt_len) + seq_lens.append(prompt_len) if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in @@ -451,8 +451,8 @@ def run_test_case(*, _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens=prompt_lens if prompt_lens else None, - subquery_lens=prompt_lens if prompt_lens else None, + seq_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else None, device=device, pin_memory=model_runner.pin_memory) # the logits tensor is modified in-place by the sampler @@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str): seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): expected: Optional[List[int]] = None sampling_type = random.randint(0, 3) @@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str): sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) sampler_output = sampler(logits=fake_logits, @@ -575,7 +575,7 @@ def test_sampling(model_runner: ModelRunner): # Shuffle the batch and resample target_index = list(range(batch_size)) for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, prompt_lens): + expected_tokens, seq_lens): random.Random(seed).shuffle(list_to_shuffle) target_index = torch.tensor(target_index) input_tensor.data = input_tensor.index_select(0, target_index) @@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): ), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 0eb784a9c5ac5..492620cf6e2cf 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -45,7 +45,7 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -66,7 +66,7 @@ def __init__( gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, engine_use_ray=True, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 98f2731de9aa3..cc0427633e688 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int): list(range(block_size * 2)), ] - final_seq_lens = [ + final_prompt_lens = [ len(prompt + output) + num_steps for prompt, output in zip(prompts, prev_output_tokens) ] @@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int): prompts, num_gpu_blocks, block_size, - final_seq_lens, + final_prompt_lens, continuations=prev_output_tokens) assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access @@ -103,17 +103,21 @@ def test_same_output_for_single_step(): [6, 7, 8, 9, 10], ] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] multi_step_execute_model_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) single_step_execute_model_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) @@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): random.randint(0, 1000) for _ in range(random.randint(10, 20)) ] for _ in range(10)] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( @@ -195,7 +199,7 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_seq_lens=final_seq_lens), ) + final_prompt_lens=final_prompt_lens), ) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) @@ -217,7 +221,7 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_seq_lens=final_seq_lens)) + final_prompt_lens=final_prompt_lens)) single_step_output.extend( worker.execute_model(**execute_model_data.to_dict(), )) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index ee4135015713d..e7e2e87f599dd 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), @@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), @@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4f8295d25cf41..87c7d88a80f42 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts( prompts: List[List[int]], num_gpu_blocks: int, block_size: int, - final_seq_lens: List[int], + final_prompt_lens: List[int], continuations: Optional[List[List[int]]] = None, seq_ids: Optional[List[int]] = None, ) -> List[SequenceGroupMetadata]: @@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts( free_gpu_blocks.pop() for _ in range(round_up_to_next_block(final_len, block_size)) ] - for i, final_len in enumerate(final_seq_lens) + for i, final_len in enumerate(final_prompt_lens) } return [ @@ -251,13 +251,13 @@ def create_batch(batch_size, prev_output_tokens = [[ next(iterator) for _ in range(prev_output_token_len) ] for _ in range(batch_size)] - final_seq_lens = [ + final_prompt_lens = [ len(prompt) + len(prev_output_token) + k + 1 for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] execute_model_data = create_execute_model_data( create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_seq_lens, + block_size, final_prompt_lens, prev_output_tokens, seq_ids), ) return execute_model_data, prompts, prev_output_tokens diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index dbaeb4de18258..179e8d25a341b 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,7 +70,7 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -81,12 +81,12 @@ def pick_ith(token_ids, logits): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) logits_processor_output = logits_processor( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 56fe6db589f18..e7975d0ef48b9 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_prompt_lens == prompt_lens + seq_len - 1) + selected_token_start_idx += seq_len + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is True - assert torch.allclose(attn_metadata.prompt_lens_tensor, - torch.tensor(prompt_lens, device=device)) - assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.max_prompt_len == max(prompt_lens) + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_seq_len == max(seq_lens) # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( attn_metadata.subquery_start_loc, @@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size): # equivalent to subquery_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len seq_start_loc.append(start_idx) assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - assert attn_metadata.max_context_len is None assert torch.allclose( - attn_metadata.context_lens, - torch.zeros(attn_metadata.context_lens.shape[0], + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device)) @@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size): # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = list(range(seq_len)) seq_data = SequenceData(seq_data) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False - assert attn_metadata.prompt_lens is None - assert attn_metadata.max_prompt_len is None + assert attn_metadata.seq_lens is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_context_len == max(prompt_lens) + assert attn_metadata.max_seq_len == max(seq_lens) assert torch.allclose( - attn_metadata.context_lens[:len(prompt_lens)], - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + attn_metadata.seq_lens_tensor[:len(seq_lens)], + torch.tensor(seq_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -241,14 +239,13 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_prompt_lens) == 0 + assert len(return_seq_lens) == 0 @pytest.fixture @@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_runner.set_block_size(16) # Add prefill requests. - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] prefill_metadata_list = [] decode_metadata_list = [] @@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_batch_size = batch_size - prefill_batch_size for i in range(prefill_batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(prompt_len)) + seq_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(seq_len)) seq_data = SequenceData(prompt_toks) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): else: assert attn_metadata.num_decode_tokens == _get_graph_batch_size( decode_batch_size) - assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3faed5ea85307..b43f646fec88e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -39,17 +39,17 @@ def paged_attention_v1( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_context_len: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, - context_lens, block_size, max_context_len, - alibi_slopes, kv_cache_dtype, kv_scale) + num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, kv_scale) def paged_attention_v2( @@ -63,17 +63,17 @@ def paged_attention_v2( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_context_len: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes, kv_cache_dtype, + block_tables, seq_lens, block_size, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 10b8c19b7499e..fc7501ed5e91f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -223,8 +223,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -245,9 +245,9 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], ) @@ -258,8 +258,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 3bc436315c3de..c411b3971b8f1 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] class ROCmFlashAttentionImpl(AttentionImpl): @@ -247,7 +247,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.prompt_lens is not None + assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -260,8 +260,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, + prefill_meta.max_seq_len, + prefill_meta.max_seq_len, True, self.scale, ) @@ -274,7 +274,7 @@ def forward( query, key, value, - prefill_meta.prompt_lens, + prefill_meta.seq_lens, self.scale, ) else: @@ -284,8 +284,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, ) @@ -303,9 +303,9 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], ) @@ -317,8 +317,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -334,13 +334,13 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - prompt_lens: List[int], + seq_lens: List[int], scale: float, ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(prompt_lens): - end = start + prompt_len + for _, seq_len in enumerate(seq_lens): + end = start + seq_len out = _naive_masked_attention( query[start:end], key[start:end], @@ -349,7 +349,7 @@ def _naive_attention( ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) - start += prompt_len + start += seq_len return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 55a7ce59ac6e0..f75a279086a26 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - prompt_lens: Optional[List[int]] + seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -136,7 +136,7 @@ def forward( kv_scale) if attn_metadata.is_prompt: - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -147,13 +147,13 @@ def forward( if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) + att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -164,9 +164,9 @@ def forward( output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): - end = start + prompt_len + for seq_len, mask in zip(attn_metadata.seq_lens, + attn_metadata.attn_bias): + end = start + seq_len sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], @@ -189,8 +189,8 @@ def forward( key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + attn_metadata.seq_lens_tensor, + attn_metadata.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -205,13 +205,13 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -221,7 +221,7 @@ def _make_alibi_bias( bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) @@ -229,14 +229,14 @@ def _make_alibi_bias( def _make_sliding_window_bias( - prompt_lens: List[int], + seq_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: + for seq_len in seq_lens: tensor = torch.full( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=dtype, fill_value=1, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dc64ac0bf985d..60f6d43f2eaa4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -242,9 +241,9 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, ) @@ -257,8 +256,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -289,7 +288,7 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -310,7 +309,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.prompt_lens) + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -318,7 +317,7 @@ def _run_memory_efficient_xformers_forward( else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.prompt_lens) + attn_metadata.seq_lens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -343,8 +342,8 @@ def _run_memory_efficient_xformers_forward( # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) start = 0 - for i, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], @@ -354,7 +353,7 @@ def _run_memory_efficient_xformers_forward( scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.view_as(original_query[start:end])) - start += prompt_len + start += seq_len return output @@ -362,13 +361,13 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -376,16 +375,16 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = (prompt_len + 7) // 8 * 8 + padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, - prompt_len, + seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :prompt_len].copy_(bias) + )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c20b94ac8315b..00a0f10c0950b 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,12 +13,11 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -85,8 +84,8 @@ def forward_decode( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, + seq_lens: torch.Tensor, + max_seq_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, @@ -97,7 +96,7 @@ def forward_decode( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -106,7 +105,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_context_len <= 8192 + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: # Run PagedAttention V1. @@ -118,9 +117,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -150,9 +149,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -168,9 +167,9 @@ def forward_prefix( value_cache: torch.Tensor, block_tables: torch.Tensor, subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, - max_subquery_len: int, + max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], ) -> torch.Tensor: @@ -185,9 +184,9 @@ def forward_prefix( block_tables, # subquery_start_loc is (batch_size + 1,) subquery_start_loc[:-1], - prompt_lens_tensor, + seq_lens_tensor, context_lens, - max_subquery_len, + max_query_len, alibi_slopes, sliding_window, ) diff --git a/vllm/config.py b/vllm/config.py index aaa2f60739d55..3bdd3f774bc27 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -63,7 +63,10 @@ class ModelConfig: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. """ @@ -84,6 +87,7 @@ def __init__( quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, skip_tokenizer_init: bool = False, ) -> None: @@ -99,6 +103,11 @@ def __init__( self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + if self.max_context_len_to_capture is not None: + raise ValueError("`max_context_len_to_capture` is deprecated. " + "Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = (max_seq_len_to_capture + or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -190,10 +199,10 @@ def _verify_quantization(self) -> None: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_context_len_to_capture is None: - self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) def verify_with_parallel_config( self, @@ -772,8 +781,8 @@ def maybe_create_spec_config( max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, + max_seq_len_to_capture=target_model_config. + max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7637616ae6089..1c8e1079bed58 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,7 +44,8 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False - max_context_len_to_capture: int = 8192 + max_context_len_to_capture: Optional[int] = None + max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" @@ -322,6 +323,14 @@ def add_cli_args( default=EngineArgs.max_context_len_to_capture, help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + '(DEPRECATED. Use --max-seq_len-to-capture instead' + ')') + parser.add_argument('--max-seq_len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' + 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', @@ -492,7 +501,8 @@ def create_engine_config(self, ) -> EngineConfig: self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init) + self.max_seq_len_to_capture, self.max_logprobs, + self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b022707794a78..3ed660e183360 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,6 +69,9 @@ class LLM: disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig @@ -90,7 +93,8 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -112,6 +116,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d79c99e5d0a45..2de7763605dfc 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: assert seq_group.is_prompt, ( "Caller should ensure the sequence group is in a prefill stage.") seq_ids = seq_group.seq_ids - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None # prompt has only 1 seq id. assert len(seq_ids) == 1 seq_data = seq_group.seq_data[seq_ids[0]] @@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: prompt_tokens = seq_data.prompt_token_ids # +1 because we are looking for a next prompt token. next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + subquery_len + 1, + next_token_index_end = min(computed_len + query_len + 1, len(prompt_tokens)) next_prompt_tokens = prompt_tokens[ next_token_index_start:next_token_index_end] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 12156b2ba1aa2..9969c45963e9a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -16,17 +16,26 @@ @dataclass class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + # Sequence ids for the sequence group in a previous step. seq_ids: List[int] sampling_params: SamplingParams # seq_id -> sequence data. seq_data: Dict[int, SequenceData] - # The length of the prompt of the sequence group. None if it is in a decode + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode # stage. - prompt_len: Optional[int] - # The length of the query tokens to compute in the current step. None if it - # is in a decode stage. The length of subquery_len <= prompt_len. - subquery_len: Optional[int] + seq_len: Optional[int] + # The length of new query tokens to compute in the current step. None if it + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. + query_len: Optional[int] # A random number generator for sampling. generator: Optional[torch.Generator] # True if the sequence group is in prefill stage. False if it is in a @@ -46,8 +55,8 @@ def __post_init__(self): if len(self.prompt_logprob_indices) > 0: assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: - assert self.prompt_len is not None - assert self.subquery_len is not None + assert self.seq_len is not None + assert self.query_len is not None class SamplingMetadata: @@ -94,8 +103,8 @@ def __init__( @staticmethod def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, pin_memory: bool, ) -> "SamplingMetadata": @@ -104,8 +113,8 @@ def prepare( selected_token_indices, categorized_sample_indices, num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, - subquery_lens, device) + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, @@ -137,8 +146,8 @@ def __repr__(self) -> str: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ SamplingType, List[Tuple[int, int]]], int]: @@ -146,9 +155,9 @@ def _prepare_seq_groups( Args: seq_group_metadata_list: A list of sequence group to batch. - prompt_lens: A list of prompt lens per sequence group. + seq_lens: A list of sequence lens per sequence group. Index of prompt len should match with seq_group_metadata_list. - subquery_lens: A list of query lengths. Prompt lens include the length + query_lens: A list of query lengths. Prompt lens include the length of entire prompt tokens, and it could be shorter. device: A device to use for random number generator, `SequenceGroupToSample.generator`. @@ -189,8 +198,8 @@ def _prepare_seq_groups( is_prompt = seq_group_metadata.is_prompt generator: Optional[torch.Generator] = None # If the current seq group is in decode stage, it is None. - prompt_len: Optional[int] = None - subquery_len: Optional[int] = None + seq_len: Optional[int] = None + query_len: Optional[int] = None prompt_logprob_indices: List[int] = [] sample_indices: List[int] = [] do_sample = seq_group_metadata.do_sample @@ -203,12 +212,12 @@ def _prepare_seq_groups( num_prompts += 1 num_prefill_sample = len(seq_ids) assert num_prefill_sample == 1 - assert subquery_lens is not None and prompt_lens is not None - subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] + assert query_lens is not None and seq_lens is not None + query_len, seq_len = query_lens[i], seq_lens[i] # If we need sampling, exclude num_prefill_sample tokens from # prompt logprob. - prompt_logprob_len = (subquery_len - num_prefill_sample - if do_sample else subquery_len) + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) sample_len = num_prefill_sample if do_sample else 0 else: # Decode @@ -267,8 +276,8 @@ def sample(logits): seq_ids=seq_ids, sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, - prompt_len=prompt_len, - subquery_len=subquery_len, + seq_len=seq_len, + query_len=query_len, generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), @@ -367,8 +376,8 @@ def from_sampling_metadata( and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None prefill_len = len(seq_group.prompt_logprob_indices) temperatures += [temperature] * prefill_len top_ps += [top_p] * prefill_len @@ -397,8 +406,8 @@ def from_sampling_metadata( if is_prompt: prompt_best_of.append(sampling_params.best_of) - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 34d7d3dffea18..193b021b7a11e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -80,7 +80,7 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -92,15 +92,15 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() - prompt_len = len(prompt_tokens) + seq_len = len(prompt_tokens) - prompt_lens.append(prompt_len) # Prompt token num + seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prompt_len))) + input_positions.extend(list(range(computed_len, seq_len))) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -109,15 +109,15 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, prompt_len): + for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -151,19 +151,19 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - num_prefills=len(prompt_lens), + seq_lens=seq_lens, + seq_lens_tensor=None, + max_seq_len=None, + num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, decode_metadata=None, - max_context_len=None, - context_lens=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, + return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) def _prepare_decode( @@ -174,7 +174,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: @@ -192,9 +192,9 @@ def _prepare_decode( position = seq_len - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( + seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -208,7 +208,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_context_len = max(context_lens) + max_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -219,9 +219,9 @@ def _prepare_decode( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -236,14 +236,14 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_seq_len=max_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), - max_context_len=max_context_len, num_prefills=0, prefill_metadata=None, decode_metadata=None, - context_lens=context_lens, block_tables=block_tables, kv_cache_dtype=self.kv_cache_dtype, ) @@ -265,20 +265,20 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, + (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, pin_memory=False) # Broadcast the metadata. @@ -300,7 +300,7 @@ def prepare_input_tensors( sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, - prompt_lens=None, + seq_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0704f5fec54d0..bbb1f5205af5e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple): input_tokens: List[int] input_positions: List[int] attn_metadata: Optional[AttentionMetadataPerStage] - prompt_lens: List[int] - subquery_lens: List[int] + seq_lens: List[int] + query_lens: List[int] lora_index_mapping: List[int] lora_prompt_mapping: List[int] lora_requests: Set[LoRARequest] @@ -56,8 +56,8 @@ def empty(cls): input_tokens=[], input_positions=[], attn_metadata=None, - prompt_lens=[], - subquery_lens=[], + seq_lens=[], + query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), @@ -134,9 +134,8 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.max_context_len_to_capture = ( - self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture + if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -149,7 +148,7 @@ def __init__( self.model: torch.nn.Module # Set after load_model self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in + # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be @@ -218,7 +217,7 @@ def set_block_size(self, block_size: int) -> None: def get_max_block_per_batch(self) -> int: block_size = self.block_size - return (self.max_context_len_to_capture + block_size - 1) // block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -231,9 +230,9 @@ def _prepare_prompt( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - prompt_lens: List[int] = [] + seq_lens: List[int] = [] context_lens: List[int] = [] - subquery_lens: List[int] = [] + query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] @@ -257,21 +256,19 @@ def _prepare_prompt( token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - computed_len = seq_data.get_num_computed_tokens() + context_len = seq_data.get_num_computed_tokens() # We should use get_len here because in case of preemption # it contains output tokens. - prefill_end = min(seq_data.get_len(), - computed_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = prefill_end - prompt_lens.append(prompt_len) + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[computed_len:] + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: @@ -285,25 +282,25 @@ def _prepare_prompt( prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 + assert context_len == 0 # actual prompt lens - context_lens.append(computed_len) - subquery_lens.append(prompt_len - computed_len) + context_lens.append(context_len) + query_lens.append(seq_len - context_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prefill_end))) + input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (prompt_len - computed_len) + lora_index_mapping += [lora_id] * (seq_len - context_len) lora_prompt_mapping.extend( [lora_id] * - (prompt_len - computed_len + (seq_len - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: @@ -313,24 +310,24 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert computed_len == 0, ( + assert context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, prefill_end): + for i in range(context_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -340,9 +337,9 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - max_subquery_len = max(subquery_lens) - max_prompt_len = max(prompt_lens) - assert max_subquery_len > 0 + max_query_len = max(query_lens) + max_seq_len = max(seq_lens) + assert max_query_len > 0 context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -369,40 +366,39 @@ def _prepare_prompt( # Query length can be shorter than key (i.e., prompt) when prefill # is chunked or prefix cached. - subquery_lens_tensor = torch.tensor(subquery_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device=self.device) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - torch.cumsum(subquery_lens_tensor, + torch.cumsum(query_lens_tensor, dim=0, dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, - max_subquery_len=max_subquery_len, - max_context_len=None, - max_prompt_len=max_prompt_len, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_seq_len=max_seq_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -411,8 +407,8 @@ def _prepare_prompt( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, - prompt_lens=prompt_lens, - subquery_lens=subquery_lens, + seq_lens=seq_lens, + query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, @@ -427,7 +423,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] @@ -455,9 +451,9 @@ def _prepare_decode( position = seq_len - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( + seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -477,11 +473,10 @@ def _prepare_decode( # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) - max_context_len = max(context_lens) - use_captured_graph = ( - not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + max_seq_len = max(seq_lens) + use_captured_graph = (not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -489,21 +484,21 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens_tensor.shape[0] == len(input_tokens) - assert context_lens_tensor.shape[0] == len(input_positions) - assert context_lens_tensor.shape[0] == len(slot_mapping) + assert seq_lens_tensor.shape[0] == len(input_tokens) + assert seq_lens_tensor.shape[0] == len(input_positions) + assert seq_lens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -525,14 +520,13 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=max_context_len, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_query_len=None, + max_seq_len=max_seq_len, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens_tensor, + context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) @@ -565,8 +559,8 @@ def prepare_input_tensors( input_tokens, input_positions, prefill_attn_metadata, - prompt_lens, - subquery_lens, + seq_lens, + query_lens, lora_index_mapping, lora_prompt_mapping, lora_requests, @@ -583,13 +577,13 @@ def prepare_input_tensors( decode_slot_mapping, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, prompt_lens, subquery_lens, - self.device, self.pin_memory) + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 - num_prefills = len(prompt_lens) + num_prefills = len(seq_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) @@ -886,7 +880,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -908,14 +902,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # Create dummy attn_metadata. decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=self.max_context_len_to_capture, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_seq_len=self.max_seq_len_to_capture, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens[:batch_size], + context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) @@ -1025,7 +1018,7 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.decode_metadata.context_lens, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} @@ -1047,8 +1040,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_( - attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a974e85c22f45..a336be04e124f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -52,7 +52,7 @@ def _prepare_prompt( input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -61,26 +61,26 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_positions.append(list(range(seq_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_prompt_len = max(prompt_lens) - assert max_prompt_len > 0 + max_seq_len = max(seq_lens) + assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) @@ -88,7 +88,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, prompt_lens + return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, @@ -149,18 +149,18 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, self.pin_memory) From 7e65477e5e737927c2f07c913ede0763134504a3 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 3 May 2024 13:32:21 -0400 Subject: [PATCH 202/413] [Bugfix] Allow "None" or "" to be passed to CLI for string args that default to None (#4586) --- vllm/engine/arg_utils.py | 32 +++++++++++++++++------------ vllm/entrypoints/openai/cli_args.py | 27 +++++++++++++----------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1c8e1079bed58..78cd07575f17d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -11,6 +11,12 @@ from vllm.utils import str_to_int_tuple +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -96,7 +102,7 @@ def add_cli_args( help='Name or path of the huggingface model to use.') parser.add_argument( '--tokenizer', - type=str, + type=nullable_str, default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use.') parser.add_argument( @@ -105,21 +111,21 @@ def add_cli_args( help='Skip initialization of tokenizer and detokenizer') parser.add_argument( '--revision', - type=str, + type=nullable_str, default=None, help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=str, + type=nullable_str, default=None, help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=str, + type=nullable_str, default=None, help='The specific tokenizer version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' @@ -136,7 +142,7 @@ def add_cli_args( action='store_true', help='Trust remote code from huggingface.') parser.add_argument('--download-dir', - type=str, + type=nullable_str, default=EngineArgs.download_dir, help='Directory to download and load the weights, ' 'default to the default cache dir of ' @@ -187,7 +193,7 @@ def add_cli_args( 'supported for common inference criteria.') parser.add_argument( '--quantization-param-path', - type=str, + type=nullable_str, default=None, help='Path to the JSON file containing the KV cache ' 'scaling factors. This should generally be supplied, when ' @@ -304,7 +310,7 @@ def add_cli_args( # Quantization settings. parser.add_argument('--quantization', '-q', - type=str, + type=nullable_str, choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' @@ -349,7 +355,7 @@ def add_cli_args( 'asynchronous tokenization. Ignored ' 'if tokenizer_pool_size is 0.') parser.add_argument('--tokenizer-pool-extra-config', - type=str, + type=nullable_str, default=EngineArgs.tokenizer_pool_extra_config, help='Extra config for tokenizer pool. ' 'This should be a JSON string that will be ' @@ -404,7 +410,7 @@ def add_cli_args( # Related to Vision-language models such as llava parser.add_argument( '--image-input-type', - type=str, + type=nullable_str, default=None, choices=[ t.name.lower() for t in VisionLanguageConfig.ImageInputType @@ -417,7 +423,7 @@ def add_cli_args( help=('Input id for image token.')) parser.add_argument( '--image-input-shape', - type=str, + type=nullable_str, default=None, help=('The biggest image input shape (worst for memory footprint) ' 'given an input type. Only used for vLLM\'s profile_run.')) @@ -440,7 +446,7 @@ def add_cli_args( parser.add_argument( '--speculative-model', - type=str, + type=nullable_str, default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') @@ -454,7 +460,7 @@ def add_cli_args( parser.add_argument( '--speculative-max-model-len', - type=str, + type=int, default=EngineArgs.speculative_max_model_len, help='The maximum sequence length supported by the ' 'draft model. Sequences over this length will skip ' @@ -475,7 +481,7 @@ def add_cli_args( 'decoding.') parser.add_argument('--model-loader-extra-config', - type=str, + type=nullable_str, default=EngineArgs.model_loader_extra_config, help='Extra config for model loader. ' 'This will be passed to the model loader ' diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 16c5b6c08d37f..2b57ab26bfd31 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -8,7 +8,7 @@ import json import ssl -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import LoRAModulePath @@ -25,7 +25,10 @@ def __call__(self, parser, namespace, values, option_string=None): def make_arg_parser(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument( "--uvicorn-log-level", @@ -49,13 +52,13 @@ def make_arg_parser(): default=["*"], help="allowed headers") parser.add_argument("--api-key", - type=str, + type=nullable_str, default=None, help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument("--served-model-name", nargs="+", - type=str, + type=nullable_str, default=None, help="The model name(s) used in the API. If multiple " "names are provided, the server will respond to any " @@ -65,33 +68,33 @@ def make_arg_parser(): "same as the `--model` argument.") parser.add_argument( "--lora-modules", - type=str, + type=nullable_str, default=None, nargs='+', action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") parser.add_argument("--chat-template", - type=str, + type=nullable_str, default=None, help="The file path to the chat template, " "or the template in single-line form " "for the specified model") parser.add_argument("--response-role", - type=str, + type=nullable_str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL key file") parser.add_argument("--ssl-certfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL cert file") parser.add_argument("--ssl-ca-certs", - type=str, + type=nullable_str, default=None, help="The CA certificates file") parser.add_argument( @@ -102,12 +105,12 @@ def make_arg_parser(): ) parser.add_argument( "--root-path", - type=str, + type=nullable_str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") parser.add_argument( "--middleware", - type=str, + type=nullable_str, action="append", default=[], help="Additional ASGI middleware to apply to the app. " From f8e7adda21810104382bdf3febe3ea02c72f7348 Mon Sep 17 00:00:00 2001 From: Sebastian Schoennenbeck Date: Fri, 3 May 2024 20:04:14 +0200 Subject: [PATCH 203/413] Fix/async chat serving (#2727) --- tests/async_engine/test_chat_template.py | 25 +++++++------ tests/entrypoints/openai/test_serving_chat.py | 37 +++++++++++++++++++ tests/entrypoints/test_openai_server.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 12 ++++-- vllm/entrypoints/openai/serving_engine.py | 18 ++++++--- 5 files changed, 73 insertions(+), 21 deletions(-) create mode 100644 tests/entrypoints/openai/test_serving_chat.py diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 8d6ad6706fb0e..64bcba67c3437 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -60,12 +60,13 @@ class MockServingChat: tokenizer: MockTokenizer -def test_load_chat_template(): +@pytest.mark.asyncio +async def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=chatml_jinja_path) + await OpenAIServingChat._load_chat_template( + mock_serving_chat, chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -76,7 +77,8 @@ def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -def test_no_load_chat_template_filelike(): +@pytest.mark.asyncio +async def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() @@ -84,18 +86,19 @@ def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) -def test_no_load_chat_template_literallike(): +@pytest.mark.asyncio +async def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template assert template_content == template @@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py new file mode 100644 index 0000000000000..269b0823fec05 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -0,0 +1,37 @@ +import asyncio +from dataclasses import dataclass + +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + +MODEL_NAME = "openai-community/gpt2" +CHAT_TEMPLATE = "Dummy chat template for testing {}" + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + + +@dataclass +class MockEngine: + + async def get_model_config(self): + return MockModelConfig + + +async def _async_serving_chat_init(): + serving_completion = OpenAIServingChat(MockEngine(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE) + return serving_completion + + +def test_async_serving_chat_init(): + serving_completion = asyncio.run(_async_serving_chat_init()) + assert serving_completion.tokenizer is not None + assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1323dba469117..e53e64a0c1ff8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -150,7 +150,7 @@ def server(zephyr_lora_files): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 599f99e56a726..c8f4a6b315db0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,3 +1,4 @@ +import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, @@ -40,9 +41,11 @@ def __init__(self, chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + await_post_init=self._load_chat_template( + chat_template=chat_template)) + self.response_role = response_role - self._load_chat_template(chat_template) def _parse_chat_message_content( self, @@ -356,7 +359,10 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template: Optional[str]): + async def _load_chat_template(self, chat_template: Optional[str]): + while self.tokenizer is None: + # Give the parent class time to load the tokenizer + await asyncio.sleep(0.1) tokenizer = self.tokenizer if chat_template is not None: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3d5ed328b9d19..21baea2e5e7f6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -29,8 +29,11 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__(self, + engine: AsyncLLMEngine, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + await_post_init: Optional[Awaitable[Any]] = None): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: @@ -56,12 +59,12 @@ def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], if event_loop is not None and event_loop.is_running(): # If the current is instanced by Ray Serve, # there is already a running event loop - event_loop.create_task(self._post_init()) + event_loop.create_task(self._post_init(await_post_init)) else: # When using single vLLM without engine_use_ray - asyncio.run(self._post_init()) + asyncio.run(self._post_init(await_post_init)) - async def _post_init(self): + async def _post_init(self, await_post_init): engine_model_config = await self.engine.get_model_config() self.max_model_len = engine_model_config.max_model_len @@ -73,6 +76,9 @@ async def _post_init(self): trust_remote_code=engine_model_config.trust_remote_code, truncation_side="left") + if await_post_init is not None: + await await_post_init + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ From 43c413ec570e94869ee7b7d275de720219a34357 Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Fri, 3 May 2024 15:51:27 -0700 Subject: [PATCH 204/413] [Kernel] Use flashinfer for decoding (#4353) Co-authored-by: LiuXiaoxuanPKU --- csrc/cache.h | 8 + csrc/cache_kernels.cu | 80 +++++++ csrc/pybind.cpp | 4 + .../test_basic_correctness.py | 12 +- .../test_basic_distributed_correctness.py | 14 +- tests/kernels/conftest.py | 8 +- tests/kernels/test_cache.py | 77 ++++++ vllm/_custom_ops.py | 12 + vllm/attention/backends/abstract.py | 13 +- vllm/attention/backends/flashinfer.py | 220 ++++++++++++++++++ vllm/attention/selector.py | 6 + vllm/config.py | 5 + vllm/sequence.py | 4 +- vllm/utils.py | 67 ++++-- vllm/worker/model_runner.py | 123 +++++++--- 15 files changed, 600 insertions(+), 53 deletions(-) create mode 100644 vllm/attention/backends/flashinfer.py diff --git a/csrc/cache.h b/csrc/cache.h index 718a5f6cfd7f7..4c142ce17f1b9 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -24,6 +24,14 @@ void reshape_and_cache( const std::string& kv_cache_dtype, const float kv_scale); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); + // Just for unittest void convert_fp8( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 24aaa2ff3e263..42f884c76c620 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel( } } +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + k_cache[tgt_value_idx] = key[src_key_idx]; + v_cache[tgt_value_idx] = value[src_value_idx]; + } +} } // namespace vllm #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ @@ -275,6 +310,51 @@ void reshape_and_cache( } } +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) +{ + // FIXME: only support auto datatype, does not support fp8 + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = k_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + k_cache.data_ptr(), + v_cache.data_ptr(), + slot_mapping.data_ptr(), + block_stride, + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + namespace vllm { template diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 9839bfc0331c4..173e0b1732e13 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); cache_ops.def( "convert_fp8", &convert_fp8, diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..d75279dd9cfa9 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,12 +2,15 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ +import os + import pytest MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.parametrize("model", MODELS) @@ -23,11 +26,18 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER" and enforce_eager is False: + pytest.skip("Skipping non-eager test for FlashInferBackend.") + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_model = vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 77aa90b12bf8f..527452630c9f5 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -18,6 +18,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -33,16 +34,19 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + enforce_eager = False + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER": + enforce_eager = True hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - ) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index d26da2c7fe4ee..4f2f9cc3dac7d 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,8 +1,14 @@ import pytest -from vllm.utils import create_kv_caches_with_random +from vllm.utils import (create_kv_caches_with_random, + create_kv_caches_with_random_flash) @pytest.fixture() def kv_cache_factory(): return create_kv_caches_with_random + + +@pytest.fixture() +def kv_cache_factory_flashinfer(): + return create_kv_caches_with_random_flash diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d1051fd7e2f4d..ca215bb75837a 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm._C import cache_ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -191,6 +192,82 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory_flashinfer, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8": + pytest.skip() + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_flashinfer( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b43f646fec88e..5b56437487477 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -222,6 +222,18 @@ def reshape_and_cache( slot_mapping, kv_cache_dtype, kv_scale) +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, +) -> None: + vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index be747c9900368..61c9c81d8a7b8 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) import torch @@ -15,7 +16,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadata": + def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": raise NotImplementedError @staticmethod @@ -50,13 +51,17 @@ def copy_blocks( class AttentionMetadataPerStage: """Attention metadata for a specific stage. I.e., prefill or decode.""" - def asdict_zerocopy(self) -> Dict[str, Any]: + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() # Note that if we add dataclasses as fields, they will need # similar handling. return { field.name: getattr(self, field.name) - for field in fields(self) + for field in fields(self) if field.name not in skip_fields } diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000000000..8ab4b1f12ee36 --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,220 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +try: + import flashinfer + from flash_attn import flash_attn_varlen_func + from flashinfer import BatchDecodeWithPagedKVCacheWrapper +except ImportError: + flashinfer = None + flash_attn_varlen_func = None + BatchDecodeWithPagedKVCacheWrapper = None + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataPerStage) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashInferMetadata": + return FlashInferMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + raise NotImplementedError + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + +@dataclass +class FlashInferMetadata(AttentionMetadataPerStage): + + is_prompt: bool + + use_cuda_graph: bool = False + + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage since we still + # use flash attention for prefill. + seq_start_loc: Optional[torch.Tensor] = None + max_seq_len: Optional[int] = None + block_tables: Optional[torch.Tensor] = None + + # Metadata for the decode stage + # Workspace buffer required by the kernel, the buffer should not + # be allocated/deacollated by the FalshInfermetadata object. + workspace_buffer: Optional[torch.Tensor] = None + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + # When using flashinfer, we are also creating the FlashInferMetadata, + # which will also call post_init by default, here we want to skip the + # post_init if it's the prefill phase. + if not self.is_prompt: + self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + self.decode_wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + data_type=self.data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.alibi_slopes = alibi_slopes + self.scale = scale + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float): + num_tokens, hidden_size = query.shape + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + + if kv_cache is not None: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + attn_metadata.kv_cache_dtype, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.block_tables is not None + if kv_cache is None or prefill_meta.block_tables.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + raise NotImplementedError( + "Prefix caching is not supported with flashinfer yet.") + else: + assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + output = attn_metadata.decode_metadata.decode_wrapper.forward( + query, + kv_cache, + sm_scale=self.scale, + ) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7ae8c31fae1ac..34da0f6c6cdfc 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -17,6 +17,7 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() @lru_cache(maxsize=None) @@ -41,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.warning("Eager mode is enforced for the Flashinfer backend. ") + from vllm.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend else: raise ValueError("Invalid attention backend.") diff --git a/vllm/config.py b/vllm/config.py index 3bdd3f774bc27..fe54c54bed48e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -298,6 +298,11 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + return self.hf_text_config.num_attention_heads // \ + parallel_config.tensor_parallel_size + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size diff --git a/vllm/sequence.py b/vllm/sequence.py index 0e931ebbb6571..8caf97d30d539 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -579,8 +579,10 @@ class SequenceGroupMetadata: query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. - state: Internal state tied to this sequence group. lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. """ diff --git a/vllm/utils.py b/vllm/utils.py index ce55253ce2199..b06c8508757c5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -355,21 +355,9 @@ def _generate_random_fp8( del tensor_tmp -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: int = 0, - device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str): @@ -388,6 +376,55 @@ def create_kv_caches_with_random( torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + assert cache_dtype != "fp8" + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + key_caches, value_caches = [], [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + key_value_cache.uniform_(-scale, scale) + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bbb1f5205af5e..ab248596490f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,6 +9,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) +from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -23,8 +24,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available, - make_tensor_with_pad) +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, make_tensor_with_pad) logger = init_logger(__name__) @@ -155,6 +156,9 @@ def __init__( # (max batch size to capture, max context len to capture / block size). self.graph_block_tables: torch.Tensor # Set after initial profiling. + # Set if the backend is flashinfer. + self.flashinfer_workspace_buffer: torch.Tensor + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -315,6 +319,7 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and @@ -390,18 +395,26 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) + if self.attn_backend is FlashInferBackend: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + use_cuda_graph=False, + seq_start_loc=seq_start_loc, + max_seq_len=max_seq_len, + block_tables=block_tables) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) return PreparePromptMetadata( input_tokens=input_tokens, @@ -429,6 +442,24 @@ def _prepare_decode( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() @@ -469,6 +500,13 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) + last_page_len = seq_data.get_len() % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + # vLLM uses cuda graph only for decoding requests. # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. @@ -518,18 +556,51 @@ def _prepare_decode( device=self.device, ) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - max_seq_len=max_seq_len, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + if self.attn_backend is FlashInferBackend: + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) + paged_kv_indptr = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + dtype=torch.int, + device=self.device) + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + use_cuda_graph=False, + workspace_buffer=self.flashinfer_workspace_buffer, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + data_type=kv_cache_dtype) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_query_len=None, + max_seq_len=max_seq_len, + subquery_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions, From ab502751117d3785384b9c33ee88e0aff93bbf05 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Fri, 3 May 2024 15:52:01 -0700 Subject: [PATCH 205/413] [Speculative decoding] Support target-model logprobs (#4378) --- tests/spec_decode/e2e/conftest.py | 66 +++- tests/spec_decode/e2e/test_logprobs.py | 335 ++++++++++++++++++ .../e2e/test_multistep_correctness.py | 63 +++- tests/spec_decode/test_multi_step_worker.py | 8 + tests/spec_decode/test_spec_decode_worker.py | 29 +- tests/spec_decode/utils.py | 2 + vllm/engine/output_processor/multi_step.py | 18 +- vllm/model_executor/layers/sampler.py | 16 +- vllm/sequence.py | 3 + vllm/spec_decode/batch_expansion.py | 59 ++- vllm/spec_decode/interfaces.py | 5 + vllm/spec_decode/ngram_worker.py | 6 + vllm/spec_decode/spec_decode_worker.py | 100 ++++-- vllm/spec_decode/top1_proposer.py | 2 +- vllm/spec_decode/util.py | 103 +++++- 15 files changed, 728 insertions(+), 87 deletions(-) create mode 100644 tests/spec_decode/e2e/test_logprobs.py diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 492620cf6e2cf..b1ab8a07ca636 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,9 +1,13 @@ import asyncio +import time from itertools import cycle -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import pytest import ray +import torch +from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) from tests.conftest import cleanup from vllm import LLM @@ -13,7 +17,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData +from vllm.sequence import Logprob, MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, random_uuid @@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, test_name = request.node.name def generator_inner(): - print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') + + wait_for_gpu_memory_to_clear( + devices=list(range(torch.cuda.device_count())), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) use_async = False if "use_async" in kwargs: use_async = kwargs.pop("use_async") + print(f'{use_async=}') + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) set_random_seed(seed) @@ -188,6 +199,20 @@ def get_output_from_llm_generator( return tokens, token_ids +def get_logprobs_from_llm_generator( + llm_generator, prompts, + sampling_params) -> List[List[Dict[int, Logprob]]]: + """Returns a dict of (token_id: Logprob) for each generated position, for + each sequence in the batch. + """ + for llm in llm_generator(): + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + logprobs = [output.outputs[0].logprobs[:] for output in outputs] + del llm + + return logprobs + + def run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, batch_size, @@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, print(f'{i=} {baseline_token_ids=}') print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + + +def wait_for_gpu_memory_to_clear(devices: List[int], + threshold_bytes: int, + timeout_s: float = 120) -> None: + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + nvmlInit() + start_time = time.time() + while True: + output = {} + output_raw = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f'{gb_used:.02f}' + + print('gpu memory used (GB): ', end='') + for k, v in output.items(): + print(f'{k}={v}; ', end='') + print('') + + dur_s = time.time() - start_time + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + print(f'Done waiting for free GPU memory on devices {devices=} ' + f'({threshold_bytes/2**30=}) {dur_s=:.02f}') + break + + if dur_s >= timeout_s: + raise ValueError(f'Memory of devices {devices=} not free after ' + f'{dur_s=:.02f} ({threshold_bytes/2**30=})') + + time.sleep(5) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py new file mode 100644 index 0000000000000..9572aac7df6e0 --- /dev/null +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -0,0 +1,335 @@ +import math +from itertools import cycle + +import pytest + +from vllm import SamplingParams + +from .conftest import get_logprobs_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_equality(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify output logprobs are equal with and without speculative decoding. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("num_logprobs", [6]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int, + num_logprobs: int): + """Verify output logprobs are equal with and without spec decode. + This specifies a number of logprobs >1. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + logprob_rank=num_logprobs) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 6, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Veriy logprob greedy equality with different speculation lens. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_when_skip_speculation(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify logprobs greedy equality when some sequences skip speculation. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify at least one logprob result has num_logprobs+1, which tests the + case where the sampled token is not in top-k logprobs. + + Ideally, this test should validate equality with non-spec by getting + logprobs. This is left as future improvement. + """ + batch_size = 8 + max_output_len = output_len + force_output_len = True + logprob_rank = 5 + + temperature = 1.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + num_returned_logprobs = [ + len(logprob_dict) for seq_logprobs in spec_batch_logprobs + for logprob_dict in seq_logprobs + ] + + # Assert one of the returned logprobs has > num_logprobs (indicating the + # sampled token is not in top-k). + assert any([ + num_returned > logprob_rank for num_returned in num_returned_logprobs + ]) + + +def run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + logprob_rank: int = 1): + """Helper method that compares the logprobs outputs of both the baseline LLM + and the test LLM. It asserts greedy equality of the logprobs when the + temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + baseline_batch_logprobs = get_logprobs_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_logprobs) == len(prompts) + assert len(spec_batch_logprobs) == len(prompts) + + # For each sequence in the batch. + for i, (baseline_logprobs, spec_logprobs) in enumerate( + zip(baseline_batch_logprobs, spec_batch_logprobs)): + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # Map rank to token/logprob in spec output. + spec_rank_to_token_id = { + value.rank: key + for key, value in spec_pos_logprobs.items() + } + spec_rank_to_logprob = { + value.rank: value.logprob + for key, value in spec_pos_logprobs.items() + } + + # Map rank to token/logprob in baseline output. + baseline_rank_to_token_id = { + value.rank: key + for key, value in baseline_pos_logprobs.items() + } + baseline_rank_to_logprob = { + value.rank: value.logprob + for key, value in baseline_pos_logprobs.items() + } + + # Assert set of ranks returned is equal. + assert set(spec_rank_to_token_id.keys()) == set( + baseline_rank_to_token_id.keys()) + + # Assert each logprob/token id is correct, keyed by rank. + for rank in sorted(set(spec_rank_to_token_id.keys())): + assert spec_rank_to_token_id[ + rank] == baseline_rank_to_token_id[rank], f"{rank}" + assert math.isclose( + a=spec_rank_to_logprob[rank], + b=baseline_rank_to_logprob[rank], + abs_tol=1e-1, + ) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index f99e0f6778e59..f15fcc4746d20 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -41,24 +41,17 @@ @pytest.mark.parametrize( "common_llm_kwargs", - [ - { - # Use a small model for a fast test. - # Note this is repeated in the test body; to initialize a tokenizer. - "model": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", - # Required for spec decode. - "use_v2_block_manager": True, + # Skip cuda graph recording for fast test. + "enforce_eager": True, - # whether use AsyncLLM engine - "use_async": async_mode, - } - # Try both async and sync engine execution - for async_mode in [True, False] - ]) + # Required for spec decode. + "use_v2_block_manager": True, + }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ @@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, assert actual_tokens.strip() == expected_tokens.strip() +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Use AsyncLLM engine + "use_async": True, + }]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_with_async_engine(test_llm_generator, + baseline_llm_generator, + batch_size: int): + """Verify spec decode works well with async LLM engine. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=32, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index cc0427633e688..a33fd71459455 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len(): vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(batch_size, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint(low=0, high=vocab_size, size=(batch_size, ), @@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k(): vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint( low=0, high=vocab_size, diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 91315df9b5e60..6763583aa85cc 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] @@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] @@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int): num_lookahead_slots=k) expected_output = create_sampler_output_list( - rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) + token_ids=rejection_sampler_output.transpose(0, 1), + probs=[None for _ in range(k + 1)], + logprobs=[None for _ in range(k + 1)]) seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) @@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int): continue assert actual_by_step[i].output_token == expected_by_step[ i].output_token - assert actual_by_step[i].logprobs == expected_by_step[i].logprobs @pytest.mark.parametrize('k', [1, 2]) @@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 87c7d88a80f42..f0f0d09106a00 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose( def create_sampler_output_list( token_ids: torch.Tensor, probs: Iterable[Optional[torch.Tensor]], + logprobs: Iterable[Optional[torch.Tensor]], seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: num_steps, batch_size = token_ids.shape token_ids_by_step = token_ids.tolist() @@ -222,6 +223,7 @@ def create_sampler_output_list( ) for seq_index, token_id in enumerate(token_ids_by_step[step]) ], sampled_token_probs=probs[step], + logprobs=logprobs[step], sampled_token_ids=token_ids[step]) for step in range(num_steps) ] diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 9abd87a4d5a9a..5f2f433aa811f 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, List from transformers import PreTrainedTokenizer @@ -8,8 +9,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -48,10 +49,14 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: # TODO(sang): Prompt logprob currently not implemented in multi step # workers. + self._log_prompt_logprob_unsupported_warning_once() + + @staticmethod + @functools.lru_cache() + def _log_prompt_logprob_unsupported_warning_once(): logger.warning( "Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") - pass def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: @@ -89,6 +94,7 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + @@ -113,11 +119,11 @@ def _process_seq_outputs(self, seq: Sequence, # Incrementally append tokens to the sequence, as if we had only one new # token. - for output_token_id in output_token_ids: + for output_token_id, output_logprob in zip(output_token_ids, + output_logprobs): seq.append_token_id( token_id=output_token_id, - # TODO emit logprobs in multi-step decoding. - logprobs={output_token_id: Logprob(0.0)}, + logprobs=output_logprob, ) new_char_count = 0 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2de7763605dfc..1f19d2053d996 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -103,8 +103,7 @@ def forward( if self.include_gpu_probs_tensor: assert maybe_sampled_tokens_tensor is not None - sampled_tokens_tensor = maybe_sampled_tokens_tensor - on_device_tensors = (probs, sampled_tokens_tensor) + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: on_device_tensors = None @@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, has implications on the overall design of the sampler, e.g. how to record accurate logprobs for the user, so this improvement is deferred to later. """ - logprobs[sample_indices, :] = -float('inf') - logprobs[sample_indices, greedy_samples] = 0.0 + # NOTE: logprobs are not modified so they can be returned to the user. probs[sample_indices, :] = 0 probs[sample_indices, greedy_samples] = 1.0 @@ -976,7 +974,8 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1005,14 +1004,17 @@ def _build_sampler_output( # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: - sampled_token_probs, sampled_token_ids = on_device_tensors + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors else: - sampled_token_probs, sampled_token_ids = (None, None) + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) return SamplerOutput( outputs=sampler_output, sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, ) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8caf97d30d539..35ac59d69f117 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -700,6 +700,9 @@ class SamplerOutput: # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + # On-device tensor containing the sampled token ids. sampled_token_ids: Optional["torch.Tensor"] = None diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8b113e93474ff..8b302ba1aabeb 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -94,7 +94,7 @@ def score_proposals( assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs = self._contract_batch( + all_tokens, all_probs, spec_logprobs = self._contract_batch( contracted_bs=len(seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, @@ -107,6 +107,7 @@ def score_proposals( return SpeculativeScores( probs=all_probs, token_ids=all_tokens, + logprobs=spec_logprobs, ) def _expand_batch( @@ -148,12 +149,12 @@ def _expand_batch( return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) - def _contract_batch(self, contracted_bs: int, - target_sampler_output: List[SamplerOutput], - proposals: SpeculativeProposals, - num_scoring_tokens: int, non_spec_indices: List[int], - spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + def _contract_batch( + self, contracted_bs: int, + target_sampler_output: List[SamplerOutput], + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -161,8 +162,9 @@ def _contract_batch(self, contracted_bs: int, contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) = self._split_scoring_output( + (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -179,6 +181,8 @@ def _contract_batch(self, contracted_bs: int, spec_expanded_bs, k + 1) target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size) + target_logprobs = target_logprobs.squeeze().reshape( + spec_expanded_bs, k + 1, self._vocab_size) all_tokens = torch.full(size=(contracted_bs, k + 1), fill_value=-1, @@ -189,16 +193,26 @@ def _contract_batch(self, contracted_bs: int, self._vocab_size, device=self._device, dtype=torch.float32) + all_logprobs = torch.full(size=( + contracted_bs, + k + 1, + self._vocab_size, + ), + fill_value=-float("inf"), + device=self._device, + dtype=torch.float32) if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs + all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs + return all_tokens, all_probs, all_logprobs def _create_scoring_model_input( self, @@ -308,7 +322,8 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: """Split the target model output into speculative and non-speculative output. """ @@ -328,21 +343,29 @@ def _split_scoring_output( ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens - target_token_ids, target_probs = sampler_output_to_torch( - [sampler_output], True) + sampler_output.logprobs = spec_logprobs + (target_token_ids, target_probs, + target_logprobs) = sampler_output_to_torch([sampler_output], True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens - non_spec_target_token_ids, non_spec_target_probs = ( - sampler_output_to_torch([sampler_output], True)) - - return (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) + sampler_output.logprobs = non_spec_logprobs + (non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], + True) + + return (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index dd040779922e9..489d940a88856 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -38,6 +38,11 @@ class SpeculativeScores: # Probabilities of the speculative tokens according to the scoring model. probs: torch.Tensor + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + # Token ids sampled from the scoring model. Used for speculative bonus # tokens and also non-speculative normal decoding. token_ids: torch.Tensor diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 696ca964328cf..cacaca697526c 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -140,11 +140,17 @@ def sampler_output( device=self.device, ) token_probs.scatter_(2, indices, 1) + token_logprobs = torch.zeros( + (len(seq_group_metadata_list), sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) for i in range(len(seq_group_metadata_list)): outputs.append( SamplerOutput( outputs=None, sampled_token_probs=token_probs[i], + logprobs=token_logprobs, sampled_token_ids=token_ids[i], )) return outputs, False diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e33bb4f3f6337..503519a0dfc4b 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -5,15 +5,16 @@ from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, +from vllm.spec_decode.util import (create_sequence_group_output, + get_all_num_logprobs, get_all_seq_ids, + get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -258,6 +259,7 @@ def _run_no_spec( # overhead when the engine runs in a different process than the workers. sampler_output.probs = None sampler_output.sampled_tokens = None + sampler_output.logprobs = None return [sampler_output] @nvtx_range("spec_decode_worker._run_speculative_decoding_step") @@ -298,12 +300,15 @@ def _run_speculative_decoding_step( ) #logger.info("verify proposals") - accepted_token_ids = self._verify_tokens(seq_group_metadata_list, - proposal_scores, proposals, k) + accepted_token_ids, target_logprobs = self._verify_tokens( + seq_group_metadata_list, proposal_scores, proposals, k) #logger.info("create output list") - return self._create_output_sampler_list(seq_group_metadata_list, - accepted_token_ids, k) + return self._create_output_sampler_list( + seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + k=k) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -312,9 +317,12 @@ def _verify_tokens( proposal_scores: SpeculativeScores, proposals: SpeculativeProposals, max_proposal_len: int, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Determine which speculative tokens are accepted using the probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. """ proposal_lens_list = proposals.proposal_lens.tolist() @@ -361,17 +369,19 @@ def _verify_tokens( non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs # Rearrange so that results are in the order of the original seq group # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() - return accepted_token_ids + return accepted_token_ids, logprobs def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -379,30 +389,68 @@ def _create_output_sampler_list( The output is padded with -1 tokens such that each sequence has the same number of outputs. """ + batch_size, num_steps = accepted_token_ids.shape + + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + + # Get the top-k logprobs (which may or may not include the logprob of + # the accepted token). + (topk_logprobs_by_step, + topk_indices_by_step) = target_logprobs_by_step.topk( + k=self.scorer_worker.model_config.max_logprobs, + dim=-1, + ) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. seq_ids = get_all_seq_ids(seq_group_metadata_list) - - # shape: [k+1, batch_size] - accepted_token_ids_by_step = accepted_token_ids.transpose(0, - 1).tolist() + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize all tensors to CPU Python lists. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step.tolist() + topk_indices_by_step = topk_indices_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. sampler_output_list = [] - for token_ids_by_step in accepted_token_ids_by_step: - if all(token_id == -1 for token_id in token_ids_by_step): + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] - for token_id, seq_id in zip(token_ids_by_step, seq_ids): + for sequence_index in range(batch_size): + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( - SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - # TODO Add verifier logprobs. - logprobs={token_id: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], )) + sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 6766a2deb8eb8..56c63887b0315 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -166,7 +166,7 @@ def _merge_outputs( return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs = sampler_output_to_torch( + proposal_tokens, proposal_probs, _ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 894d2fd915948..d6f80c82b80bf 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,10 +1,11 @@ from contextlib import contextmanager from itertools import chain -from typing import List, Tuple +from typing import Dict, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceGroupOutput, SequenceOutput) SeqId = int @@ -21,6 +22,89 @@ def get_all_seq_ids( ])) +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if seq_group_metadata.sampling_params.logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor >= + expanded_selected_logprobs).sum(-1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[int], + topk_logprobs: List[float], +) -> SequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[int]): The list of top-k token ids. + topk_logprobs (List[float]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_ids[topk_logprob_index]: Logprob( + logprob=topk_logprobs[topk_logprob_index], + rank=topk_logprob_index + 1, + ) + for topk_logprob_index, _ in enumerate(topk_token_ids) + }) + + return SequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + # TODO add prompt logprobs support. + prompt_logprobs=None, + ) + + def split_batch_by_proposal_len( seq_group_metadata_list: List[SequenceGroupMetadata], proposal_lens: List[int], select_proposal_len_zero: bool @@ -49,8 +133,8 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], - sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: + sampler_output_list: List[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -76,6 +160,15 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_probs = sampled_token_probs.transpose(0, 1) + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + if sampler_transposed: + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) + # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( [ @@ -87,7 +180,7 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs + return sampled_token_ids, sampled_token_probs, sampled_token_logprobs def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, From 344bf7cd2d66a8b13f216f61c7a6d5d70576a498 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 3 May 2024 15:55:56 -0700 Subject: [PATCH 206/413] [Misc] add installation time env vars (#4574) --- setup.py | 33 ++++++++++++++++++-------- vllm/envs.py | 66 +++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 81 insertions(+), 18 deletions(-) diff --git a/setup.py b/setup.py index 801d8d50db1dc..3768daf9d6fab 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import importlib.util import io import logging import os @@ -13,10 +14,23 @@ from setuptools.command.build_ext import build_ext from torch.utils.cpp_extension import CUDA_HOME + +def load_module_from_path(module_name, path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) -# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] -VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") + +# cannot import envs directly because it depends on vllm, +# which is not installed yet +envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) + +VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE # vLLM only supports Linux platform assert sys.platform.startswith( @@ -60,7 +74,7 @@ class cmake_build_ext(build_ext): def compute_num_jobs(self): # `num_jobs` is either the value of the MAX_JOBS environment variable # (if defined) or the number of CPUs available. - num_jobs = os.environ.get("MAX_JOBS", None) + num_jobs = envs.MAX_JOBS if num_jobs is not None: num_jobs = int(num_jobs) logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) @@ -78,7 +92,7 @@ def compute_num_jobs(self): # environment variable (if defined) or 1. # when it is set, we reduce `num_jobs` to avoid # overloading the system. - nvcc_threads = os.getenv("NVCC_THREADS", None) + nvcc_threads = envs.NVCC_THREADS if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) logger.info( @@ -104,7 +118,7 @@ def configure(self, ext: CMakeExtension) -> None: # Select the build type. # Note: optimization level + debug info are set by the build type default_cfg = "Debug" if self.debug else "RelWithDebInfo" - cfg = os.getenv("CMAKE_BUILD_TYPE", default_cfg) + cfg = envs.CMAKE_BUILD_TYPE or default_cfg # where .so files will be written, should be the same for all extensions # that use the same CMakeLists.txt. @@ -118,7 +132,7 @@ def configure(self, ext: CMakeExtension) -> None: '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] - verbose = bool(int(os.getenv('VERBOSE', '0'))) + verbose = envs.VERBOSE if verbose: cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] @@ -205,8 +219,7 @@ def _is_neuron() -> bool: subprocess.run(["neuron-ls"], capture_output=True, check=True) except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False - return torch_neuronx_installed or os.environ.get("VLLM_BUILD_WITH_NEURON", - False) + return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON def _is_cpu() -> bool: @@ -214,7 +227,7 @@ def _is_cpu() -> bool: def _install_punica() -> bool: - return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) + return envs.VLLM_INSTALL_PUNICA_KERNELS def get_hipcc_rocm_version(): @@ -377,7 +390,7 @@ def _read_requirements(filename: str) -> List[str]: package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } -if os.environ.get("VLLM_USE_PRECOMPILED"): +if envs.VLLM_USE_PRECOMPILED: ext_modules = [] package_data["vllm"].append("*.so") diff --git a/vllm/envs.py b/vllm/envs.py index 2dbb57e6253a7..91cc8f3be775c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -27,6 +27,14 @@ VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_TARGET_DEVICE: str = "cuda" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_BUILD_WITH_NEURON: bool = False + VLLM_USE_PRECOMPILED: bool = False + VLLM_INSTALL_PUNICA_KERNELS: bool = False + CMAKE_BUILD_TYPE: Optional[str] = None + VERBOSE: bool = False # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -34,6 +42,56 @@ # begin-env-vars-definition environment_variables: Dict[str, Callable[[], Any]] = { + + # ================== Installation Time Env Vars ================== + + # Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), + + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": + lambda: os.getenv("NVCC_THREADS", None), + + # If set, vllm will build with Neuron support + "VLLM_BUILD_WITH_NEURON": + lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)), + + # If set, vllm will use precompiled binaries (*.so) + "VLLM_USE_PRECOMPILED": + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), + + # If set, vllm will install Punica kernels + "VLLM_INSTALL_PUNICA_KERNELS": + lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))), + + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + + # If set, vllm will print verbose logs during installation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + + # Root directory for VLLM configuration files + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv( + "XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"), + + # ================== Runtime Env Vars ================== + # used in distributed environment to determine the master address 'VLLM_HOST_IP': lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), @@ -93,14 +151,6 @@ "S3_ENDPOINT_URL": lambda: os.environ.get("S3_ENDPOINT_URL", None), - # Root directory for VLLM configuration files - # Note that this not only affects how vllm finds its configuration files - # during runtime, but also affects how vllm installs its configuration - # files during **installation**. - "VLLM_CONFIG_ROOT": - lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv( - "XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"), - # Usage stats collection "VLLM_USAGE_STATS_SERVER": lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), From bc8ad68455ce41ba672764f4a53df5a87d1dbe99 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 May 2024 17:47:07 -0700 Subject: [PATCH 207/413] [Misc][Refactor] Introduce ExecuteModelData (#4540) --- tests/spec_decode/test_multi_step_worker.py | 98 ++++++++++---------- tests/spec_decode/test_ngram_worker.py | 64 ++++++------- tests/spec_decode/test_spec_decode_worker.py | 95 +++++++++---------- tests/spec_decode/utils.py | 50 +--------- tests/worker/test_swap.py | 30 ++++-- vllm/core/scheduler.py | 4 + vllm/engine/async_llm_engine.py | 16 ++-- vllm/engine/llm_engine.py | 12 ++- vllm/executor/cpu_executor.py | 37 ++------ vllm/executor/executor_base.py | 22 ++--- vllm/executor/gpu_executor.py | 33 ++----- vllm/executor/neuron_executor.py | 33 +++---- vllm/executor/ray_gpu_executor.py | 19 +--- vllm/sequence.py | 32 ++++++- vllm/spec_decode/batch_expansion.py | 30 ++---- vllm/spec_decode/interfaces.py | 15 +-- vllm/spec_decode/multi_step_worker.py | 54 +++++------ vllm/spec_decode/ngram_worker.py | 62 +++++-------- vllm/spec_decode/spec_decode_worker.py | 90 +++++------------- vllm/spec_decode/top1_proposer.py | 22 ++--- vllm/worker/cpu_worker.py | 25 ++--- vllm/worker/worker.py | 23 +++-- vllm/worker/worker_base.py | 8 +- 23 files changed, 359 insertions(+), 515 deletions(-) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index a33fd71459455..cb2de97a4af94 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,13 +5,12 @@ import torch from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker from .utils import (assert_logprobs_dict_allclose, create_batch, - create_execute_model_data, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) @@ -105,31 +104,32 @@ def test_same_output_for_single_step(): final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] - multi_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - single_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) + multi_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) actual_output, _ = multi_step_worker.sampler_output( - **multi_step_execute_model_data.to_dict(), sample_len=num_steps) + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=multi_step_seq_group), + sample_len=num_steps) assert len(actual_output) == num_steps actual_output = actual_output[0] + single_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), )[0] + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=single_step_seq_group))[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -193,19 +193,20 @@ def test_same_output_for_multi_step(): worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) continuations = [[1] for _ in prompts] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens), ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) multi_step_output, _ = multi_step_worker.sampler_output( - **execute_model_data.to_dict(), sample_len=num_steps) + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=num_steps) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -215,16 +216,16 @@ def test_same_output_for_multi_step(): for _ in multi_step_output: - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens)) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) single_step_output.extend( - worker.execute_model(**execute_model_data.to_dict(), )) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): @@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len(): ) for _ in range(k) ], True - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations(): max_proposal_len=prompt_len + k - 1, ) - execute_model_data, _, _ = create_batch(batch_size, - k, - prompt_len=prompt_len) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prompt_len=prompt_len) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k(): ) for _ in range(k) ], True - execute_model_data, _, _ = create_batch( + seq_group_metadata_list, _, _ = create_batch( batch_size, k, prompt_len=prompt_len, prev_output_token_len=prev_output_token_len, ) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index e7e2e87f599dd..de305c4030aa9 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -1,10 +1,10 @@ import torch +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.top1_proposer import Top1Proposer -from .utils import (create_execute_model_data, - create_seq_group_metadata_from_prompts, create_worker) +from .utils import create_seq_group_metadata_from_prompts, create_worker def test_ngram_algo_correctness_for_single_no_match(): @@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match(): proposal_len = 5 final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - ngram_sampler_output_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - proposals = proposer.get_proposals( - **ngram_sampler_output_data.to_dict(), - proposal_len=proposal_len, - ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): proposal_len = 5 final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - ngram_sampler_output_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - proposals = proposer.get_proposals( - **ngram_sampler_output_data.to_dict(), - proposal_len=proposal_len, - ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all(): proposal_len = 5 final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - ngram_sampler_output_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens)) - - proposals = proposer.get_proposals( - **ngram_sampler_output_data.to_dict(), - proposal_len=proposal_len, - ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 6763583aa85cc..ef9d32f73d668 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -15,8 +15,7 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) -from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, - mock_worker) +from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=execute_model_req) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 for args, _ in call_args_list: - (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, actual_k) = args - actual_execute_model_data = ExecuteModelData(seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy) - assert actual_execute_model_data == execute_model_data - assert actual_k == k + actual_execute_model_data = args[0] + assert actual_execute_model_data == execute_model_req @pytest.mark.parametrize('k', [1, 2, 6]) @@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, prompts, prev_output_tokens = create_batch( + seq_group_metadata_list, prompts, prev_output_tokens = create_batch( batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( @@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int): target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) seen_contexts = [] call_args_list = target_worker.execute_model.call_args_list assert len(call_args_list) == 1 - for args, kwargs in call_args_list: - target_execute_model_data = ExecuteModelData.from_dict(kwargs) + for _, kwargs in call_args_list: + seq_group_metadata_list = kwargs[ + "execute_model_req"].seq_group_metadata_list - assert len(target_execute_model_data.seq_group_metadata_list) == ( - k + 1) * batch_size - for seq_group_metadata in ( - target_execute_model_data.seq_group_metadata_list): + assert len(seq_group_metadata_list) == (k + 1) * batch_size + for seq_group_metadata in seq_group_metadata_list: for seq_data in seq_group_metadata.seq_data.values(): seen_contexts.append(seq_data.get_token_ids()) @@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert len(rejection_sampler.call_args_list) == 1 _, kwargs = rejection_sampler.call_args_list[0] @@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output - output = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) expected_output = create_sampler_output_list( token_ids=rejection_sampler_output.transpose(0, 1), @@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int): seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) - for seq_group_metadata in execute_model_data.seq_group_metadata_list + for seq_group_metadata in seq_group_metadata_list ] actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} @@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): metrics_collector.maybe_collect_rejsample_metrics.return_value = ( mock_rejsample_metrics) - output = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.parametrize('k', [0, 5]) @@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.skip_global_cleanup diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f0f0d09106a00..f288652d51556 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, fields from itertools import count from typing import Dict, Iterable, List, Optional, Union from unittest.mock import MagicMock @@ -16,50 +15,10 @@ from vllm.worker.worker import Worker -@dataclass -class ExecuteModelData: - """Helper data structure which facilitates cleaner tests. - """ - seq_group_metadata_list: List[SequenceGroupMetadata] - blocks_to_swap_in: Dict[int, int] - blocks_to_swap_out: Dict[int, int] - blocks_to_copy: Dict[int, List[int]] - - def to_dict(self): - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) - - @classmethod - def from_dict(cls, d): - cleaned = dict((field.name, d[field.name]) for field in fields(cls)) - return cls(**cleaned) - - def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size -def create_execute_model_data( - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, int]] = None, -) -> ExecuteModelData: - if blocks_to_swap_in is None: - blocks_to_swap_in = {} - if blocks_to_swap_out is None: - blocks_to_swap_out = {} - if blocks_to_copy is None: - blocks_to_copy = {} - - return ExecuteModelData( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - def mock_worker(cls=None, vocab_size: int = 30_000, max_model_len: int = 2048, @@ -258,8 +217,7 @@ def create_batch(batch_size, for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_prompt_lens, - prev_output_tokens, seq_ids), ) - return execute_model_data, prompts, prev_output_tokens + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, final_prompt_lens, + prev_output_tokens, seq_ids) + return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 1804cf78d8003..07bcd343a96a6 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -1,6 +1,7 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.worker import Worker @@ -54,10 +55,14 @@ def test_swap() -> None: # Test swap out. blocks_to_swap_out = {3: 72, 56: 35, 84: 34} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in={}, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=[], + blocks_to_swap_in={}, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy={}, + ) + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] @@ -66,14 +71,19 @@ def test_swap() -> None: assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out={}, - blocks_to_copy={}) + execute_model_req.blocks_to_swap_out = {} + execute_model_req.blocks_to_swap_in = { + 19: 45, + 67: 23, + 12: 78, + 40: 99, + 1: 71 + } + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in.items(): assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7c55b08d4857d..a9e0b05b8db67 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -128,6 +128,8 @@ class SchedulerOutputs: ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -797,6 +799,7 @@ def _schedule_default(self) -> SchedulerOutputs: ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), ) def _schedule_chunked_prefill(self): @@ -883,6 +886,7 @@ def _schedule_chunked_prefill(self): swapped_in.blocks_to_copy), ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), ) def _schedule(self) -> SchedulerOutputs: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cf5053bba1d48..9f72a0d11974f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,7 +16,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData, SamplerOutput +from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -210,12 +210,16 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) output = await self.model_executor.execute_model_async( - seq_group_metadata_list, - scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots) + execute_model_req) else: output = [] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94a5b397a4d43..342f2c796d6fb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,8 +22,8 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, +from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, + Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -583,12 +583,16 @@ def step(self) -> List[RequestOutput]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if not scheduler_outputs.is_empty(): - output = self.model_executor.execute_model( + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots) + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) + output = self.model_executor.execute_model( + execute_model_req=execute_model_req) else: output = [] diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 733eef828adc4..a2212459f034e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Set, Tuple +from typing import List, Set, Tuple import torch @@ -7,7 +7,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -72,18 +72,10 @@ def initialize_cache(self, num_gpu_blocks: int, logger.info("# CPU blocks: %d", num_gpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -104,19 +96,10 @@ def check_health(self) -> None: class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots) + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output async def check_health_async(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 96cd18250bb37..08aa58999b1ec 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput class ExecutorBase(ABC): @@ -68,12 +68,9 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes at least one model step on the given sequences.""" raise NotImplementedError @@ -107,13 +104,8 @@ class ExecutorAsyncBase(ExecutorBase): @abstractmethod async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index a58856a12f0c8..1af3bcf380843 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,7 +3,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -117,20 +117,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots, - ) + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, + execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots) + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 8a3b9cde84311..e7f0e887921b7 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Set, Tuple +from typing import List, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import make_async logger = init_logger(__name__) @@ -45,20 +45,18 @@ def initialize_cache(self, num_gpu_blocks: int, """ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: - assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} - and blocks_to_copy == {}), ( + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + assert (execute_model_req.blocks_to_swap_in == {} + and execute_model_req.blocks_to_swap_out == {} + and execute_model_req.blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") - assert num_lookahead_slots == 0, ( + assert execute_model_req.num_lookahead_slots == 0, ( "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, + execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, ) + output = await make_async( + self.driver_worker.execute_model + )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, ) return output async def check_health_async(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4684b857ccd39..afc1c886722e6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,7 +10,7 @@ DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -166,21 +166,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int = 0) -> List[SamplerOutput]: + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: all_outputs = self._run_workers( "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - "num_lookahead_slots": num_lookahead_slots, - }, + driver_kwargs={"execute_model_req": execute_model_req}, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. diff --git a/vllm/sequence.py b/vllm/sequence.py index 35ac59d69f117..f2939eff7959b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,7 +1,7 @@ """Sequence and its related classes.""" import copy import enum -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Union from vllm.block import LogicalTokenBlock @@ -734,3 +734,33 @@ def __repr__(self) -> str: f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, " f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + + +@dataclass +class ExecuteModelRequest: + """The model execution request.""" + # The sequence group metadata list. + seq_group_metadata_list: List[SequenceGroupMetadata] + # Blocks to swap in. Dict of CPU -> GPU block number. + blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) + # Blocks to swap out. Dict of GPU -> CPU block number. + blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) + # Blocks to copy. Source to a list of dest blocks. + blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict) + # The number of slots for lookahead decoding. + num_lookahead_slots: int = 0 + # The number of requests in the running queue. + running_queue_size: int = 0 + + def clone( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> "ExecuteModelRequest": + """Clone the request with a new sequence group metadata list.""" + return ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=self.blocks_to_swap_in.copy(), + blocks_to_swap_out=self.blocks_to_swap_out.copy(), + blocks_to_copy=self.blocks_to_copy.copy(), + num_lookahead_slots=self.num_lookahead_slots, + running_queue_size=self.running_queue_size, + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8b302ba1aabeb..d5fd96907ddd7 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,9 +1,10 @@ from itertools import chain, count -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Iterator, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, @@ -40,11 +41,7 @@ def __init__(self, scorer_worker: WorkerBase, device: str, @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. @@ -57,11 +54,7 @@ def score_proposals( no speculation is produced for that sequence. Args: - seq_group_metadata_list: The input sequence group metadata. - blocks_to_swap_in: This is passed to the worker during scoring. - blocks_to_swap_out: This is passed to the worker during scoring. - blocks_to_copy: This is passed to the worker during scoring. - k: The fixed proposal length. + execute_model_req: The execution request. proposals: The speculative proposals to score. Returns: SpeculativeScores: The scores of each speculative token, along with @@ -80,28 +73,25 @@ def score_proposals( (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( - seq_group_metadata_list=seq_group_metadata_list, + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) target_sampler_output = self._scorer_worker.execute_model( - seq_group_metadata_list=target_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list, )) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] all_tokens, all_probs, spec_logprobs = self._contract_batch( - contracted_bs=len(seq_group_metadata_list), + contracted_bs=len(execute_model_req.seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, non_spec_indices=non_spec_indices, spec_indices=spec_indices, - k=k, + k=execute_model_req.num_lookahead_slots, ) return SpeculativeScores( diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 489d940a88856..d311bfe984cbc 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional import torch -from vllm.sequence import SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest @dataclass @@ -58,11 +57,7 @@ class SpeculativeProposer(ABC): @abstractmethod def get_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: raise NotImplementedError @@ -72,11 +67,7 @@ class SpeculativeScorer(ABC): @abstractmethod def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d031bc85af160..5044cc1ef85fd 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,9 +1,10 @@ import copy -from typing import Dict, List, Tuple +from typing import List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -44,10 +45,7 @@ def set_include_gpu_probs_tensor(self): @torch.inference_mode() def sampler_output( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, sample_len: int, ) -> Tuple[List[SamplerOutput], bool]: """Run the model forward pass sample_len times. Returns the list of @@ -57,26 +55,24 @@ def sampler_output( For multi step worker, this indicator shall be True. """ - self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, - blocks_to_swap_out, blocks_to_copy) + self._raise_if_unsupported(execute_model_req) # Shallow copy input data so modifications (such as appending tokens) # do not cause side-effects. copied_seq_group_metadata_list = self._shallow_copy_inputs( - seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) + copied_execute_model_req = execute_model_req.clone( + copied_seq_group_metadata_list) # Assert enough KV space for sample_len tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, sample_len) + self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, + sample_len) # Run model sample_len times. model_outputs = [] for _ in range(sample_len): model_output = super().execute_model( - seq_group_metadata_list=copied_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=copied_execute_model_req) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] @@ -89,23 +85,13 @@ def sampler_output( def get_spec_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - max_proposal_len, - ) + return self._proposer.get_proposals(execute_model_req) def _append_new_tokens( self, model_output: SamplerOutput, @@ -196,20 +182,22 @@ def _assert_enough_kv_space( def _raise_if_unsupported( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> None: """MultiStepWorker does not yet implement support for cache swap operations or beam search. """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): raise NotImplementedError( "MultiStepWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index cacaca697526c..fed8be42054a5 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,8 +1,8 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -46,13 +46,7 @@ def set_include_gpu_probs_tensor(self): # NGram don't need gpu sampler pass - def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - ) -> None: + def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: """NGram doesn't depend on model execution, just pass this function""" pass @@ -71,10 +65,7 @@ def get_cache_block_size_bytes(self): def sampler_output( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, sample_len: int, ) -> Tuple[Optional[List[SamplerOutput]], bool]: """NGram match algo to pick proposal candidate. Returns the list of @@ -83,16 +74,11 @@ def sampler_output( For ngram worker, we already done needed transposed internal, so the indicator pass to sampler_output_to_torch shall be False. """ - self._raise_if_unsupported( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - ) + self._raise_if_unsupported(execute_model_req) arr = [] has_spec_out = False - for seq_group_metadata in seq_group_metadata_list: + for seq_group_metadata in execute_model_req.seq_group_metadata_list: seq_data = next(iter(seq_group_metadata.seq_data.values())) input_ids = torch.as_tensor(seq_data.get_token_ids(), @@ -135,17 +121,19 @@ def sampler_output( indices = token_ids.unsqueeze(2) token_probs = torch.zeros( - (len(seq_group_metadata_list), sample_len, self.vocab_size), + (len(execute_model_req.seq_group_metadata_list), sample_len, + self.vocab_size), dtype=torch.float32, device=self.device, ) token_probs.scatter_(2, indices, 1) token_logprobs = torch.zeros( - (len(seq_group_metadata_list), sample_len, self.vocab_size), + (len(execute_model_req.seq_group_metadata_list), sample_len, + self.vocab_size), dtype=torch.float32, device=self.device, ) - for i in range(len(seq_group_metadata_list)): + for i in range(len(execute_model_req.seq_group_metadata_list)): outputs.append( SamplerOutput( outputs=None, @@ -157,40 +145,32 @@ def sampler_output( def get_spec_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - max_proposal_len, - ) + return self._proposer.get_proposals(execute_model_req) def _raise_if_unsupported( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> None: """NGramWorker does not yet implement support for cache swap operations or beam search. """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): raise NotImplementedError( "NGramWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): raise NotImplementedError( "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 503519a0dfc4b..c2b119fbd5036 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,11 +1,12 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) @@ -189,69 +190,37 @@ def initialize_cache(self, num_gpu_blocks: int, @torch.inference_mode() def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - assert seq_group_metadata_list is not None, ( + assert execute_model_req.seq_group_metadata_list is not None, ( "speculative decoding " "requires non-None seq_group_metadata_list") - #logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", - # num_lookahead_slots) - # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. - if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: - return self._run_no_spec( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - return self._run_speculative_decoding_step( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - k=num_lookahead_slots, - ) + if execute_model_req.num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req) + + return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Run a prefill step, without any speculation. The input is sent to the proposer and scorer model so that the KV cache is consistent between the two. """ #logger.info("run proposer worker no spec") - self.proposer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + self.proposer_worker.execute_model(execute_model_req) #logger.info("run target worker no spec") - sampler_output = self.scorer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -264,13 +233,8 @@ def _run_no_spec( @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -282,33 +246,25 @@ def _run_speculative_decoding_step( #logger.info("get spec proposals") # Generate proposals using draft worker. - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - proposals = self.proposer_worker.get_spec_proposals( - seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, k) + proposals = self.proposer_worker.get_spec_proposals(execute_model_req) #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - k, + execute_model_req, proposals, ) #logger.info("verify proposals") accepted_token_ids, target_logprobs = self._verify_tokens( - seq_group_metadata_list, proposal_scores, proposals, k) + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) #logger.info("create output list") return self._create_output_sampler_list( - seq_group_metadata_list, + execute_model_req.seq_group_metadata_list, accepted_token_ids, target_logprobs=target_logprobs, - k=k) + k=execute_model_req.num_lookahead_slots) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 56c63887b0315..eb622a0e2e7f4 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -1,8 +1,9 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.util import sampler_output_to_torch @@ -40,17 +41,15 @@ def __init__( def get_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Get speculative proposals given the input batch. Sequences which would exceed the max model length are skipped during speculation. """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list # Split speculative- and non-speculative- sequences. ( @@ -66,11 +65,12 @@ def get_proposals( # token_ids is like [batch] format in proposal_len size list, # while if it is false, the format would be [proposal_len] # in batch size list - maybe_sampler_output, transposed = self._worker.sampler_output( + nonzero_execute_model_req = ExecuteModelRequest( seq_group_metadata_list=nonzero_proposal_len_seqs, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, + num_lookahead_slots=proposal_len, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, ) else: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 83ededd742533..4420d4cc9e12f 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -256,22 +256,24 @@ def cache_copy( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: + + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups: int = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - assert len(blocks_to_swap_in) == 0 - assert len(blocks_to_swap_out) == 0 + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, - "blocks_to_copy": blocks_to_copy, + "blocks_to_copy": execute_model_req.blocks_to_copy, } broadcast_tensor_dict(data, src=0) else: @@ -279,7 +281,6 @@ def execute_model( num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] - assert blocks_to_copy is not None self.cache_copy(blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 808261e47318b..4add36e94f723 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -18,7 +18,7 @@ init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -211,19 +211,21 @@ def cache_swap( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, - num_lookahead_slots: int = 0, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None + assert execute_model_req is not None num_seq_groups = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None + blocks_to_swap_in = execute_model_req.blocks_to_swap_in + blocks_to_swap_out = execute_model_req.blocks_to_swap_out + blocks_to_copy = execute_model_req.blocks_to_copy data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, @@ -238,9 +240,6 @@ def execute_model( blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 0a89e3a79769f..fb32feaca0c94 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -48,10 +48,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, - int], - blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError From 36fb68f94792a8cec8df5b58bab7ab4d4d6158b4 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 4 May 2024 16:18:00 +0900 Subject: [PATCH 208/413] [Doc] Chunked Prefill Documentation (#4580) --- docs/source/index.rst | 1 + docs/source/models/performance.rst | 38 ++++++++++++++++++++++++++++++ vllm/config.py | 5 ++-- 3 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 docs/source/models/performance.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 5cc28a2d70139..4022c590843e6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -87,6 +87,7 @@ Documentation models/adding_model models/engine_args models/lora + models/performance .. toctree:: :maxdepth: 1 diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst new file mode 100644 index 0000000000000..067757699f32a --- /dev/null +++ b/docs/source/models/performance.rst @@ -0,0 +1,38 @@ +.. _performance: + +Performance and Tuning +====================== + +Chunked Prefill +--------------- +vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. + +You can enable the feature by specifying + +.. code-block:: python + + llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True) + # Set max_num_batched_tokens to tune performance. + # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. + # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) + +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. + +Once chunked prefill is enabled, the policy is changed to + +- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. +- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. + +This policy has two benefits. + +- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. +- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. + +You can tune the performance by changing `max_num_batched_tokens`. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. +Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. +Higher batch size achieves better TTFT as you can put more prefill to the batch. +If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). diff --git a/vllm/config.py b/vllm/config.py index fe54c54bed48e..6c65bbe247f84 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -607,8 +607,9 @@ def __init__( self.max_num_batched_tokens = max_num_batched_tokens else: if enable_chunked_prefill: - # For chunked prefill, choose the well-tuned batch size. - self.max_num_batched_tokens = 768 + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + self.max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. From 2a052011ca473a9dc8160f3daa1f5f63a2ad1fe3 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 4 May 2024 14:45:16 -0400 Subject: [PATCH 209/413] [Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527) Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436. This PR enables the following checkpoint loading features for Mixtral: Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model Supports static or dynamic activation quantization with static weight quantization (all per tensor) Supports different scales for each expert weight Supports Fp8 in QKV layer Notes: The Expert Gate/Router always runs at half / full precision for now. If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance. --- tests/kernels/test_moe.py | 4 +- vllm/model_executor/models/mixtral.py | 171 ++++++++++++++++++-------- 2 files changed, 122 insertions(+), 53 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 046f11d957bdd..2356b9ec18b0d 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype): for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.ws[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 9ff9ba298588a..efa4de7516212 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -78,6 +78,8 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + # FIXME(pcmoritz): Make this more general to support different # quantization schemes self.use_fp8 = isinstance(quant_config, Fp8Config) @@ -86,55 +88,79 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, params_dtype=self.params_dtype, quant_config=None) - self.ws = nn.Parameter( + if self.use_fp8: + params_dtype = torch.float8_e4m3fn + + self.w13_weight = nn.Parameter( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - dtype=self.params_dtype)) - self.w2s = nn.Parameter( + dtype=params_dtype)) + self.w2_weight = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - dtype=self.params_dtype)) + dtype=params_dtype)) - set_weight_attrs(self.ws, { + set_weight_attrs(self.w13_weight, { "weight_loader": self.weight_loader, }) - set_weight_attrs(self.w2s, { + set_weight_attrs(self.w2_weight, { "weight_loader": self.weight_loader, }) - # Scaling factors for FP8 weights - self.ws_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - self.w2s_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - - # Scaling factors for FP8 activations - need_act_scales = (self.use_fp8 - and quant_config.activation_scheme == "static") - self.as_scale = nn.Parameter( - torch.zeros(1, dtype=torch.float32), - requires_grad=False) if need_act_scales else None - self.a2s_scale = nn.Parameter( - torch.zeros(1, dtype=torch.float32), - requires_grad=False) if need_act_scales else None - - if need_act_scales: - set_weight_attrs(self.as_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.a2s_scale, { - "weight_loader": self.weight_loader, - }) + # Used for fp8. + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a2_scale = None + + if self.use_fp8: + # WEIGHT_SCALE (for fp8) + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(self.w13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_scale, { + "weight_loader": self.weight_loader, + }) + + # ACT_SCALE (for fp8) + if quant_config.activation_scheme == "static": + if not quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + self.a13_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + self.a2_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + + set_weight_attrs(self.a13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2_scale, { + "weight_loader": self.weight_loader, + }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): @@ -149,20 +175,49 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name: - param_data[:] = param_data[:].max(loaded_weight) + if "act_scale" in weight_name or "weight_scale" in weight_name: + param_data[expert_id] = loaded_weight def process_weights_after_loading(self): - if self.use_fp8: - ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) - w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) + # Fp8 is the only case where we need to process after loading. + if not self.use_fp8: + return + + # If checkpoint is fp16, quantize here. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(self.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(self.w2_weight.data, + dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant( - self.ws.data[expert, :, :]) - w2s[expert, :, :], self.w2s_scale[ - expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) - self.ws = nn.Parameter(ws, requires_grad=False) - self.w2s = nn.Parameter(w2s, requires_grad=False) + w13_weight[expert, :, :], self.w13_scale[ + expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], self.w2_scale[ + expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :]) + self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + elif self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for fp8 MoE layer. " + "Using the maximum across experts for each layer. ") + + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -170,17 +225,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, + self.w13_weight, + self.w2_weight, router_logits, self.top_k, renormalize=True, inplace=True, use_fp8=self.use_fp8, - w1_scale=self.ws_scale, - w2_scale=self.w2s_scale, - a1_scale=self.as_scale, - a2_scale=self.a2s_scale) + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -222,7 +277,9 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - if isinstance(quant_config, Fp8Config): + if isinstance( + quant_config, + Fp8Config) and not quant_config.is_checkpoint_fp8_serialized: print_warning_once( "For Mixtral FP8 quantization, we currently do not quantize " "the attention layers until their FP8 performance is improved." @@ -461,16 +518,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] ] + [ # These are the activation scales for the experts # (param_name, weight_name, expert_id) - ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", f"experts.{expert_id}.{weight_name}.act_scale", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] @@ -512,3 +576,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) From 021b1a2ab7497769dae8a67ea3467e4bafb474c5 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sat, 4 May 2024 13:44:36 -0700 Subject: [PATCH 210/413] [CI] check size of the wheels (#4319) --- .buildkite/check-wheel-size.py | 33 +++++++++++++++++++++++++++++++++ Dockerfile | 12 ++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) create mode 100644 .buildkite/check-wheel-size.py diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py new file mode 100644 index 0000000000000..8178fba552c46 --- /dev/null +++ b/.buildkite/check-wheel-size.py @@ -0,0 +1,33 @@ +import os +import zipfile + +MAX_SIZE_MB = 100 + + +def print_top_10_largest_files(zip_file): + with zipfile.ZipFile(zip_file, 'r') as z: + file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] + file_sizes.sort(key=lambda x: x[1], reverse=True) + for f, size in file_sizes[:10]: + print(f"{f}: {size/(1024*1024)} MBs uncompressed.") + + +def check_wheel_size(directory): + for root, _, files in os.walk(directory): + for f in files: + if f.endswith(".whl"): + wheel_path = os.path.join(root, f) + wheel_size = os.path.getsize(wheel_path) + wheel_size_mb = wheel_size / (1024 * 1024) + if wheel_size_mb > MAX_SIZE_MB: + print( + f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " + f"compare to the allowed size ({MAX_SIZE_MB} MB).") + print_top_10_largest_files(wheel_path) + return 1 + return 0 + + +if __name__ == "__main__": + import sys + sys.exit(check_wheel_size(sys.argv[1])) diff --git a/Dockerfile b/Dockerfile index e8a9842c089dd..90be3a30f89b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev RUN apt-get update -y \ && apt-get install -y python3-pip git @@ -16,7 +16,7 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ WORKDIR /workspace @@ -75,6 +75,10 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ python3 setup.py bdist_wheel --dist-dir=dist +# check the size of the wheel, we cannot upload wheels larger than 100MB +COPY .buildkite/check-wheel-size.py check-wheel-size.py +RUN python3 check-wheel-size.py dist + # the `vllm_nccl` package must be installed from source distribution # pip is too smart to store a wheel in the cache, and other CI jobs # will directly use the wheel from the cache, which is not what we want. @@ -102,7 +106,7 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base +FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base WORKDIR /vllm-workspace RUN apt-get update -y \ @@ -112,7 +116,7 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ From 43029870694de0789a10ab49f181f1cba6ec741a Mon Sep 17 00:00:00 2001 From: DearPlanet <149305930+DearPlanet@users.noreply.github.com> Date: Sun, 5 May 2024 06:39:34 +0800 Subject: [PATCH 211/413] [Bugfix] Fix inappropriate content of model_name tag in Prometheus metrics (#3937) --- tests/metrics/test_metrics.py | 30 +++++++++++++++++++++++++++++ vllm/config.py | 25 ++++++++++++++++++++++++ vllm/engine/arg_utils.py | 20 +++++++++++++++++-- vllm/engine/llm_engine.py | 5 +++-- vllm/entrypoints/openai/cli_args.py | 10 ---------- 5 files changed, 76 insertions(+), 14 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 311e60ba60f61..e0aa14f165c2d 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,3 +1,5 @@ +from typing import List + import pytest from prometheus_client import REGISTRY @@ -76,6 +78,34 @@ def test_metric_counter_generation_tokens( f"metric: {metric_count!r}") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize( + "served_model_name", + [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) +def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, + served_model_name: List[str]) -> None: + vllm_model = vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.3, + served_model_name=served_model_name) + stat_logger = vllm_model.model.llm_engine.stat_logger + metrics_tag_content = stat_logger.labels["model_name"] + + del vllm_model + + if served_model_name is None or served_model_name == []: + assert metrics_tag_content == model, ( + f"Metrics tag model_name is wrong! expect: {model!r}\n" + f"actual: {metrics_tag_content!r}") + else: + assert metrics_tag_content == served_model_name[0], ( + f"Metrics tag model_name is wrong! expect: " + f"{served_model_name[0]!r}\n" + f"actual: {metrics_tag_content!r}") + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [4]) diff --git a/vllm/config.py b/vllm/config.py index 6c65bbe247f84..13bb294591725 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -31,6 +31,8 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. @@ -69,6 +71,10 @@ class ModelConfig: to eager mode skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. """ def __init__( @@ -90,6 +96,7 @@ def __init__( max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -117,6 +124,8 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) + self.served_model_name = get_served_model_name(model, + served_model_name) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_quantization() @@ -1150,6 +1159,22 @@ def _get_and_verify_max_len( return int(max_model_len) +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + @dataclass class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 78cd07575f17d..bb8245eb307f7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,7 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Union from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -21,6 +21,7 @@ def nullable_str(val: str): class EngineArgs: """Arguments for vLLM engine.""" model: str + served_model_name: Optional[Union[List[str]]] = None tokenizer: Optional[str] = None skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' @@ -489,6 +490,21 @@ def add_cli_args( 'This should be a JSON string that will be ' 'parsed into a dictionary.') + parser.add_argument( + "--served-model-name", + nargs="+", + type=str, + default=None, + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument. Noted that this name(s)" + "will also be used in `model_name` tag content of " + "prometheus metrics, if multiple names provided, metrics" + "tag will take the first one.") + return parser @classmethod @@ -508,7 +524,7 @@ def create_engine_config(self, ) -> EngineConfig: self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, self.max_seq_len_to_capture, self.max_logprobs, - self.skip_tokenizer_init) + self.skip_tokenizer_init, self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 342f2c796d6fb..b9938b045ba2b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -106,7 +106,7 @@ def __init__( "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, seed=%d)", + "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, model_config.model, speculative_config, @@ -129,6 +129,7 @@ def __init__( device_config.device, decoding_config, model_config.seed, + model_config.served_model_name, ) # TODO(woosuk): Print more configs in debug mode. @@ -219,7 +220,7 @@ def __init__( if self.log_stats: self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.model), + labels=dict(model_name=model_config.served_model_name), max_model_len=self.model_config.max_model_len) self.stat_logger.info("cache_config", self.cache_config) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 2b57ab26bfd31..4c0cb1e4f3e49 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -56,16 +56,6 @@ def make_arg_parser(): default=None, help="If provided, the server will require this key " "to be presented in the header.") - parser.add_argument("--served-model-name", - nargs="+", - type=nullable_str, - default=None, - help="The model name(s) used in the API. If multiple " - "names are provided, the server will respond to any " - "of the provided names. The model name in the model " - "field of a response will be the first name in this " - "list. If not specified, the model name will be the " - "same as the `--model` argument.") parser.add_argument( "--lora-modules", type=nullable_str, From 8d8357c8ed1f3ddb6a0e8f3287ec669a13d77df1 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sat, 4 May 2024 17:09:49 -0700 Subject: [PATCH 212/413] bump version to v0.4.2 (#4600) --- .github/workflows/scripts/create_release.js | 2 +- vllm/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js index 0f25624b4c21c..475742118afeb 100644 --- a/.github/workflows/scripts/create_release.js +++ b/.github/workflows/scripts/create_release.js @@ -8,7 +8,7 @@ module.exports = async (github, context, core) => { generate_release_notes: true, name: process.env.RELEASE_TAG, owner: context.repo.owner, - prerelease: false, + prerelease: true, repo: context.repo.repo, tag_name: process.env.RELEASE_TAG, }); diff --git a/vllm/__init__.py b/vllm/__init__.py index ca454efd44b24..59810da3ca411 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.4.1" +__version__ = "0.4.2" __all__ = [ "LLM", From c7f2cf2b7f67bce5842fedfdba508440fe257375 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sat, 4 May 2024 21:28:58 -0700 Subject: [PATCH 213/413] [CI] Reduce wheel size by not shipping debug symbols (#4602) --- .buildkite/check-wheel-size.py | 3 +++ .github/workflows/publish.yml | 2 ++ 2 files changed, 5 insertions(+) diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 8178fba552c46..90a5e54736cf3 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -25,6 +25,9 @@ def check_wheel_size(directory): f"compare to the allowed size ({MAX_SIZE_MB} MB).") print_top_10_largest_files(wheel_path) return 1 + else: + print(f"Wheel {wheel_path} is within the allowed size " + f"({wheel_size_mb} MB).") return 0 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d79681f03b003..ac60ce0fed14a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -79,6 +79,8 @@ jobs: - name: Build wheel shell: bash + env: + CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size run: | bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} wheel_name=$(ls dist/*whl | xargs -n 1 basename) From 0650e5935b0f6af35fb2acf71769982c47b804d7 Mon Sep 17 00:00:00 2001 From: zhaoyang-star Date: Mon, 6 May 2024 07:58:55 +0800 Subject: [PATCH 214/413] Disable cuda version check in vllm-openai image (#4530) --- vllm/config.py | 11 +---------- vllm/utils.py | 24 +----------------------- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 13bb294591725..5c3a8615eefb4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4,15 +4,13 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch -from packaging.version import Version from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, - is_neuron) +from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron GPTQMarlinConfig = get_quantization_config("gptq_marlin") @@ -369,13 +367,6 @@ def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass elif self.cache_dtype == "fp8": - if not is_hip(): - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version is not None \ - and nvcc_cuda_version < Version("11.8"): - raise ValueError( - "FP8 is not supported when cuda version is" - "lower than 11.8.") logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " diff --git a/vllm/utils.py b/vllm/utils.py index b06c8508757c5..6479a8dab320a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -19,7 +19,6 @@ import psutil import torch -from packaging.version import Version, parse import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger @@ -314,27 +313,6 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) -@lru_cache(maxsize=None) -def get_nvcc_cuda_version() -> Optional[Version]: - cuda_home = envs.CUDA_HOME - if not cuda_home: - cuda_home = '/usr/local/cuda' - if os.path.isfile(cuda_home + '/bin/nvcc'): - logger.info( - 'CUDA_HOME is not found in the environment. ' - 'Using %s as CUDA_HOME.', cuda_home) - else: - logger.warning('Not found nvcc in %s. Skip cuda version check!', - cuda_home) - return None - nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], - universal_newlines=True) - output = nvcc_output.split() - release_idx = output.index("release") + 1 - nvcc_cuda_version = parse(output[release_idx].split(",")[0]) - return nvcc_cuda_version - - def _generate_random_fp8( tensor: torch.tensor, low: float, @@ -560,7 +538,7 @@ def maybe_expand_dim(tensor: torch.Tensor, def merge_dicts(dict1: Dict[Any, List[Any]], dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. - + When a key conflicts, the values in dict1 is prioritized. """ merged_dict = defaultdict(list) From 323f27b9048713cdbab31995265975842a937167 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 7 May 2024 00:31:05 +0800 Subject: [PATCH 215/413] [Bugfix] Fix `asyncio.Task` not being subscriptable (#4623) --- vllm/engine/async_llm_engine.py | 6 +++--- vllm/entrypoints/openai/api_server.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9f72a0d11974f..37a2dc77a3b50 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, - Optional, Set, Tuple, Type, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, + Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -327,7 +327,7 @@ def __init__(self, # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected - self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None + self._background_loop_unshielded: Optional[asyncio.Task] = None self.start_engine_loop = start_engine_loop self._errored_with: Optional[BaseException] = None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f9e294af47253..44a946f2e32d4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,7 +4,7 @@ import re from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Any, Set +from typing import Set import fastapi import uvicorn @@ -34,7 +34,7 @@ openai_serving_completion: OpenAIServingCompletion logger = init_logger(__name__) -_running_tasks: Set[asyncio.Task[Any]] = set() +_running_tasks: Set[asyncio.Task] = set() @asynccontextmanager From e186d37cb135107a09cd684e4fa2cf30c0ce6f28 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 6 May 2024 15:23:36 -0700 Subject: [PATCH 216/413] [CI] use ccache actions properly in release workflow (#4629) --- .github/workflows/publish.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ac60ce0fed14a..9c35ede5f6781 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -58,6 +58,9 @@ jobs: - name: Setup ccache uses: hendrikmuhs/ccache-action@v1.2 + with: + create-symlink: true + key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - name: Set up Linux Env if: ${{ runner.os == 'Linux' }} From 19cb4716ee700e5d8baa64d7cf14fb5da3737f6d Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Mon, 6 May 2024 16:18:57 -0700 Subject: [PATCH 217/413] [CI] Add retry for agent lost (#4633) --- .buildkite/test-template.j2 | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index ea02b6b1e9c9e..919a09e1cc064 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -14,6 +14,8 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 - wait - group: "AMD Tests" @@ -53,6 +55,8 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 plugins: - kubernetes: podSpec: From bd99d226295776011f4ea4831498a7103bc4e43b Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Tue, 7 May 2024 02:51:59 +0300 Subject: [PATCH 218/413] Update lm-format-enforcer to 0.10.1 (#4631) --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 3abb828116680..bd779d5acb68e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,7 +14,7 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer == 0.9.8 +lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 From a98187cf7227695819e199e2e3ad35be0a9a84f3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 6 May 2024 17:39:28 -0700 Subject: [PATCH 219/413] [Kernel] Make static FP8 scaling more robust (#4570) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously FP8 static scaling works if the scales are overestimating the maxima of all activation tensors during computation. However this will not always be the case even if the scales were calibrated very carefully. For example, with the activations in my checkpoint https://huggingface.co/pcmoritz/Mixtral-8x7B-v0.1-fp8-act-scale (which was calibrated on https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), I'm getting the following mostly random performance on MMLU: | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.2295|± |0.0035| | - humanities |N/A |none | 5|acc |0.2421|± |0.0062| | - other |N/A |none | 5|acc |0.2398|± |0.0076| | - social_sciences|N/A |none | 5|acc |0.2171|± |0.0074| | - stem |N/A |none | 5|acc |0.2125|± |0.0073| With the fix in this PR where the scaled activations are clamped between [-std::numeric_limits::max(), std::numeric_limits::max()] to make sure there are no NaNs, the performance is | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7008|± |0.0036| | - humanities |N/A |none | 5|acc |0.6453|± |0.0065| | - other |N/A |none | 5|acc |0.7692|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8083|± |0.0070| | - stem |N/A |none | 5|acc |0.6115|± |0.0083| This is not perfect yet but is getting very close to the FP16 / dynamic activation scale performance. --- csrc/quantization/fp8/fp8_cuda_kernels.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 2477051eb60d7..b9c5d39277ca5 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -17,6 +17,15 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } +#define FP8_E4M3_MAX std::numeric_limits::max() + +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { + float x = static_cast(val) / scale; + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + // Compute the absolute maximum m of the input tensor and store // m / float8_e4m3::max() in *scale. Each thread block performs a // reduction tree and the memory in scale is atomically updated. @@ -67,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel( int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { - out[i] = static_cast(input[i] / *scale); + out[i] = scaled_fp8_conversion(input[i], *scale); i += blockDim.x * gridDim.x; } } From 63575bc2e197b85ce1c911421ff30c5459e35e9c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 6 May 2024 21:30:27 -0700 Subject: [PATCH 220/413] [Core][Optimization] change python dict to pytorch tensor (#4607) --- csrc/cache.h | 2 +- csrc/cache_kernels.cu | 20 +++-------- csrc/cpu/cache.cpp | 20 ++++------- tests/core/test_scheduler.py | 8 ++--- tests/kernels/test_cache.py | 21 ++++++----- tests/worker/test_swap.py | 2 +- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 2 +- vllm/attention/backends/torch_sdpa.py | 2 +- vllm/attention/backends/xformers.py | 2 +- vllm/attention/ops/paged_attn.py | 2 +- vllm/core/scheduler.py | 41 +++++++++++----------- vllm/distributed/communication_op.py | 7 ++++ vllm/sequence.py | 6 ++-- vllm/worker/cache_engine.py | 2 +- vllm/worker/cpu_worker.py | 7 ++-- vllm/worker/worker.py | 8 +++-- 19 files changed, 77 insertions(+), 81 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 4c142ce17f1b9..10871b3670bac 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -13,7 +13,7 @@ void swap_blocks( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping); + torch::Tensor& block_mapping); void reshape_and_cache( torch::Tensor& key, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 42f884c76c620..1e02f7fcbae4c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping) { + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -114,17 +114,9 @@ void copy_blocks( key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } - // Create block mapping array. - std::vector block_mapping_vec; - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - block_mapping_vec.push_back(src_block_number); - block_mapping_vec.push_back(dst_block_number); - } - } - int64_t* block_mapping_array = block_mapping_vec.data(); - int num_pairs = block_mapping_vec.size() / 2; + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. @@ -132,8 +124,6 @@ void copy_blocks( key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor = torch::from_blob( value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -146,7 +136,7 @@ void copy_blocks( vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); })); } diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 7849a5df991b1..95e3f11900fde 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -8,16 +8,16 @@ template void copy_blocks_cpu_impl( std::vector &key_caches, std::vector &value_caches, - const std::vector> mapping_pairs, + const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { - const size_t pair_num = mapping_pairs.size(); + const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = - element_num_per_block * mapping_pairs[pair].second; + element_num_per_block * mapping_pairs[pair][1].item(); scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t *target_ptr = key_cache_ptr + target_offset; @@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, - const std::map> &block_mapping) { + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; } - std::vector> mapping_pairs; - mapping_pairs.reserve(block_mapping.size()); - for (const auto &pair : block_mapping) { - for (const auto &dst : pair.second) { - mapping_pairs.emplace_back(pair.first, dst); - } - } - const int element_num_per_block = key_caches[0][0].numel(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, element_num_per_block, num_layers); CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) }); diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 1358dffec8104..348169035ae97 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -568,7 +568,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == {} # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_decode_swap_beam_search(): @@ -618,7 +618,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == expected_swap_mapping # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_schedule_decode_blocks_to_copy_update(): @@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update(): assert output.blocks_to_swap_out == {} # Since append_slot returns the source -> dist mapping, it should # applied. - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_schedule_swapped_simple(): @@ -853,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy(): assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_scheduling_budget(): diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index ca215bb75837a..94a577139596e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -63,12 +63,13 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} + block_mapping = [] for i in range(num_mappings): src = src_blocks[i] dst1 = dst_blocks[2 * i] dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -81,15 +82,17 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - ops.copy_blocks(key_caches, value_caches, block_mapping) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device=device).view(-1, 2) + ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dst in block_mapping: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 07bcd343a96a6..4d2d3add27d59 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -59,7 +59,7 @@ def test_swap() -> None: seq_group_metadata_list=[], blocks_to_swap_in={}, blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}, + blocks_to_copy=[], ) worker.execute_model(execute_model_req=execute_model_req) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 61c9c81d8a7b8..b2b6e7ac810e3 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -42,7 +42,7 @@ def swap_blocks( @abstractmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index fc7501ed5e91f..da672d5df6161 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -48,7 +48,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8ab4b1f12ee36..2851cbe2396b2 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -48,7 +48,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c411b3971b8f1..c3b522e63b4b8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -46,7 +46,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index f75a279086a26..03825f6023f4c 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -44,7 +44,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 60f6d43f2eaa4..4c7fa71a2c78e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -49,7 +49,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 00a0f10c0950b..6f7fd51c774f8 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -209,7 +209,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a9e0b05b8db67..de3ecd24e52db 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -13,7 +13,6 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.utils import merge_dicts logger = init_logger(__name__) @@ -122,8 +121,8 @@ class SchedulerOutputs: blocks_to_swap_in: Dict[int, int] # Blocks to swap out. Dict of GPU -> CPU block number. blocks_to_swap_out: Dict[int, int] - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. @@ -177,7 +176,7 @@ class SchedulerRunningOutputs: # The blocks to swap out. blocks_to_swap_out: Dict[int, int] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int @@ -189,7 +188,7 @@ def create_empty(cls) -> "SchedulerRunningOutputs": preempted=[], swapped_out=[], blocks_to_swap_out={}, - blocks_to_copy={}, + blocks_to_copy=[], num_lookahead_slots=0, ) @@ -209,7 +208,7 @@ class SchedulerSwappedInOutputs: # The blocks to swap in. blocks_to_swap_in: Dict[int, int] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int # Infeasible sequence groups. @@ -221,7 +220,7 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs": decode_seq_groups=[], prefill_seq_groups=[], blocks_to_swap_in={}, - blocks_to_copy={}, + blocks_to_copy=[], num_lookahead_slots=0, infeasible_seq_groups=[], ) @@ -394,7 +393,7 @@ def _schedule_running( """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_out: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -511,7 +510,7 @@ def _schedule_swapped( """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() @@ -794,8 +793,8 @@ def _schedule_default(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, @@ -882,8 +881,8 @@ def _schedule_chunked_prefill(self): num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), @@ -1011,17 +1010,18 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _append_slots( self, seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: List[Tuple[int, int]], ) -> None: """Appends new slots to the sequences in the given sequence group. Args: seq_group (SequenceGroup): The sequence group containing the sequences to append slots to. - blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source - block indices to lists of destination block indices. This - dictionary is updated with the new source and destination block - indices for the appended slots. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) @@ -1029,9 +1029,8 @@ def _append_slots( cows = self.block_manager.append_slots(seq, num_lookahead_slots) for src, dests in cows.items(): - if src not in blocks_to_copy: - blocks_to_copy[src] = [] - blocks_to_copy[src].extend(dests) + for dest in dests: + blocks_to_copy.append((src, dest)) def _preempt( self, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index b539a7beedbfe..817bd6d812e48 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -203,6 +203,9 @@ def broadcast_tensor_dict( group=metadata_group) async_handles = [] for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue async_handles.append( torch.distributed.broadcast(tensor, src=src, @@ -224,6 +227,10 @@ def broadcast_tensor_dict( tensor = torch.empty(value.size, dtype=value.dtype, device="cuda") + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue async_handle = torch.distributed.broadcast(tensor, src=src, async_op=True, diff --git a/vllm/sequence.py b/vllm/sequence.py index f2939eff7959b..b486d1fedebd3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -2,7 +2,7 @@ import copy import enum from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest @@ -745,8 +745,8 @@ class ExecuteModelRequest: blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) # Blocks to swap out. Dict of GPU -> CPU block number. blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict) + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) # The number of slots for lookahead decoding. num_lookahead_slots: int = 0 # The number of requests in the running queue. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c34ee0648626b..26a60c652b6f4 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -77,7 +77,7 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4420d4cc9e12f..e1ef500ac07b8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -248,9 +248,9 @@ def _init_cache_engine(self) -> None: def cache_copy( self, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() @@ -269,6 +269,9 @@ def execute_model( num_seq_groups: int = len(seq_group_metadata_list) assert execute_model_req is not None blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) assert len(execute_model_req.blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0 data: Dict[str, Any] = { diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4add36e94f723..538332ad003f1 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -197,7 +197,7 @@ def cache_swap( self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. # TODO(woosuk): Profile swapping overhead and optimize if needed. @@ -205,7 +205,7 @@ def cache_swap( self.cache_engine.swap_in(blocks_to_swap_in) if blocks_to_swap_out: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() @@ -225,7 +225,9 @@ def execute_model( num_seq_groups = len(seq_group_metadata_list) blocks_to_swap_in = execute_model_req.blocks_to_swap_in blocks_to_swap_out = execute_model_req.blocks_to_swap_out - blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, From 478aed5827169ef3ee07fdab42a935532a9ff68d Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Tue, 7 May 2024 11:23:17 -0500 Subject: [PATCH 221/413] [Build/CI] Fixing 'docker run' to re-enable AMD CI tests. (#4642) --- .buildkite/run-amd-test.sh | 2 +- .buildkite/test-pipeline.yaml | 8 ++++---- .buildkite/test-template.j2 | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index c04e05a994894..ce508e4748aba 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -40,5 +40,5 @@ docker run \ -e HF_TOKEN \ --name ${container_name} \ ${container_name} \ - /bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//") + /bin/bash -c "${@}" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e49a5650c44ea..cee5e7e9d2a73 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -48,7 +48,7 @@ steps: - pytest -v -s test_pynccl.py - label: Engine Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test @@ -73,13 +73,13 @@ steps: parallelism: 4 - label: Models Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py - label: Llava Test - mirror_hardwares: [amd] + #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py @@ -101,7 +101,7 @@ steps: command: pytest -v -s worker - label: Speculative decoding tests - mirror_hardwares: [amd] + #mirror_hardwares: [amd] command: pytest -v -s spec_decode - label: LoRA Test %N diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 919a09e1cc064..174c756ae74a3 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -26,7 +26,7 @@ steps: - label: "AMD: {{ step.label }}" agents: queue: amd - command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" + command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}" env: DOCKER_BUILDKIT: "1" {% endif %} From 10760da8003824e208c94fb2bfcdb6fdd0f4edda Mon Sep 17 00:00:00 2001 From: Austin Veselka <50646302+FurtherAI@users.noreply.github.com> Date: Tue, 7 May 2024 12:59:07 -0500 Subject: [PATCH 222/413] [Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609) --- vllm/lora/fully_sharded_layers.py | 54 +++++++++++++++++-------------- vllm/lora/layers.py | 30 ++++++++++++----- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 1720566840bb1..ffdc32b7339af 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,5 +1,5 @@ # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Union import torch import torch.nn as nn @@ -51,10 +51,9 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, @@ -88,7 +87,7 @@ def can_replace_layer(cls, source_layer: nn.Module, ) -def _mcp_apply_weights(x, bias, layer): +def _mcp_apply(x, bias, layer): """ MergedColumnParallelLinearWithShardedLoRA and QKVParallelLinearWithShardedLora share the same @@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer): """ # expecting 2 for column parallel and 3 for qkv n = len(layer.lora_a_stacked) - output = layer.base_layer.linear_method.apply_weights( - layer.base_layer, x, bias) + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape @@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA( Based on S-LoRA, slicing happens along the rank dim. """ - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None: + return lora_a output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size lora_a = [ - lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] - for i in range(2) + lora_a[0][:, + output_start_idx:output_start_idx + output_shard_size], + lora_a[1][:, output_start_idx:output_start_idx + output_shard_size] ] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return _mcp_apply_weights(x, bias, self) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): Based on S-LoRA, slicing happens along the rank dim. """ - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None: + return lora_a shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] lora_a = [ - lora_a[i][:, start_idx[i]:start_idx[i] + - shard_size[i]] if lora_a[i] is not None else None - for i in range(3) + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], + lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]] ] return lora_a - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - return _mcp_apply_weights(x, bias, self) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -218,9 +225,8 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: lora_b = lora_b[:, start_idx:end_idx] return lora_b - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x) + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index b3609666b2ec7..90f63c34fb2d3 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -145,11 +145,15 @@ def __post_init__(self): class BaseLayerWithLoRA(nn.Module): - def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: """Slice lora a if splitting for tensor parallelism.""" ... - def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: """Slice lora b if splitting with tensor parallelism.""" ... @@ -539,10 +543,16 @@ def reset_lora(self, index: int): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: return lora_a - def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_b[0] is None or lora_b[1] is None: + return lora_b shard_size = self.output_dim start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size @@ -767,10 +777,15 @@ def reset_lora(self, index: int): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 - def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: return lora_a - def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + lora_b_q, lora_b_k, lora_b_v = None, None, None if lora_b[0] is not None: lora_b_q = lora_b[0][:, self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size * @@ -992,7 +1007,6 @@ def forward(self, input_): @property def weight(self): - return self.base_layer.weight if hasattr( self.base_layer, "weight") else self.base_layer.qweight From 469f85c7829c301b6dec48725951b5501c18d611 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 May 2024 11:06:32 -0700 Subject: [PATCH 223/413] [Core][Optimization] change copy-on-write from dict[int, list] to list (#4648) --- tests/core/block/test_block_table.py | 6 ++---- tests/core/test_block_manager.py | 6 +++++- tests/core/test_scheduler.py | 4 ++-- vllm/core/block/common.py | 21 ++++++++++----------- vllm/core/block/cpu_gpu_block_allocator.py | 8 ++++---- vllm/core/block/interfaces.py | 6 +++--- vllm/core/block/naive_block.py | 8 ++++---- vllm/core/block/prefix_caching_block.py | 8 ++++---- vllm/core/block_manager_v1.py | 10 +++++----- vllm/core/block_manager_v2.py | 3 ++- vllm/core/interfaces.py | 3 ++- vllm/core/scheduler.py | 5 +---- 12 files changed, 44 insertions(+), 44 deletions(-) diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py index 3481d6b4312c1..6fb95cfdfab81 100644 --- a/tests/core/block/test_block_table.py +++ b/tests/core/block/test_block_table.py @@ -410,8 +410,7 @@ def test_cow(block_size: int, sequence_len: int, append_len: int, expected_src = static_block_table.physical_block_ids[cow_block_id] expected_dst = appender_block_table.physical_block_ids[cow_block_id] - assert expected_src in cows - assert expected_dst in cows[expected_src] + assert (expected_src, expected_dst) in cows else: # Otherwise, there should be no copy-on-write. assert not cows @@ -490,8 +489,7 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int, expected_src = static_block_table.physical_block_ids[cow_block_id] expected_dst = appender_block_table.physical_block_ids[cow_block_id] - assert expected_src in cows - assert expected_dst in cows[expected_src] + assert (expected_src, expected_dst) in cows static_block_table.free() appender_block_table.free() diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 9f9a6180add78..08d34efb8302c 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -1,4 +1,5 @@ import time +from collections import defaultdict from typing import List import pytest @@ -155,7 +156,10 @@ def test_append_slot_cow(): cows = block_manager.append_slots(child) assert cows - for src_block, dst_blocks in cows.items(): + dict_cows = defaultdict(list) + for src_block, dst_block in cows: + dict_cows[src_block].append(dst_block) + for src_block, dst_blocks in dict_cows.items(): assert src_block not in dst_blocks after_blocks = block_manager.get_num_free_gpu_blocks() diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 348169035ae97..3f0c918a89abb 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -636,7 +636,7 @@ def test_schedule_decode_blocks_to_copy_update(): # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = {2: [3]} + scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() remaining_running, output = scheduler._schedule_running( @@ -845,7 +845,7 @@ def test_schedule_swapped_blocks_to_copy(): # The last request should be swapped out. scheduler.block_manager.append_slots = MagicMock() - scheduler.block_manager.append_slots.return_value = {2: [3]} + scheduler.block_manager.append_slots.return_value = [(2, 3)] budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 3f97a1210b096..4d7a12165cb01 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,5 +1,4 @@ -from collections import defaultdict -from typing import Dict, Iterable, List, Optional, Protocol +from typing import Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -111,7 +110,7 @@ def __init__( refcounter: RefCounterProtocol, allocator: BlockAllocator, ): - self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] self._refcounter = refcounter self._allocator = allocator @@ -152,25 +151,25 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: # Track src/dst copy. assert src_block_id is not None assert block_id is not None - self._copy_on_writes[src_block_id].append(block_id) + self._copy_on_writes.append((src_block_id, block_id)) return block_id - def clear_cows(self) -> Dict[BlockId, List[BlockId]]: + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: """Clears the copy-on-write tracking information and returns the current state. - This method returns a dictionary mapping source block indices to lists - of destination block indices for the current copy-on-write operations. + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. It then clears the internal tracking information. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices for the + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the current copy-on-write operations. """ - cows = dict(self._copy_on_writes) - self._copy_on_writes.clear() + cows = self._copy_on_writes + self._copy_on_writes = [] return cows diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5b25e1bcdada0..0577ca76ea971 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -1,4 +1,4 @@ -from typing import Dict, FrozenSet, List, Optional +from typing import Dict, FrozenSet, List, Optional, Tuple from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) @@ -185,13 +185,13 @@ def get_num_free_blocks(self, device: Device) -> int: def get_num_total_blocks(self, device: Device) -> int: return self._allocators[device].get_num_total_blocks() - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. Returns: - Dict[int, List[int]]: A dictionary mapping source block IDs to lists - of destination block IDs. + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. """ # CoW only supported on GPU device = Device.GPU diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 634c4016ca19c..140fbbb0949cc 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, FrozenSet, List, Optional, Protocol +from typing import FrozenSet, List, Optional, Protocol, Tuple from vllm.utils import Device @@ -122,7 +122,7 @@ def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: pass @abstractmethod @@ -187,7 +187,7 @@ def all_block_ids(self) -> FrozenSet[int]: pass @abstractmethod - def clear_copy_on_writes(self) -> Dict[int, List[int]]: + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index a1b901bf78efc..ae01930878254 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,4 +1,4 @@ -from typing import Dict, FrozenSet, Iterable, List, Optional, Set +from typing import FrozenSet, Iterable, List, Optional, Set, Tuple from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) @@ -175,12 +175,12 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: """ return self._cow_tracker.cow_block_if_not_appendable(block) - def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices. + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. """ return self._cow_tracker.clear_cows() diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 4a37e8f87c379..882f301c1f697 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,7 +1,7 @@ """Token blocks.""" from itertools import takewhile from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional +from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) @@ -337,12 +337,12 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: """ return self._cow_tracker.cow_block_if_not_appendable(block) - def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: """Returns the copy-on-write source->destination mapping and clears it. Returns: - Dict[BlockId, List[BlockId]]: A dictionary mapping source - block indices to lists of destination block indices. + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. """ return self._cow_tracker.clear_cows() diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 268c5c135d887..4e7392f3486c9 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -5,7 +5,7 @@ from os.path import commonprefix from typing import Dict, List, Optional from typing import Sequence as GenericSequence -from typing import Set +from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor @@ -386,7 +386,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int = 0, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: """Allocate a physical slot for a new token.""" logical_blocks = seq.logical_token_blocks block_table = self.block_tables[seq.seq_id] @@ -405,7 +405,7 @@ def append_slots( # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) - return {} + return [] # We want to append the token to the last physical block. last_block = block_table[-1] @@ -418,7 +418,7 @@ def append_slots( maybe_new_block = self._maybe_promote_last_block( seq, last_block) block_table[-1] = maybe_new_block - return {} + return [] else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. @@ -426,7 +426,7 @@ def append_slots( block_table[-1] = new_block self.gpu_allocator.free(last_block) - return {last_block.block_number: [new_block.block_number]} + return [(last_block.block_number, new_block.block_number)] def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index ce90ce2f17278..3b483e67ad9c1 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,6 +1,7 @@ """A block manager that manages token blocks.""" from typing import Dict, List, Optional from typing import Sequence as GenericSequence +from typing import Tuple from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator @@ -166,7 +167,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 09ccaddb62615..ab2c8ea0053dd 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import Dict, List from typing import Sequence as GenericSequence +from typing import Tuple from vllm.sequence import Sequence, SequenceGroup @@ -54,7 +55,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, - ) -> Dict[int, List[int]]: + ) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index de3ecd24e52db..f426ee95c0ca2 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1027,10 +1027,7 @@ def _append_slots( for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) - - for src, dests in cows.items(): - for dest in dests: - blocks_to_copy.append((src, dest)) + blocks_to_copy.extend(cows) def _preempt( self, From 8344f7742b794ca6ec9bcb891c178cd0551f23d0 Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Wed, 8 May 2024 02:40:18 +0800 Subject: [PATCH 224/413] [Bug fix][Core] fixup ngram not setup correctly (#4551) Co-authored-by: Lei Wen Co-authored-by: Cade Daniel Co-authored-by: Cody Yu --- tests/spec_decode/e2e/conftest.py | 24 ++++++++++++++++++------ vllm/executor/gpu_executor.py | 4 ++++ vllm/spec_decode/spec_decode_worker.py | 14 +++++++------- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b1ab8a07ca636..eda7293ea7cee 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -55,7 +55,7 @@ def __init__( ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - self.engine_args = AsyncEngineArgs( + engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, @@ -76,6 +76,8 @@ def __init__( **kwargs, ) self.request_counter = Counter() + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) def generate( self, @@ -88,9 +90,6 @@ def generate( multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: - llm_engine = AsyncLLMEngine.from_engine_args( - self.engine_args, usage_context=UsageContext.LLM_CLASS) - if prompts is None: raise ValueError("prompts must be provided.") if isinstance(prompts, str): @@ -111,8 +110,8 @@ def generate( async def get_output(prompt, sampling_param) -> str: request_id = random_uuid() - results_generator = llm_engine.generate(prompt, sampling_param, - request_id) + results_generator = self.llm_engine.generate( + prompt, sampling_param, request_id) final_output = None async for request_output in results_generator: final_output = request_output @@ -185,12 +184,25 @@ def generator_outer(): return generator_outer +def maybe_assert_ngram_worker(llm): + # Verify the proposer worker is ngram if ngram is specified. + if (not isinstance(llm, AsyncLLM) + and llm.llm_engine.speculative_config is not None + and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0): + from vllm.spec_decode.ngram_worker import NGramWorker + assert isinstance( + llm.llm_engine.model_executor.driver_worker.proposer_worker, + NGramWorker) + + def get_output_from_llm_generator( llm_generator, prompts, sampling_params) -> Tuple[List[str], List[List[int]]]: tokens = [] token_ids = [] for llm in llm_generator(): + maybe_assert_ngram_worker(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 1af3bcf380843..e8559b6a5c0fe 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -82,6 +82,10 @@ def _init_spec_worker(self): draft_worker_kwargs.update( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=self.speculative_config. + ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.speculative_config. + ngram_prompt_lookup_min, # TODO allow draft-model specific load config. #load_config=self.load_config, ) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index c2b119fbd5036..84ec974806c7e 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -57,13 +57,10 @@ def create_worker( draft_worker_kwargs, ) -> "SpecDecodeWorker": - if "ngram_prompt_lookup_max" in draft_worker_kwargs: - ngram_prompt_lookup_max = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_max")) - ngram_prompt_lookup_min = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - else: - ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) if ngram_prompt_lookup_max > 0: proposer_worker = NGramWorker(**draft_worker_kwargs) @@ -72,6 +69,9 @@ def create_worker( else: proposer_worker = MultiStepWorker(**draft_worker_kwargs) + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + return SpecDecodeWorker( proposer_worker, scorer_worker, From cc466a32903d53d0ceca459b766d74ad668c8f87 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 7 May 2024 19:34:47 -0700 Subject: [PATCH 225/413] [Core][Distributed] support cpu&device in broadcast tensor dict (#4660) [Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660) --- tests/distributed/test_comm_ops.py | 7 +++- vllm/distributed/communication_op.py | 56 +++++++++++++++++----------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index aa9e0537c6910..9a7a1f07e1b8d 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, init_test_distributed_environment(1, tensor_parallel_size, rank, distributed_init_port) test_dict = { + # device tensor "a": torch.arange(8, dtype=torch.float32, device="cuda"), - "b": torch.arange(16, dtype=torch.int8, device="cuda"), + # CPU tensor + "b": torch.arange(16, dtype=torch.int8, device="cpu"), "c": "test", "d": [1, 2, 3], "e": { "a": 1, "b": 2 }, + # empty tensor + "f": torch.tensor([], dtype=torch.float32, device="cuda"), } if rank == 0: @@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, assert recv_dict["c"] == test_dict["c"] assert recv_dict["d"] == test_dict["d"] assert recv_dict["e"] == test_dict["e"] + assert torch.allclose(recv_dict["f"], test_dict["f"]) @pytest.mark.skipif(torch.cuda.device_count() < 2, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 817bd6d812e48..80d03129bdb9b 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -137,7 +137,7 @@ def broadcast_object_list(obj_list: List[Any], return obj_list -TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) def _split_tensor_dict( @@ -152,15 +152,13 @@ def _split_tensor_dict( tensor_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): - # Note(youkaichao): currently this only supports broadcasting - # tensors on cuda. In the future, we can add device as a field in - # TensorMetadata to support broadcasting tensors on different - # devices. - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append((key, TensorMetadata(value.dtype, - value.size()))) + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = "cpu" if value.is_cpu else "cuda" + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) else: metadata_list.append((key, value)) @@ -206,11 +204,19 @@ def broadcast_tensor_dict( if tensor.numel() == 0: # Skip broadcasting empty tensors. continue - async_handles.append( - torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True)) + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -226,16 +232,24 @@ def broadcast_tensor_dict( if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, - device="cuda") + device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. tensor_dict[key] = tensor continue - async_handle = torch.distributed.broadcast(tensor, - src=src, - async_op=True, - group=group) - async_handles.append(async_handle) + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) tensor_dict[key] = tensor else: tensor_dict[key] = value From d7740ea4dcee4ab75d7d6eef723f33cae957b288 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 9 May 2024 00:42:28 +0900 Subject: [PATCH 226/413] [Core] Optimize sampler get_logprobs (#4594) --- vllm/model_executor/layers/sampler.py | 117 +++++++++++++++----------- 1 file changed, 68 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1f19d2053d996..e52e350d2726f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -782,13 +782,14 @@ def _get_logprobs( top_logprobs, top_token_ids = torch.topk(logprobs, largest_num_logprobs, dim=-1) - top_logprobs = top_logprobs.cpu() - top_token_ids = top_token_ids.cpu() else: top_logprobs, top_token_ids = None, None - selected_logprobs = selected_logprobs.cpu() - ranks = ranks.cpu() + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + if top_logprobs is not None and top_token_ids is not None: + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') # Find prompt/sample logprobs. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] @@ -828,37 +829,48 @@ def _get_prompt_logprob_if_needed( # Find prompt logprobs prompt_logprobs: Optional[PromptLogprobs] = None - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: prompt_logprobs = [] num_logprobs = sampling_params.prompt_logprobs next_prompt_tokens = _get_next_prompt_tokens(seq_group) - for token_id in next_prompt_tokens: + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): # Calculate the prompt logprob of the real prompt tokens. - # Use tuple here for performance (to use to_list()). # {token_id: (logprob, rank_from_vocab)} prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { - token_id: (selected_logprobs[selected_logprobs_idx].item(), - ranks[selected_logprobs_idx].item()) + token_id: (selected_logprob_items[idx], rank_items[idx]) } # Add top K prompt logprobs along with its rank. if num_logprobs > 0: - prompt_logprobs_dict.update( - zip( - top_token_ids[top_logprob_idx, :num_logprobs].tolist(), - zip( - top_logprobs[ - top_logprob_idx, :num_logprobs].tolist(), - # This is ranks. Since top_logprob is sorted, - # we can just use a range here. - range(1, num_logprobs + 1)))) + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) prompt_logprobs.append({ token_id: Logprob(*logprob_and_rank) for token_id, logprob_and_rank in prompt_logprobs_dict.items() }) # + 1 to go to the next prompt token. top_logprob_idx += 1 - selected_logprobs_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) return prompt_logprobs, top_logprob_idx, selected_logprobs_idx @@ -874,47 +886,54 @@ def _get_sampled_logprob_if_needed( ): """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 + num_logprobs = seq_group.sampling_params.logprobs or 0 sampled_logprobs: SampleLogprobs = [] next_token_ids, parent_seq_ids = sample_result if seq_group.do_sample: assert len(next_token_ids) > 0 - for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): - # Calculate the sample logprob of the real sampled tokens. - # Use tuple here for performance (to use to_list()). - # token_id: (logprob, rank_from_vocab) - sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { - next_token_id: - (selected_logprobs[selected_logprobs_idx].item(), - ranks[selected_logprobs_idx].item()) + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, + parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: (selected_logprob_items[idx], rank_items[idx]) } - # +1 to go to the next sampled token. Note that - # selected_logprobs can contain duplicates unlike top_logprobs - # when beam search is enabled. - selected_logprobs_idx += 1 - - # Second, add top K logprobs along with its rank. - if num_logprobs >= 0: - sampled_logprobs_dict.update( - zip( - top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist(), - zip( - top_logprobs[top_logprob_idx + - parent_id, :num_logprobs].tolist(), - # This is rank. Since top_logprob is sorted, we - # can just use a range here. - range(1, num_logprobs + 1)))) + # Get top K logprobs. + if num_logprobs > 0: + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[top_logprob_idx + + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) + sampled_logprobs.append({ token_id: Logprob(*logprob_and_rank) for token_id, logprob_and_rank in sampled_logprobs_dict.items() }) - # There are len(seq_ids) number of sampled tokens for the current - # sequence group in top_logprobs. Jump to the next seq_group. + + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. top_logprob_idx += len(seq_ids) return sampled_logprobs, top_logprob_idx, selected_logprobs_idx From f6a593093ac201c286e99a849091801a88d83622 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Thu, 9 May 2024 00:44:35 +0900 Subject: [PATCH 227/413] [CI] Make mistral tests pass (#4596) --- .buildkite/test-pipeline.yaml | 2 +- tests/conftest.py | 62 +++++++++++++++++++ tests/models/test_big_models.py | 2 +- tests/models/test_mistral.py | 33 +++++----- .../model_executor/layers/rotary_embedding.py | 5 +- 5 files changed, 85 insertions(+), 19 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cee5e7e9d2a73..2eeba904a209d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -76,7 +76,7 @@ steps: #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py + - pytest -v -s models --ignore=models/test_llava.py - label: Llava Test #mirror_hardwares: [amd] diff --git a/tests/conftest.py b/tests/conftest.py index 671326915b22b..1f2ad1cbd7298 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -272,6 +272,68 @@ def generate_greedy_logprobs( all_logprobs.append(seq_logprobs) return all_logprobs + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str]]: + all_logprobs = [] + all_output_ids = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def __del__(self): del self.model cleanup() diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 3dde498bcd639..c02204f16ac68 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -8,7 +8,7 @@ MODELS = [ "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", # Broken + # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 7aeff3a913098..33d28da85d9e7 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -4,6 +4,8 @@ """ import pytest +from tests.models.utils import check_logprobs_close + MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", ] @@ -11,30 +13,31 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.skip( - "Two problems: 1. Failing correctness tests. 2. RuntimeError: expected " - "scalar type BFloat16 but found Half (only in CI).") +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, - example_long_prompts, + example_prompts, model: str, dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: + # TODO(sang): Sliding window should be tested separately. hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) del hf_model vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) del vllm_model - - for i in range(len(example_long_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 857d70fadcb57..f41e0f30a4e4b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -109,7 +109,7 @@ def _forward( key_pass = key[..., self.rotary_dim:] self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + positions.device, dtype=query.dtype) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -143,7 +143,8 @@ def forward( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: From 0f9a6e3d229cade0ae9a53a4f69a38f52e430bd0 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 9 May 2024 00:19:58 +0800 Subject: [PATCH 228/413] [Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi (#4573) --- tests/kernels/test_prefix_prefill.py | 243 ++++++++++++++++++++++++++- vllm/attention/ops/prefix_prefill.py | 41 +++-- 2 files changed, 267 insertions(+), 17 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 5a5987e2242fa..99fda8364dc0e 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,3 +1,4 @@ +import math import random import time @@ -6,11 +7,12 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +from vllm.attention.backends.xformers import _make_alibi_bias from vllm.attention.ops.prefix_prefill import context_attention_fwd NUM_HEADS = [64] NUM_QUERIES_PER_KV = [1, 8, 64] -HEAD_SIZES = [128, 96] +HEAD_SIZES = [128, 96, 24] DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -207,3 +209,242 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.reshape(output.shape) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + device: str, +) -> None: + random.seed(0) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) + + # Need this, otherwise when we capture the graph the process + # for GPU 1 would run on both GPU0 and GPU1 and things would hang + # + # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523 + torch.cuda.set_device(device) + + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + # Fork from: vllm/vllm/model_executor/models/bloom.py#L44 + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + alibi_slopes = _get_alibi_slopes(num_heads).to(device) + + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(BS): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_kv_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + # Warm up the Triton kernel by calling it once before actually measuring + # generation time + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=alibi_slopes) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + scale = float(1.0 / (head_size**0.5)) + + # NOTE(DefTruth): In order to reuse _make_alibi_bias function, + # we have to pad query tensor before MQA/GQA expanding. + if query.shape[0] != key.shape[0]: + query_pad = torch.empty(sum(seq_lens), + num_heads, + head_size, + dtype=dtype) + query_pad.uniform_(-1e-3, 1e-3) + seq_start = 0 + query_start = 0 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + query_pad[seq_start:seq_end, ...] = torch.cat([ + torch.zeros( + seq_len - query_len, num_heads, head_size, dtype=dtype), + query[query_start:query_end, ...] + ], + dim=0) + seq_start += seq_len + query_start += query_len + query = query_pad + + if num_kv_heads != num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # + # see also: vllm/model_executor/layers/attention.py + query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv, + query.shape[-1]) + key = key[:, :, None, :].expand(key.shape[0], num_kv_heads, + num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], num_kv_heads, + num_queries_per_kv, value.shape[-1]) + + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + output_ref = torch.empty_like(output) + seq_start = 0 + query_start = 0 + start_time = time.time() + # Attention with alibi slopes. + # FIXME(DefTruth): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + # modified from: vllm/attention/backends/xformers.py#L343 + for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): + seq_end = seq_start + seq_len + query_end = query_start + query_len + out = xops.memory_efficient_attention_forward(query[:, + seq_start:seq_end], + key[:, + seq_start:seq_end], + value[:, + seq_start:seq_end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + out = out.view_as(query[:, seq_start:seq_end]).view( + seq_len, num_heads, head_size) + output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:, + ...]) + seq_start += seq_len + query_start += query_len + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 79878b26c5294..997b25e887e30 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -472,7 +472,8 @@ def _fwd_kernel_alibi( stride_v_cache_bl, num_queries_per_kv: int, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, ): # attn_bias[] @@ -493,21 +494,24 @@ def _fwd_kernel_alibi( # initialize offsets offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) - q = tl.load( - Q + off_q, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, - other=0.0) + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) alibi_slope = tl.load(Alibi_slopes + cur_head) alibi_start_q = tl.arange( @@ -532,8 +536,9 @@ def _fwd_kernel_alibi( offs_d[None, :] * stride_v_cache_d + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, - mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, - other=0.0) + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -567,7 +572,8 @@ def _fwd_kernel_alibi( acc = acc * acc_scale[:, None] # update acc v = tl.load(V_cache + off_v, - mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -600,8 +606,9 @@ def _fwd_kernel_alibi( # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -637,8 +644,9 @@ def _fwd_kernel_alibi( # update acc v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < - cur_batch_seq_len - cur_batch_ctx_len, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), other=0.0) p = p.to(v.dtype) @@ -656,7 +664,8 @@ def _fwd_kernel_alibi( out_ptrs = Out + off_o tl.store(out_ptrs, acc, - mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) return @torch.inference_mode() @@ -690,7 +699,6 @@ def context_attention_fwd(q, num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: - assert Lk == Lk_padded _fwd_kernel_alibi[grid]( q, k, @@ -735,6 +743,7 @@ def context_attention_fwd(q, num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, From 5510cf0e8a6a3ee56daefb86b145c7f2a000817f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 09:59:31 -0700 Subject: [PATCH 229/413] [Misc] Add `get_name` method to attention backends (#4685) --- vllm/attention/backends/abstract.py | 5 +++++ vllm/attention/backends/flash_attn.py | 4 ++++ vllm/attention/backends/flashinfer.py | 16 +++++++--------- vllm/attention/backends/rocm_flash_attn.py | 4 ++++ vllm/attention/backends/torch_sdpa.py | 4 ++++ vllm/attention/backends/xformers.py | 4 ++++ vllm/worker/model_runner.py | 5 ++--- 7 files changed, 30 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b2b6e7ac810e3..02a2fd603faa8 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,11 @@ class AttentionBackend(ABC): """Abstract class for attention backends.""" + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index da672d5df6161..bee482c3431c4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -19,6 +19,10 @@ class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flash-attn" + @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2851cbe2396b2..67b99ba2eade4 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,16 +1,10 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -try: - import flashinfer - from flash_attn import flash_attn_varlen_func - from flashinfer import BatchDecodeWithPagedKVCacheWrapper -except ImportError: - flashinfer = None - flash_attn_varlen_func = None - BatchDecodeWithPagedKVCacheWrapper = None - +import flashinfer import torch +from flash_attn import flash_attn_varlen_func +from flashinfer import BatchDecodeWithPagedKVCacheWrapper from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -20,6 +14,10 @@ class FlashInferBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flashinfer" + @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c3b522e63b4b8..10c94f02ff05b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,6 +17,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 03825f6023f4c..c1c07abef0ce6 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -15,6 +15,10 @@ class TorchSDPABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "torch-sdpa" + @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 4c7fa71a2c78e..2a9150dea5875 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -20,6 +20,10 @@ class XFormersBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "xformers" + @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab248596490f6..c96f13c590fc4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,7 +9,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -395,7 +394,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": attn_metadata = self.attn_backend.make_metadata( is_prompt=True, use_cuda_graph=False, @@ -556,7 +555,7 @@ def _prepare_decode( device=self.device, ) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html From ad932a221d2a4c1e6355021bb9e9c47f7a179e51 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 8 May 2024 10:33:18 -0700 Subject: [PATCH 230/413] [Core] Faster startup for LoRA enabled models (#4634) --- vllm/lora/models.py | 10 ++++++++++ vllm/lora/worker_manager.py | 26 ++++++++++++++++++++++---- vllm/worker/model_runner.py | 29 +++++++++++++++-------------- 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 50d7e9133e0e8..cd45040bcca5d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -119,6 +119,16 @@ def __init__( self.rank = rank self.loras: Dict[str, LoRALayerWeights] = loras + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ec3c10c591a18..377f561cceaf2 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty -from typing import Any, Dict, List, Set, Type +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Set, Type, Union import torch @@ -25,6 +26,17 @@ def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, self.device = device self.lora_config = lora_config + # If False, do not cache. If None, cache is empty. + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + @abstractproperty def is_enabled(self) -> bool: ... @@ -174,9 +186,15 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: if lora_request.lora_int_id in self.list_loras(): return False - return self._lora_manager.add_lora( - self._lora_manager.create_dummy_lora(lora_request.lora_int_id, - rank, self.embedding_modules)) + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._lora_manager.create_dummy_lora( + lora_request.lora_int_id, rank, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._lora_manager.add_lora(dummy_lora) def add_lora(self, lora_request: LoRARequest) -> bool: if lora_request.lora_int_id in self.list_loras(): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c96f13c590fc4..46c6730645c1b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -835,20 +835,21 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. From 20cfcdec998b39f5dbb0dc89efe4122f95f5cb16 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 12:07:05 -0700 Subject: [PATCH 231/413] [Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659) --- csrc/cache.h | 4 +-- csrc/cache_kernels.cu | 16 ++++++--- csrc/cpu/cache.cpp | 4 +-- tests/core/test_block_manager.py | 4 +-- tests/core/test_chunked_prefill_scheduler.py | 32 +++++++++--------- tests/core/test_scheduler.py | 34 ++++++++++---------- tests/kernels/test_cache.py | 13 +++++--- tests/worker/test_swap.py | 24 +++++++------- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/flash_attn.py | 4 +-- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/backends/rocm_flash_attn.py | 4 +-- vllm/attention/backends/torch_sdpa.py | 4 +-- vllm/attention/ops/paged_attn.py | 4 +-- vllm/core/block_manager_v1.py | 12 ++++--- vllm/core/block_manager_v2.py | 4 +-- vllm/core/interfaces.py | 6 ++-- vllm/core/scheduler.py | 32 +++++++++--------- vllm/sequence.py | 8 ++--- vllm/worker/cache_engine.py | 6 ++-- vllm/worker/worker.py | 27 ++++++++++++---- 21 files changed, 137 insertions(+), 109 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 10871b3670bac..212a3bf3ddc1c 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,12 +8,12 @@ void swap_blocks( torch::Tensor& src, torch::Tensor& dst, - const std::map& block_mapping); + const torch::Tensor& block_mapping); void copy_blocks( std::vector& key_caches, std::vector& value_caches, - torch::Tensor& block_mapping); + const torch::Tensor& block_mapping); void reshape_and_cache( torch::Tensor& key, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1e02f7fcbae4c..76db96f099c69 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -23,7 +23,7 @@ void swap_blocks( torch::Tensor& src, torch::Tensor& dst, - const std::map& block_mapping) { + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -40,6 +40,11 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + char *src_ptr = static_cast(src.data_ptr()); char *dst_ptr = static_cast(dst.data_ptr()); @@ -47,9 +52,10 @@ void swap_blocks( const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - int64_t dst_block_number = pair.second; + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; cudaMemcpyAsync( @@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - torch::Tensor& block_mapping) { + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 95e3f11900fde..620d11ef1ed6c 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, - torch::Tensor& block_mapping) { + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, } void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const std::map &block_mapping) { + const torch::Tensor&block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 08d34efb8302c..9db58e075196d 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -219,7 +219,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - assert list(mapping.keys()) == gpu_blocks + assert [x[0] for x in mapping] == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) @@ -232,7 +232,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) - assert list(mapping.keys()) == cpu_blocks + assert [x[0] for x in mapping] == cpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 92498c0014666..3649e6b003a5d 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -355,8 +355,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap should be prioritized over new prefill. _, seq_group = create_dummy_prompt("2", prompt_length=60) @@ -365,8 +365,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_running_prefill_prioritized_over_swap(): @@ -406,8 +406,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() @@ -419,8 +419,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. @@ -429,8 +429,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -440,8 +440,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 1 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -451,8 +451,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_chunked_prefill_preempt(): @@ -493,8 +493,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == [] + assert out.blocks_to_swap_in == [] # Make sure we can reschedule preempted request. _, out = schedule_and_update_computed_tokens(scheduler) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 3f0c918a89abb..6bcabc4f95fa9 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -293,8 +293,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 2 assert out.num_batched_tokens == 2 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. @@ -305,8 +305,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert len(out.scheduled_seq_groups) == 3 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 3 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def initialize_scheduler(*, @@ -566,7 +566,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # NOTE: When enable_chunk is False, num_seqs budget is not updated. # assert budget.num_curr_seqs == 1 # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Nothing is copied. assert output.blocks_to_copy == [] @@ -599,7 +599,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): scheduler.block_manager.can_append_slots.side_effect = ( cannot_append_second_group) scheduler.block_manager.swap_out = MagicMock() - expected_swap_mapping = {"5": "7"} + expected_swap_mapping = [("5", "7")] scheduler.block_manager.swap_out.return_value = expected_swap_mapping remainig_running, output = scheduler._schedule_running( @@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update(): assert len(output.preempted) == 0 assert len(output.swapped_out) == 0 # Nothing is preempted. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Since append_slot returns the source -> dist mapping, it should # applied. assert output.blocks_to_copy == [(2, 3)] @@ -658,7 +658,7 @@ def test_schedule_swapped_simple(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -674,9 +674,9 @@ def test_schedule_swapped_simple(): assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 # swap in is the reverse of swap out - blocks_to_swap_in_reverse = {} - for swapin, swapout in output.blocks_to_swap_in.items(): - blocks_to_swap_in_reverse[swapout] = swapin + blocks_to_swap_in_reverse = [] + for swapin, swapout in output.blocks_to_swap_in: + blocks_to_swap_in_reverse.append((swapout, swapin)) assert blocks_to_swap_out == blocks_to_swap_in_reverse @@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) scheduler._allocate_and_set_running(seq_group) @@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = set() - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -808,7 +808,7 @@ def test_infeasible_swap(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy(): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out = {} + blocks_to_swap_out = [] scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 94a577139596e..8a27d51bb78d5 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -315,7 +315,10 @@ def test_swap_blocks( else: dst_blocks = random.sample(range(num_blocks), num_mappings) - block_mapping = dict(zip(src_blocks, dst_blocks)) + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device="cpu").view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( @@ -331,10 +334,12 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], + block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], + block_mapping_tensor) - for src, dst in block_mapping.items(): + for src, dst in block_mapping: assert torch.allclose(src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 4d2d3add27d59..d941ffdb5588a 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -54,10 +54,10 @@ def test_swap() -> None: a.cuda(), b.cuda(), rtol=0.0, atol=0.0) # Test swap out. - blocks_to_swap_out = {3: 72, 56: 35, 84: 34} + blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] execute_model_req = ExecuteModelRequest( seq_group_metadata_list=[], - blocks_to_swap_in={}, + blocks_to_swap_in=[], blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=[], ) @@ -66,24 +66,24 @@ def test_swap() -> None: for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_out.items(): + for src, dst in blocks_to_swap_out: assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - execute_model_req.blocks_to_swap_out = {} - execute_model_req.blocks_to_swap_in = { - 19: 45, - 67: 23, - 12: 78, - 40: 99, - 1: 71 - } + execute_model_req.blocks_to_swap_out = [] + execute_model_req.blocks_to_swap_in = [ + (19, 45), + (67, 23), + (12, 78), + (40, 99), + (1, 71), + ] worker.execute_model(execute_model_req=execute_model_req) for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in execute_model_req.blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in: assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 02a2fd603faa8..64ccb309a0480 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -39,7 +39,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bee482c3431c4..c2fec9153f2d8 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,7 +5,7 @@ flashinfer for all the attention operations. """ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch from flash_attn import flash_attn_varlen_func @@ -45,7 +45,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 67b99ba2eade4..8f13f3525512b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -39,7 +39,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 10c94f02ff05b..8fc1af1aa1e1c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -43,7 +43,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c1c07abef0ce6..c29218dfd0cfc 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,7 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention @@ -41,7 +41,7 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 6f7fd51c774f8..3c010b67b3120 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -196,7 +196,7 @@ def forward_prefix( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 4e7392f3486c9..52a170d79e4e7 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -473,11 +473,12 @@ def can_swap_in(self, def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> Dict[int, int]: + num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" # CPU block -> GPU block. + # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] @@ -500,14 +501,16 @@ def swap_in(self, cpu_block.block_number: gpu_block.block_number for cpu_block, gpu_block in mapping.items() } - return block_number_mapping + # convert to list of tuples once here + return list(block_number_mapping.items()) def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: # GPU block -> CPU block. + # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): new_block_table: BlockTable = [] @@ -530,7 +533,8 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: gpu_block.block_number: cpu_block.block_number for gpu_block, cpu_block in mapping.items() } - return block_number_mapping + # convert to list of tuples once here + return list(block_number_mapping.items()) def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 3b483e67ad9c1..f0bc96564050a 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -243,13 +243,13 @@ def can_swap_in(self, seq_group: SequenceGroup, return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: raise NotImplementedError def can_swap_out(self, seq_group: SequenceGroup) -> bool: return False - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: raise NotImplementedError def get_num_free_gpu_blocks(self) -> int: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index ab2c8ea0053dd..b2a5e41990f39 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,6 +1,6 @@ import enum from abc import ABC, abstractmethod -from typing import Dict, List +from typing import List from typing import Sequence as GenericSequence from typing import Tuple @@ -69,7 +69,7 @@ def can_swap_in(self, seq_group: SequenceGroup, @abstractmethod def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: pass @abstractmethod @@ -77,7 +77,7 @@ def can_swap_out(self, seq_group: SequenceGroup) -> bool: pass @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f426ee95c0ca2..35e3db18f1c43 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -117,10 +117,10 @@ class SchedulerOutputs: num_prefill_groups: int # Total number of batched tokens. num_batched_tokens: int - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. @@ -174,7 +174,7 @@ class SchedulerRunningOutputs: # Sequences that are swapped out. swapped_out: List[SequenceGroup] # The blocks to swap out. - blocks_to_swap_out: Dict[int, int] + blocks_to_swap_out: List[Tuple[int, int]] # The blocks to copy. blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. @@ -187,7 +187,7 @@ def create_empty(cls) -> "SchedulerRunningOutputs": prefill_seq_groups=[], preempted=[], swapped_out=[], - blocks_to_swap_out={}, + blocks_to_swap_out=[], blocks_to_copy=[], num_lookahead_slots=0, ) @@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs: # phase. I.e., it means the prefill has been chunked. prefill_seq_groups: List[SequenceGroup] # The blocks to swap in. - blocks_to_swap_in: Dict[int, int] + blocks_to_swap_in: List[Tuple[int, int]] # The blocks to copy. blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. @@ -219,7 +219,7 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs": return SchedulerSwappedInOutputs( decode_seq_groups=[], prefill_seq_groups=[], - blocks_to_swap_in={}, + blocks_to_swap_in=[], blocks_to_copy=[], num_lookahead_slots=0, infeasible_seq_groups=[], @@ -392,7 +392,7 @@ def _schedule_running( scheduling and SchedulerRunningOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: Dict[int, int] = {} + blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] @@ -509,7 +509,7 @@ def _schedule_swapped( SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} + blocks_to_swap_in: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -1032,7 +1032,7 @@ def _append_slots( def _preempt( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], preemption_mode: Optional[PreemptionMode] = None, ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: @@ -1073,24 +1073,24 @@ def _preempt_by_recompute( def _preempt_by_swap( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) def _swap_in( self, seq_group: SequenceGroup, - blocks_to_swap_in: Dict[int, int], + blocks_to_swap_in: List[Tuple[int, int]], ) -> None: mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.update(mapping) + blocks_to_swap_in.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): seq.status = SequenceStatus.RUNNING def _swap_out( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: if not self.block_manager.can_swap_out(seq_group): # FIXME(woosuk): Abort the sequence group instead of aborting the @@ -1099,7 +1099,7 @@ def _swap_out( "Aborted due to the lack of CPU swap space. Please increase " "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.update(mapping) + blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED diff --git a/vllm/sequence.py b/vllm/sequence.py index b486d1fedebd3..42b508b517200 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -741,10 +741,10 @@ class ExecuteModelRequest: """The model execution request.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) # The number of slots for lookahead decoding. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 26a60c652b6f4..1fb63a3e47921 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,5 +1,5 @@ """CacheEngine class for managing the KV cache.""" -from typing import Dict, List +from typing import List import torch @@ -67,12 +67,12 @@ def _allocate_kv_cache( device=device)) return kv_cache - def swap_in(self, src_to_dst: Dict[int, int]) -> None: + def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) - def swap_out(self, src_to_dst: Dict[int, int]) -> None: + def swap_out(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 538332ad003f1..313bcf25d8870 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -195,15 +195,14 @@ def _warm_up_model(self) -> None: def cache_swap( self, - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_in: torch.Tensor, + blocks_to_swap_out: torch.Tensor, blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if blocks_to_swap_in: + if blocks_to_swap_in.numel() > 0: self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out: + if blocks_to_swap_out.numel() > 0: self.cache_engine.swap_out(blocks_to_swap_out) if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @@ -219,12 +218,26 @@ def execute_model( else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list + blocks_to_swap_in: torch.Tensor + blocks_to_swap_out: torch.Tensor + blocks_to_copy: torch.Tensor if self.is_driver_worker: assert seq_group_metadata_list is not None assert execute_model_req is not None num_seq_groups = len(seq_group_metadata_list) - blocks_to_swap_in = execute_model_req.blocks_to_swap_in - blocks_to_swap_out = execute_model_req.blocks_to_swap_out + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor( + execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor( + execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) From 230c4b38c10148c3fbbbb67cd1046766d73c865a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 8 May 2024 13:14:02 -0700 Subject: [PATCH 232/413] [CI/Test] fix swap test for multi gpu (#4689) --- tests/kernels/test_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 8a27d51bb78d5..4cae15c79c489 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -222,11 +222,12 @@ def test_reshape_and_cache_flash( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) qkv = torch.randn(num_tokens, 3, @@ -245,6 +246,7 @@ def test_reshape_and_cache_flash( head_size, kv_cache_dtype, dtype, + device=device, ) key_cache, value_cache = key_caches[0], value_caches[0] From 89579a201f2c84b512f0e1006ac2ea0d979803ab Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 13:15:34 -0700 Subject: [PATCH 233/413] [Misc] Use vllm-flash-attn instead of flash-attn (#4686) --- Dockerfile | 21 --------------------- requirements-cuda.txt | 1 + setup.py | 14 +++++++++----- vllm/attention/backends/flash_attn.py | 2 +- vllm/attention/backends/flashinfer.py | 2 +- vllm/attention/selector.py | 7 ++++--- 6 files changed, 16 insertions(+), 31 deletions(-) diff --git a/Dockerfile b/Dockerfile index 90be3a30f89b1..ddca95c0e8786 100644 --- a/Dockerfile +++ b/Dockerfile @@ -87,23 +87,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### -#################### FLASH_ATTENTION Build IMAGE #################### -FROM dev as flash-attn-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# flash attention version -ARG flash_attn_version=v2.5.8 -ENV FLASH_ATTN_VERSION=${flash_attn_version} - -WORKDIR /usr/src/flash-attention-v2 - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ - --no-build-isolation --no-deps --no-cache-dir - -#################### FLASH_ATTENTION Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base @@ -122,10 +105,6 @@ RUN ldconfig /usr/local/cuda-12.4/compat/ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/pip \ pip install dist/*.whl --verbose - -RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ - --mount=type=cache,target=/root/.cache/pip \ - pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir #################### vLLM installation IMAGE #################### diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 6548d7a6684b2..ba8c614d205d2 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 diff --git a/setup.py b/setup.py index 3768daf9d6fab..d9ba96b82329a 100644 --- a/setup.py +++ b/setup.py @@ -355,14 +355,18 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): requirements = _read_requirements("requirements-cuda.txt") - cuda_major = torch.version.cuda.split(".")[0] + cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: if "vllm-nccl-cu12" in req: - modified_requirements.append( - req.replace("vllm-nccl-cu12", f"vllm-nccl-cu{cuda_major}")) - else: - modified_requirements.append(req) + req = req.replace("vllm-nccl-cu12", + f"vllm-nccl-cu{cuda_major}") + elif ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): + # vllm-flash-attn is built only for CUDA 12.1. + # Skip for other versions. + continue + modified_requirements.append(req) requirements = modified_requirements elif _is_hip(): requirements = _read_requirements("requirements-rocm.txt") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c2fec9153f2d8..4bad226512b69 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 8f13f3525512b..36e162671f944 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -3,8 +3,8 @@ import flashinfer import torch -from flash_attn import flash_attn_varlen_func from flashinfer import BatchDecodeWithPagedKVCacheWrapper +from vllm_flash_attn import flash_attn_varlen_func from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 34da0f6c6cdfc..f4446bac6b8d2 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -76,11 +76,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: return _Backend.XFORMERS try: - import flash_attn # noqa: F401 + import vllm_flash_attn # noqa: F401 except ImportError: logger.info( - "Cannot use FlashAttention-2 backend because the flash_attn " - "package is not found. Please install it for better performance.") + "Cannot use FlashAttention-2 backend because the vllm_flash_attn " + "package is not found. `pip install vllm-flash-attn` for better " + "performance.") return _Backend.XFORMERS backend_by_env_var = envs.VLLM_ATTENTION_BACKEND From f942efb5a3712498b8b583d2d9345f98d15f22f0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 8 May 2024 14:44:00 -0700 Subject: [PATCH 234/413] [Dynamic Spec Decoding] Auto-disable by the running queue size (#4592) Co-authored-by: Cade Daniel --- tests/samplers/test_rejection_sampler.py | 13 +++- .../e2e/test_multistep_correctness.py | 34 ++++++++ .../spec_decode/e2e/test_ngram_correctness.py | 2 +- tests/spec_decode/test_dynamic_spec_decode.py | 77 +++++++++++++++++++ vllm/config.py | 29 +++++-- vllm/engine/arg_utils.py | 10 +++ vllm/executor/gpu_executor.py | 2 + .../layers/rejection_sampler.py | 11 ++- vllm/sequence.py | 6 ++ vllm/spec_decode/spec_decode_worker.py | 59 +++++++++----- vllm/spec_decode/top1_proposer.py | 23 ++++-- 11 files changed, 227 insertions(+), 39 deletions(-) create mode 100644 tests/spec_decode/test_dynamic_spec_decode.py diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 13b5b80cccfdc..00a2379502e6d 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -42,9 +42,11 @@ def mock_causal_accepted_tensor( @pytest.mark.parametrize( "which_tokens_accepted", ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int, +def test_correct_output_format(which_tokens_accepted: str, + disable_bonus_tokens: bool, seed: int, device: str): """Verify the output has correct format given predetermined accepted matrix. """ @@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, size=(batch_size, 1), dtype=torch.int64) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens) rejection_sampler.init_gpu_tensors(rank=0) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int, bonus_token_ids, ) - # Bonus tokens are currently disabled. Verify they're set to -1. + expected_bonus_token_ids = bonus_token_ids.clone() + # If bonus tokens disabled. Verify they are set to -1. # See https://github.com/vllm-project/vllm/issues/4212 - expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 + if disable_bonus_tokens: + expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1 if which_tokens_accepted == "all_tokens_accepted": # Expect all tokens to be equal to draft tokens. diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index f15fcc4746d20..94d71fb012727 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator, force_output_len=True) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-160m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + "speculative_disable_by_batch_size": 2, + }, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_disable_speculation(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when all sequences disable speculation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 44ef400c91d34..c2004ff061a1e 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -57,7 +57,7 @@ @pytest.mark.parametrize("output_len", [ 256, ]) -@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("seed", [1]) def test_ngram_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, batch_size: int, diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py new file mode 100644 index 0000000000000..948a74b22f0ae --- /dev/null +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -0,0 +1,77 @@ +from unittest.mock import MagicMock + +import pytest +import torch + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_batch, mock_worker + + +@pytest.mark.parametrize('queue_size', [2, 4]) +@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) +@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@torch.inference_mode() +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): + """Verify that speculative tokens are disabled when the batch size + exceeds the threshold. + """ + disable_by_batch_size = 3 + + draft_worker = mock_worker(cls=MultiStepWorker) + target_worker = mock_worker() + rejection_sampler = MagicMock(spec=RejectionSampler) + metrics_collector = MagicMock(spec=AsyncMetricsCollector) + worker = SpecDecodeWorker(proposer_worker=draft_worker, + scorer_worker=target_worker, + rejection_sampler=rejection_sampler, + metrics_collector=metrics_collector, + disable_by_batch_size=disable_by_batch_size) + + exception_secret = 'artificial stop' + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + running_queue_size=queue_size) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + + # When the batch size is larger than the threshold, + # we expect no speculative tokens (0). + expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 + assert seq_group_metadata_list[ + 0].num_speculative_tokens == expected_num_spec_tokens + + draft_worker.sampler_output.side_effect = ValueError(exception_secret) + + proposer = Top1Proposer( + worker=draft_worker, + device='cpu', # not used + vocab_size=100, # not used + # Must be long enough to avoid being skipped due to length. + max_proposal_len=1024, + ) + + if queue_size < disable_by_batch_size: + # Should raise exception when executing the mocked draft model. + with pytest.raises(ValueError, match=exception_secret): + proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + else: + # Should not execute the draft model because spec decode is disabled + # for all requests. Accordingly, the proposal length should be 0. + proposals = proposer.get_proposals( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) + assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/vllm/config.py b/vllm/config.py index 5c3a8615eefb4..a2cb9b32c65fc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -692,6 +692,7 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: @@ -720,6 +721,9 @@ def maybe_create_spec_config( use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + speculative_disable_by_batch_size (Optional[int]): Disable + speculative decoding for new incoming requests when the number + of enqueue requests is larger than this value, if provided. ngram_prompt_lookup_max (Optional[int]): Max size of ngram token window, if provided. ngram_prompt_lookup_min (Optional[int]): Min size of ngram token @@ -730,7 +734,7 @@ def maybe_create_spec_config( the necessary conditions are met, else None. """ - if (speculative_model is None and num_speculative_tokens is None): + if speculative_model is None and num_speculative_tokens is None: return None if speculative_model is not None and num_speculative_tokens is None: @@ -739,6 +743,12 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + if (speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}") + assert (speculative_model is not None and num_speculative_tokens is not None) @@ -807,6 +817,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, ) @@ -876,8 +887,9 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, - ngram_prompt_lookup_max: int, - ngram_prompt_lookup_min: int, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ): """Create a SpeculativeConfig object. @@ -886,12 +898,19 @@ def __init__( draft_parallel_config: ParallelConfig for the draft model. num_speculative_tokens: The number of tokens to sample from the draft model before scoring with the target model. + speculative_disable_by_batch_size: Disable speculative + decoding for new incoming requests when the number of + enqueue requests is larger than this value. + ngram_prompt_lookup_max: Max size of ngram token window. + ngram_prompt_lookup_min: Min size of ngram token window. """ self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + self.speculative_disable_by_batch_size = \ + speculative_disable_by_batch_size + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bb8245eb307f7..c99b1806c7d1d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -83,6 +83,7 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None @@ -467,6 +468,13 @@ def add_cli_args( 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + parser.add_argument( '--ngram-prompt-lookup-max', type=int, @@ -547,6 +555,8 @@ def create_engine_config(self, ) -> EngineConfig: target_dtype=self.dtype, speculative_model=self.speculative_model, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index e8559b6a5c0fe..fa3480fa64837 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -93,6 +93,8 @@ def _init_spec_worker(self): spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=self.speculative_config. + speculative_disable_by_batch_size, ) assert self.parallel_config.world_size == 1, ( diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 5edbbf2c70a49..b5f1e55d0e839 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -12,15 +12,21 @@ class RejectionSampler(nn.Module): https://arxiv.org/pdf/2302.01318.pdf. """ - def __init__(self, strict_mode: bool = False): + def __init__(self, + disable_bonus_tokens: bool = True, + strict_mode: bool = False): """Create a rejection sampler. Args: + disable_bonus_tokens: Whether or not to disable the bonus token. + Require when bonus tokens will cause corrupt KV cache for + proposal methods that require KV cache. strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. """ super().__init__() + self._disable_bonus_tokens = disable_bonus_tokens self._strict_mode = strict_mode # NOTE: A "bonus token" is accepted iff all proposal tokens are @@ -312,7 +318,8 @@ def _create_output( # proposal methods that require KV cache. We can fix it by "prefilling" # the bonus token in the proposer. The following issue tracks the fix. # https://github.com/vllm-project/vllm/issues/4212 - output_with_bonus_tokens[:, -1] = -1 + if self._disable_bonus_tokens: + output_with_bonus_tokens[:, -1] = -1 # Fill the recovered token ids. output.mul_(~after_false_mask).add_( diff --git a/vllm/sequence.py b/vllm/sequence.py index 42b508b517200..3cebb85b49d27 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -612,6 +612,12 @@ def __init__( self._token_chunk_size = token_chunk_size self.do_sample = do_sample + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + self.num_speculative_tokens = None + if self._token_chunk_size is None: if is_prompt: self._token_chunk_size = list(seq_data.values())[0].get_len() diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 84ec974806c7e..a4e759095b294 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch @@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): def create_worker( cls, scorer_worker: WorkerBase, - draft_worker_kwargs, + draft_worker_kwargs: Dict[str, Any], + disable_by_batch_size: Optional[int], ) -> "SpecDecodeWorker": ngram_prompt_lookup_max = ( @@ -62,7 +63,9 @@ def create_worker( ngram_prompt_lookup_min = ( draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + disable_bonus_tokens = True if ngram_prompt_lookup_max > 0: + disable_bonus_tokens = False proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, ngram_prompt_lookup_max) @@ -75,9 +78,9 @@ def create_worker( return SpecDecodeWorker( proposer_worker, scorer_worker, - # TODO(cade) disable strict mode for speedup. - rejection_sampler=RejectionSampler(strict_mode=True), - ) + disable_by_batch_size=disable_by_batch_size, + rejection_sampler=RejectionSampler( + disable_bonus_tokens=disable_bonus_tokens, )) def __init__( self, @@ -85,6 +88,7 @@ def __init__( scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, ): """ Create a SpecDecodeWorker. @@ -97,11 +101,14 @@ def __init__( Worker. rejection_sampler: A Torch module used to perform modified rejection sampling for speculative decoding. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. metrics_collector: Helper class for collecting metrics; can be set for testing purposes. """ self.proposer_worker = proposer_worker self.scorer_worker = scorer_worker + self.disable_by_batch_size = disable_by_batch_size or float("inf") self.rejection_sampler = rejection_sampler self._metrics = AsyncMetricsCollector( @@ -199,27 +206,41 @@ def execute_model( "speculative decoding " "requires non-None seq_group_metadata_list") + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + disable_all = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + if disable_all: + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # This happens for prefill, or when the spec decode is disabled + # for this batch. if execute_model_req.num_lookahead_slots == 0 or len( execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req) + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all) return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") - def _run_no_spec( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to the - proposer and scorer model so that the KV cache is consistent between the - two. + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a prefill step, without any speculation. The input is sent to + the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. """ - #logger.info("run proposer worker no spec") - - self.proposer_worker.execute_model(execute_model_req) + if not skip_proposer: + self.proposer_worker.execute_model(execute_model_req) - #logger.info("run target worker no spec") sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -244,22 +265,18 @@ def _run_speculative_decoding_step( sequence. """ - #logger.info("get spec proposals") # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) - #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( execute_model_req, proposals, ) - #logger.info("verify proposals") accepted_token_ids, target_logprobs = self._verify_tokens( execute_model_req.seq_group_metadata_list, proposal_scores, proposals, execute_model_req.num_lookahead_slots) - #logger.info("create output list") return self._create_output_sampler_list( execute_model_req.seq_group_metadata_list, accepted_token_ids, diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index eb622a0e2e7f4..ee9462b68dae8 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -56,7 +56,7 @@ def get_proposals( proposal_lens, nonzero_proposal_len_seqs, nonzero_proposal_len_indices, - ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) if nonzero_proposal_len_seqs: # Speculate tokens using the draft worker for the speculative @@ -97,17 +97,27 @@ def get_proposals( return proposals - def _split_by_max_model_len( + def _split_by_proposal_len( self, seq_group_metadata_list: List[SequenceGroupMetadata], proposal_len: int, ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length.""" + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ proposal_lens: List[int] = [] nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_indices: List[int] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has been disabled + # (e.g. due to high traffic). + if seq_group_metadata.num_speculative_tokens == 0: + proposal_lens.append(0) + continue + seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_len = seq_data.get_len() @@ -115,13 +125,14 @@ def _split_by_max_model_len( # are supported. # If max_proposal_len is defined, then we shall no exccess this # quota for nonzero_proposal + new_k = 0 if (self.max_proposal_len is None or seq_len + proposal_len < self.max_proposal_len): - proposal_lens.append(proposal_len) + new_k = proposal_len nonzero_proposal_len_seqs.append(seq_group_metadata) nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k return ( proposal_lens, From 8b9241be3a0020724e145bf600d9710b3d59b167 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Wed, 8 May 2024 16:24:46 -0700 Subject: [PATCH 235/413] [Speculative decoding] [Bugfix] Fix overallocation in ngram + spec logprobs (#4672) --- vllm/spec_decode/ngram_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index fed8be42054a5..f18f9387f5b23 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -138,7 +138,7 @@ def sampler_output( SamplerOutput( outputs=None, sampled_token_probs=token_probs[i], - logprobs=token_logprobs, + logprobs=token_logprobs[i], sampled_token_ids=token_ids[i], )) return outputs, False From e288df0632d5bdde76c20bed8310b46d35b8e5ac Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Wed, 8 May 2024 20:14:31 -0400 Subject: [PATCH 236/413] [Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 48 ++++++++++++++------ 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index fd0837f0cb39c..9c6bff000e916 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -115,7 +115,8 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } -// Constructs destination register by taking bytes from 2 sources (based on mask) +// Constructs destination register by taking bytes from 2 sources (based on +// mask) template __device__ inline uint32_t prmt(uint32_t a) { uint32_t res; @@ -933,9 +934,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -1275,13 +1276,22 @@ typedef struct { thread_config_t tb_cfg; } exec_config_t; -thread_config_t thread_configs[] = { +thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default (max cache usage) - {64, 128, 128}, // Reduce N, reduce warps - {128, 64, 128}, // Reduce N more, but increase K + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, }; @@ -1397,11 +1407,21 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { - for (auto th_config : thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } } } @@ -1574,10 +1594,12 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } CALL_IF(4, 32, 2, 256) CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 8, 256) CALL_IF(4, 8, 4, 128) CALL_IF(4, 4, 8, 128) CALL_IF(8, 32, 2, 256) CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 8, 256) CALL_IF(8, 8, 4, 128) CALL_IF(8, 4, 8, 128) else { From 16bc0a098f6d34050637c3336183fb6966300dd5 Mon Sep 17 00:00:00 2001 From: Mahmoud Ashraf Date: Thu, 9 May 2024 08:02:31 +0300 Subject: [PATCH 237/413] [Frontend] add tok/s speed metric to llm class when using tqdm (#4400) Co-authored-by: Michael Goin --- vllm/entrypoints/llm.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 3ed660e183360..71620139fba39 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -238,17 +238,25 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() - pbar = tqdm(total=num_requests, - desc="Processed prompts", - dynamic_ncols=True) + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=f"Generation Speed: {0:.2f} toks/s", + ) # Run the engine. outputs: List[RequestOutput] = [] + total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished: outputs.append(output) if use_tqdm: + total_toks += (sum( + len(stp.token_ids) for stp in output.outputs)) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() @@ -256,4 +264,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs \ No newline at end of file + return outputs From f12b20deccbc6c8bb5cdeac053d75178341c66c1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 9 May 2024 13:48:33 +0800 Subject: [PATCH 238/413] [Frontend] Move async logic outside of constructor (#4674) --- tests/async_engine/test_chat_template.py | 30 ++++---- tests/entrypoints/openai/test_serving_chat.py | 8 ++- vllm/engine/arg_utils.py | 2 +- vllm/entrypoints/openai/api_server.py | 23 +++++- vllm/entrypoints/openai/serving_chat.py | 72 +++++++++---------- vllm/entrypoints/openai/serving_completion.py | 7 +- vllm/entrypoints/openai/serving_engine.py | 56 +++++---------- 7 files changed, 96 insertions(+), 102 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 64bcba67c3437..55b730812ea94 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -60,13 +60,12 @@ class MockServingChat: tokenizer: MockTokenizer -@pytest.mark.asyncio -async def test_load_chat_template(): +def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template( - mock_serving_chat, chat_template=chatml_jinja_path) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -77,8 +76,7 @@ async def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -@pytest.mark.asyncio -async def test_no_load_chat_template_filelike(): +def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() @@ -86,35 +84,33 @@ async def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) -@pytest.mark.asyncio -async def test_no_load_chat_template_literallike(): +def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template assert template_content == template -@pytest.mark.asyncio @pytest.mark.parametrize( "model,template,add_generation_prompt,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) -async def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): +def test_get_gen_prompt(model, template, add_generation_prompt, + expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - await OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 269b0823fec05..13e2e372cef33 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -20,11 +20,15 @@ class MockModelConfig: class MockEngine: async def get_model_config(self): - return MockModelConfig + return MockModelConfig() async def _async_serving_chat_init(): - serving_completion = OpenAIServingChat(MockEngine(), + engine = MockEngine() + model_config = await engine.get_model_config() + + serving_completion = OpenAIServingChat(engine, + model_config, served_model_names=[MODEL_NAME], response_role="assistant", chat_template=CHAT_TEMPLATE) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c99b1806c7d1d..5c2acbef13129 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -516,7 +516,7 @@ def add_cli_args( return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 44a946f2e32d4..362f28d05c3bb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,7 +4,7 @@ import re from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Set +from typing import Optional, Set import fastapi import uvicorn @@ -164,15 +164,32 @@ async def authentication(request: Request, call_next): served_model_names = args.served_model_name else: served_model_names = [args.model] + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - openai_serving_chat = OpenAIServingChat(engine, served_model_names, + + event_loop: Optional[asyncio.AbstractEventLoop] + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + openai_serving_chat = OpenAIServingChat(engine, model_config, + served_model_names, args.response_role, args.lora_modules, args.chat_template) openai_serving_completion = OpenAIServingCompletion( - engine, served_model_names, args.lora_modules) + engine, model_config, served_model_names, args.lora_modules) app.root_path = args.root_path uvicorn.run(app, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c8f4a6b315db0..1b469fc59b076 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,4 +1,3 @@ -import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, @@ -8,6 +7,7 @@ from openai.types.chat import (ChatCompletionContentPartParam, ChatCompletionRole) +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -35,17 +35,47 @@ class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, + model_config: ModelConfig, served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules, - await_post_init=self._load_chat_template( - chat_template=chat_template)) + lora_modules=lora_modules) self.response_role = response_role + self._load_chat_template(chat_template) + + def _load_chat_template(self, chat_template: Optional[str]): + tokenizer = self.tokenizer + + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) + else: + logger.warning( + "No chat template provided. Chat API will not work.") def _parse_chat_message_content( self, @@ -357,36 +387,4 @@ async def chat_completion_full_generator( usage=usage, ) - return response - - async def _load_chat_template(self, chat_template: Optional[str]): - while self.tokenizer is None: - # Give the parent class time to load the tokenizer - await asyncio.sleep(0.1) - tokenizer = self.tokenizer - - if chat_template is not None: - try: - with open(chat_template, "r") as f: - tokenizer.chat_template = f.read() - except OSError as e: - JINJA_CHARS = "{}\n" - if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") - raise ValueError(msg) from e - - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - tokenizer.chat_template = codecs.decode( - chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", - tokenizer.chat_template) - elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) - else: - logger.warning( - "No chat template provided. Chat API will not work.") + return response \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6a7f29c4c96f2..158d8ed7fbbf5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -4,6 +4,7 @@ from fastapi import Request +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (CompletionRequest, CompletionResponse, @@ -52,11 +53,11 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]: class OpenAIServingCompletion(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]] = None): + lora_modules: Optional[List[LoRAModulePath]]): super().__init__(engine=engine, + model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 21baea2e5e7f6..f10718c5f3d80 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,12 @@ -import asyncio import json from dataclasses import dataclass from http import HTTPStatus -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from pydantic import Field -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Annotated +from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, ErrorResponse, @@ -29,13 +28,24 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]], - await_post_init: Optional[Awaitable[Any]] = None): + lora_modules: Optional[List[LoRAModulePath]]): + super().__init__() + self.engine = engine + self.max_model_len = model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + model_config.tokenizer, + tokenizer_mode=model_config.tokenizer_mode, + tokenizer_revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + truncation_side="left") + self.served_model_names = served_model_names + if lora_modules is None: self.lora_requests = [] else: @@ -47,38 +57,6 @@ def __init__(self, ) for i, lora in enumerate(lora_modules, start=1) ] - self.max_model_len = 0 - # Lazy initialized - self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - event_loop.create_task(self._post_init(await_post_init)) - else: - # When using single vLLM without engine_use_ray - asyncio.run(self._post_init(await_post_init)) - - async def _post_init(self, await_post_init): - engine_model_config = await self.engine.get_model_config() - self.max_model_len = engine_model_config.max_model_len - - # A separate tokenizer to map token IDs to strings. - self.tokenizer = get_tokenizer( - engine_model_config.tokenizer, - tokenizer_mode=engine_model_config.tokenizer_mode, - tokenizer_revision=engine_model_config.tokenizer_revision, - trust_remote_code=engine_model_config.trust_remote_code, - truncation_side="left") - - if await_post_init is not None: - await await_post_init - async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ From 190bc838e17196733526896bf2861f8d05bd3f43 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 00:17:17 -0700 Subject: [PATCH 239/413] [Misc] Remove unnecessary ModelRunner imports (#4703) --- tests/samplers/test_sampler.py | 81 ++++++++++------------------------ tests/test_logits_processor.py | 23 +++------- 2 files changed, 31 insertions(+), 73 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index e4fea165a4d46..ddc66aa28a094 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -11,8 +11,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import Counter -from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter, is_pin_memory_available class MockLogitsSampler(Sampler): @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]: input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, sampler, model_runner + return input_tensor, fake_logits, sampler VOCAB_SIZE = 32000 @@ -53,7 +46,6 @@ def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, - model_runner: ModelRunner, sampling_params: SamplingParams, device: str, ): @@ -75,7 +67,7 @@ def _do_sample( seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams(temperature=0) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -105,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -115,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str): temperature=1.0, n=random.randint(1, 10), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -131,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 @@ -141,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str): n=random.randint(1, 10), seed=random.randint(0, 10000), ) - sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, + sampler_output = _do_sample(batch_size, fake_logits, sampler, sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -157,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=1.0, @@ -165,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params, device) + sampling_params, device) assert first_sampler_output == second_sampler_output - del model_runner - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @@ -181,20 +163,18 @@ def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, - device) + _do_sample(batch_size, fake_logits, sampler, sampling_params, device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. - del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -448,13 +428,13 @@ def run_test_case(*, ("Invalid test case, expected_penalization does not match computed" "batch size") - _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + _, fake_logits, sampler = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else None, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -480,8 +460,6 @@ def run_test_case(*, fake_logits[logits_idx, :] == -float('inf')) == 0, "No tokens should have been penalized" - del model_runner - for test_case in test_cases: run_test_case(**test_case) @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str): )) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - def test_sampling(model_runner: ModelRunner): + def test_sampling(): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner): assert nth_output.output_token in expected_tokens[i] # Test batch - test_sampling(model_runner) + test_sampling() # Shuffle the batch and resample target_index = list(range(batch_size)) @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner): # This time, results of seeded random samples will be compared with # the corresponding sample in the pre-shuffled batch - test_sampling(model_runner) - - del model_runner + test_sampling() @pytest.mark.parametrize("seed", RANDOM_SEEDS) @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): seq_lens, query_lens=seq_lens, device=device, - pin_memory=model_runner.pin_memory) + pin_memory=is_pin_memory_available()) sample_probs = None @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs): hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) - - del model_runner diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 179e8d25a341b..4ee980505a3ab 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -9,7 +9,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.utils import is_pin_memory_available class MockLogitsProcessor(LogitsProcessor): @@ -30,21 +30,15 @@ def forward(self, *args, **kwargs): def _prepare_test( - batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - return input_tensor, fake_logits, logits_processor, model_runner + return input_tensor, fake_logits, logits_processor RANDOM_SEEDS = list(range(128)) @@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) - input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, logits_processor = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -87,8 +80,8 @@ def pick_ith(token_ids, logits): seq_group_metadata_list, seq_lens, query_lens=seq_lens, - device=model_runner.device, - pin_memory=model_runner.pin_memory) + device=device, + pin_memory=is_pin_memory_available()) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, @@ -99,5 +92,3 @@ def pick_ith(token_ids, logits): fake_logits *= logits_processor.scale assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], 1e-4) - - del model_runner From 0ee535b2945d042cbb1fc6e63fd3fddd94d491f2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 9 May 2024 09:04:59 -0700 Subject: [PATCH 240/413] [Misc] Set block size at initialization & Fix test_model_runner (#4705) --- tests/worker/test_model_runner.py | 90 +++++++++++-------------------- vllm/worker/cpu_model_runner.py | 21 ++++---- vllm/worker/cpu_worker.py | 1 + vllm/worker/model_runner.py | 54 ++++++++----------- vllm/worker/worker.py | 2 +- 5 files changed, 64 insertions(+), 104 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e7975d0ef48b9..3e3d2e3f5c53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,27 +1,38 @@ import pytest import torch -from vllm.config import ModelConfig, SchedulerConfig from vllm.distributed.parallel_state import init_distributed_environment +from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + return model_runner + + @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_prompt(batch_size): - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=None, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) + model_runner = _create_model_runner( + "facebook/opt-125m", + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, + ) seq_lens = [] seq_group_metadata_list = [] @@ -123,27 +134,15 @@ def test_prepare_prompt(batch_size): @pytest.mark.parametrize("batch_size", list(range(1, 257))) def test_prepare_decode_cuda_graph(batch_size): - model_config = ModelConfig( + model_runner = _create_model_runner( "facebook/opt-125m", - "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=False, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=False) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_lens = [] seq_group_metadata_list = [] @@ -214,23 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size): def test_empty_seq_group(): """Verify prepare prompt and decode returns empty output.""" - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=False, ) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=None, - device_config=None, - load_config=None, - lora_config=None) - model_runner.set_block_size(16) seq_group_metadata_list = [] input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( model_runner._prepare_decode(seq_group_metadata_list)) @@ -260,29 +248,15 @@ def distributed_init(): @pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_hybrid_batches(batch_size, enforce_eager, distributed_init): - - model_config = ModelConfig( - "facebook/opt-125m", + model_runner = _create_model_runner( "facebook/opt-125m", - tokenizer_mode="auto", - trust_remote_code=False, seed=0, dtype="float16", - revision=None, enforce_eager=enforce_eager, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=True, ) - scheduler_config = SchedulerConfig(100000, - 100000, - 100000, - enable_chunked_prefill=True) - model_runner = ModelRunner(model_config=model_config, - parallel_config=None, - scheduler_config=scheduler_config, - device_config=None, - load_config=None, - lora_config=None, - is_driver_worker=True) - model_runner.set_block_size(16) # Add prefill requests. seq_lens = [] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 193b021b7a11e..6c8b1685dadcf 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -4,8 +4,9 @@ from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -26,6 +27,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], @@ -39,27 +41,22 @@ def __init__( self.scheduler_config = scheduler_config # Currently, CPU worker doesn't support chunked prefill. assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.kv_cache_dtype = kv_cache_dtype - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend(self.model_config.dtype) # Lazy initialization. self.model: nn.Module # Set after init_Model - self.block_size: int # Set after initial profiling. def load_model(self) -> None: self.model = get_model( diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index e1ef500ac07b8..5e4ae564cb57e 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -151,6 +151,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=self.load_config, lora_config=self.lora_config, vision_language_config=self.vision_language_config, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 46c6730645c1b..b5e582116297c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,8 +9,9 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed.device_communicators import (custom_all_reduce, pynccl_utils) @@ -106,6 +107,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, + cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", @@ -115,48 +117,40 @@ def __init__( self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config self.lora_config = lora_config self.load_config = load_config self.is_driver_worker = is_driver_worker + self.vision_language_config = vision_language_config - # model_config can be None in tests/samplers/test_sampler.py. - # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) - self.device_config = (device_config - if device_config is not None else DeviceConfig()) self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() - # Set after load_model. - self.lora_manager: LRUCacheWorkerLoRAManager = None - + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - - self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture - if self.model_config is not None else 0) - - self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = kv_cache_dtype - self.vision_language_config = vision_language_config - - self.attn_backend = get_attn_backend( - self.model_config.dtype if model_config is not None else None) - - # Lazy initialization - self.model: torch.nn.Module # Set after load_model - self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables: torch.Tensor # Set after initial profiling. + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + dtype=np.int32) + self.attn_backend = get_attn_backend(self.model_config.dtype) + # Lazy initialization + self.model: torch.nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: with CudaMemoryProfiler() as m: @@ -211,13 +205,6 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") - def set_block_size(self, block_size: int) -> None: - self.block_size = block_size - - self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), - dtype=np.int32) - def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size @@ -835,6 +822,7 @@ def profile_run(self) -> None: dummy_lora_requests = [] dummy_lora_requests_per_seq = [] if self.lora_config: + assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for idx in range(self.lora_config.max_loras): lora_id = idx + 1 diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 313bcf25d8870..43f6b2b443b70 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -75,6 +75,7 @@ def __init__( parallel_config, scheduler_config, device_config, + cache_config, load_config=load_config, lora_config=self.lora_config, kv_cache_dtype=self.cache_config.cache_dtype, @@ -184,7 +185,6 @@ def _init_cache_engine(self): self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.gpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: From ff5abcd7463211251bcf19916ae3b45d762f48d4 Mon Sep 17 00:00:00 2001 From: kliuae <17350011+kliuae@users.noreply.github.com> Date: Fri, 10 May 2024 00:19:50 +0800 Subject: [PATCH 241/413] [ROCm] Add support for Punica kernels on AMD GPUs (#3140) Co-authored-by: miloice --- CMakeLists.txt | 16 +- Dockerfile.rocm | 3 + csrc/cuda_compat.h | 6 + csrc/punica/bgmv/bgmv_impl.cuh | 154 +++++++++++++++++++ csrc/punica/bgmv/vec_dtypes.cuh | 5 +- csrc/punica/{punica_ops.cc => punica_ops.cu} | 17 +- csrc/punica/punica_ops.h | 11 ++ csrc/punica/punica_pybind.cpp | 13 ++ csrc/punica/type_convert.h | 82 ++++++++++ setup.py | 6 +- 10 files changed, 287 insertions(+), 26 deletions(-) rename csrc/punica/{punica_ops.cc => punica_ops.cu} (98%) create mode 100644 csrc/punica/punica_ops.h create mode 100644 csrc/punica/punica_pybind.cpp create mode 100644 csrc/punica/type_convert.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f817f3382c5e1..47629f036fb09 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -219,7 +219,8 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/punica_ops.cc") + "csrc/punica/punica_ops.cu" + "csrc/punica/punica_pybind.cpp") # # Copy GPU compilation flags+update for punica @@ -243,6 +244,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA") endif() endforeach() message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") +elseif(${VLLM_GPU_LANG} STREQUAL "HIP") + set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES}) + message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") endif() if (VLLM_PUNICA_GPU_ARCHES) @@ -277,11 +281,6 @@ add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) -endif() - -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and @@ -292,3 +291,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") add_dependencies(default _punica_C) endif() endif() + +if(VLLM_GPU_LANG STREQUAL "CUDA") + message(STATUS "Enabling moe extension.") + add_dependencies(default _moe_C) +endif() diff --git a/Dockerfile.rocm b/Dockerfile.rocm index d04bb9915e2ab..eefad79e79d83 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -94,6 +94,9 @@ COPY . . RUN python3 -m pip install --upgrade pip numba +# make sure punica kernels are built (for LoRA) +ENV VLLM_INSTALL_PUNICA_KERNELS=1 + RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c711d8d1b24b9..1ebb2e74a82fc 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -28,6 +28,12 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index dad8805c750cb..8a3b8403b4a6f 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,8 +1,14 @@ #pragma once #include +#ifndef USE_ROCM #include +#else +#include +#endif +#ifndef USE_ROCM #include +#endif #include #include #include @@ -11,6 +17,24 @@ namespace cg = cooperative_groups; +#ifdef USE_ROCM +template +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} +#endif + +#ifndef USE_ROCM + // nthrs = (32, 4) template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + size_t j = blockIdx.x; + constexpr size_t tile_size = tx * ty * vec_size; + constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; + __shared__ float y_warpwise[ty]; + + float y = 0; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + x_vec.load(X + (batch_idx * feat_in) + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + } + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); + } + + __syncthreads(); + + if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { + y += sum; + } + } + + if (threadIdx.x == 0) { + y_warpwise[threadIdx.y] = y; + } + __syncthreads(); + + float y_write = 0.f; +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y_write += y_warpwise[i]; + } + + // write Y; + if (threadIdx.x == 0 && threadIdx.y == 0) { + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y_write)); + } +} + +#endif + // nthrs = (2, 16, 4) template @@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif } } @@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, scale); } } else { +#ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); @@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, full_y_size, num_layers, layer_idx, scale); } +#else + constexpr size_t rocm_warp_size = warpSize; + +#define CHECK_INPUT_TILEABLE_BY(vec_size_) \ + feat_in % (rocm_warp_size * vec_size_) == 0 + +#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ + if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ + constexpr size_t vec_size_shrink = vec_size_; \ + constexpr int tx = tx_; \ + constexpr int ty = ty_; \ + dim3 nblks(feat_out, batch_size); \ + dim3 nthrs(tx, ty); \ + bgmv_shrink_kernel \ + <<>>(Y, X, W, indicies, y_offset, \ + full_y_size, num_layers, layer_idx, \ + scale); \ + } + + static_assert(CHECK_INPUT_TILEABLE_BY(32) || + CHECK_INPUT_TILEABLE_BY(16) || + CHECK_INPUT_TILEABLE_BY( 8) || + CHECK_INPUT_TILEABLE_BY( 4) || + CHECK_INPUT_TILEABLE_BY( 2) || + CHECK_INPUT_TILEABLE_BY( 1)); + + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) + else + LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) + +#undef CHECK_INPUT_TILEABLE_BY +#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM +#endif } } diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index cf00d869cf635..2738892e6dc4a 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,8 +1,6 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#include -#include #ifdef FLASHINFER_USE_FP8 #include #endif @@ -10,6 +8,9 @@ #include +#include "../type_convert.h" +#include "../../cuda_compat.h" + #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cu similarity index 98% rename from csrc/punica/punica_ops.cc rename to csrc/punica/punica_ops.cu index 8797fde85744a..61de3b37937cc 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cu @@ -1,12 +1,11 @@ -#include -#include #include #include #include +#include "type_convert.h" +#include "../cuda_compat.h" #include "bgmv/bgmv_config.h" -namespace { //====== utils ====== @@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } - -} // namespace - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..937e2d1d25d4a --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp new file mode 100644 index 0000000000000..9490ad59cdd5f --- /dev/null +++ b/csrc/punica/punica_pybind.cpp @@ -0,0 +1,13 @@ +#include + +#include "punica_ops.h" + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..dff7ce49283d7 --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,82 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifndef USE_ROCM + +#include +#include + +#else + +#include +#include + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { + return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { + return __hip_bfloat162{vall, valr}; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { + return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { + return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { + return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { + return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { + return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { + return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { + return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/setup.py b/setup.py index d9ba96b82329a..0dc8818b44a9e 100644 --- a/setup.py +++ b/setup.py @@ -385,12 +385,12 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) - if _install_punica(): - ext_modules.append(CMakeExtension(name="vllm._punica_C")) - if not _is_neuron(): ext_modules.append(CMakeExtension(name="vllm._C")) + if _install_punica(): + ext_modules.append(CMakeExtension(name="vllm._punica_C")) + package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } From a3c124570a66f746ba09faabe2e14851386b395a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 10 May 2024 00:53:14 +0800 Subject: [PATCH 242/413] [Bugfix] Fix CLI arguments in OpenAI server docs (#4709) --- docs/source/serving/openai_compatible_server.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index c157d8ba998da..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -108,5 +108,5 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) ```{argparse} :module: vllm.entrypoints.openai.cli_args :func: make_arg_parser -:prog: vllm-openai-server +:prog: -m vllm.entrypoints.openai.api_server ``` \ No newline at end of file From cea64430f615ff90c67ff8375ec86562913c5500 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 9 May 2024 11:10:13 -0600 Subject: [PATCH 243/413] [Bugfix] Update grafana.json (#4711) --- examples/production_monitoring/grafana.json | 432 +++++++++++--------- 1 file changed, 239 insertions(+), 193 deletions(-) diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json index 5e9bd5bd03869..273f7f5ac42cf 100644 --- a/examples/production_monitoring/grafana.json +++ b/examples/production_monitoring/grafana.json @@ -1,4 +1,41 @@ { + "__inputs": [ + { + "name": "DS_PROMETHEUS", + "label": "prometheus", + "description": "", + "type": "datasource", + "pluginId": "prometheus", + "pluginName": "Prometheus" + } + ], + "__elements": {}, + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "10.4.2" + }, + { + "type": "panel", + "id": "heatmap", + "name": "Heatmap", + "version": "" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "1.0.0" + }, + { + "type": "panel", + "id": "timeseries", + "name": "Time series", + "version": "" + } + ], "annotations": { "list": [ { @@ -25,14 +62,14 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 29, + "id": null, "links": [], "liveNow": false, "panels": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "End to end request latency measured in seconds.", "fieldConfig": { @@ -41,6 +78,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -54,6 +92,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -111,7 +150,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -127,7 +166,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -144,7 +183,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -161,7 +200,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -178,7 +217,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:e2e_request_latency_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:e2e_request_latency_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -195,7 +234,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of tokens processed per second", "fieldConfig": { @@ -204,6 +243,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -217,6 +257,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -273,7 +314,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -289,7 +330,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -310,7 +351,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Inter token latency in seconds.", "fieldConfig": { @@ -319,6 +360,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -332,6 +374,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -389,7 +432,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -405,7 +448,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -422,7 +465,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -439,7 +482,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -456,7 +499,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_per_output_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -473,7 +516,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Number of requests in RUNNING, WAITING, and SWAPPED state", "fieldConfig": { @@ -482,6 +525,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -495,6 +539,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -552,7 +597,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -568,7 +613,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -585,7 +630,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -606,7 +651,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "P50, P90, P95, and P99 TTFT latency in seconds.", "fieldConfig": { @@ -615,6 +660,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -628,6 +674,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -685,7 +732,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -702,7 +749,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -718,7 +765,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -735,7 +782,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -752,7 +799,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "rate(vllm:time_to_first_token_seconds_sum{model_name=\"$model_name\"}[$__rate_interval])\n/\nrate(vllm:time_to_first_token_seconds_count{model_name=\"$model_name\"}[$__rate_interval])", @@ -769,7 +816,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "description": "Percentage of used cache blocks by vLLM.", "fieldConfig": { @@ -778,6 +825,7 @@ "mode": "palette-classic" }, "custom": { + "axisBorderShow": false, "axisCenteredZero": false, "axisColorMode": "text", "axisLabel": "", @@ -791,6 +839,7 @@ "tooltip": false, "viz": false }, + "insertNulls": false, "lineInterpolation": "linear", "lineWidth": 1, "pointSize": 5, @@ -848,7 +897,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:gpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -860,7 +909,7 @@ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "editorMode": "code", "expr": "vllm:cpu_cache_usage_perc{model_name=\"$model_name\"}", @@ -875,229 +924,232 @@ "type": "timeseries" }, { - "type": "heatmap", - "title": "Request Prompt Length", + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, "description": "Heatmap of request prompt length", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, "gridPos": { - "x": 0, - "y": 24, + "h": 8, "w": 12, - "h": 8 - }, - "datasource": { - "uid": "prometheus", - "type": "prometheus" + "x": 0, + "y": 24 }, "id": 12, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "refId": "A", - "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", - "range": true, - "instant": false, - "editorMode": "builder", - "legendFormat": "{{le}}", - "useBackend": false, - "disableTextWrap": false, - "fullMetaSearch": false, - "includeNullMetadata": true, - "format": "heatmap" - } - ], "options": { "calculate": false, - "yAxis": { - "axisPlacement": "left", - "reverse": false, - "unit": "none", - "axisLabel": "Prompt Length" - }, - "rowsFrame": { - "layout": "auto", - "value": "Request count" + "cellGap": 1, + "cellValues": { + "unit": "none" }, "color": { - "mode": "scheme", + "exponent": 0.5, "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, "scale": "exponential", - "exponent": 0.5, "scheme": "Spectral", - "steps": 64, - "reverse": false, - "min": 0 + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" }, - "cellGap": 1, "filterValues": { "le": 1e-9 }, - "tooltip": { - "show": true, - "yHistogram": true - }, "legend": { "show": true }, - "exemplars": { - "color": "rgba(255,0,255,0.7)" + "rowsFrame": { + "layout": "auto", + "value": "Request count" }, - "cellValues": { + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Prompt Length", + "axisPlacement": "left", + "reverse": false, "unit": "none" } }, + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Prompt Length", + "type": "heatmap" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "description": "Heatmap of request generation length", "fieldConfig": { "defaults": { "custom": { - "scaleDistribution": { - "type": "linear" - }, "hideFrom": { + "legend": false, "tooltip": false, - "viz": false, - "legend": false + "viz": false + }, + "scaleDistribution": { + "type": "linear" } } }, "overrides": [] }, - "pluginVersion": "10.2.0" - }, - { - "datasource": { - "uid": "prometheus", - "type": "prometheus" - }, - "type": "heatmap", - "title": "Request Generation Length", - "description": "Heatmap of request generation length", "gridPos": { - "x": 12, - "y": 24, + "h": 8, "w": 12, - "h": 8 + "x": 12, + "y": 24 }, "id": 13, - "targets": [ - { - "datasource": { - "type": "prometheus", - "uid": "prometheus" - }, - "refId": "A", - "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", - "range": true, - "instant": false, - "editorMode": "builder", - "legendFormat": "{{le}}", - "useBackend": false, - "disableTextWrap": false, - "fullMetaSearch": false, - "includeNullMetadata": true, - "format": "heatmap" - } - ], "options": { "calculate": false, - "yAxis": { - "axisPlacement": "left", - "reverse": false, - "unit": "none", - "axisLabel": "Generation Length" - }, - "rowsFrame": { - "layout": "auto", - "value": "Request count" + "cellGap": 1, + "cellValues": { + "unit": "none" }, "color": { - "mode": "scheme", + "exponent": 0.5, "fill": "dark-orange", + "min": 0, + "mode": "scheme", + "reverse": false, "scale": "exponential", - "exponent": 0.5, "scheme": "Spectral", - "steps": 64, - "reverse": false, - "min": 0 + "steps": 64 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" }, - "cellGap": 1, "filterValues": { "le": 1e-9 }, - "tooltip": { - "show": true, - "yHistogram": true - }, "legend": { "show": true }, - "exemplars": { - "color": "rgba(255,0,255,0.7)" + "rowsFrame": { + "layout": "auto", + "value": "Request count" }, - "cellValues": { + "tooltip": { + "mode": "single", + "showColorScale": false, + "yHistogram": true + }, + "yAxis": { + "axisLabel": "Generation Length", + "axisPlacement": "left", + "reverse": false, "unit": "none" } }, - "fieldConfig": { - "defaults": { - "custom": { - "scaleDistribution": { - "type": "linear" - }, - "hideFrom": { - "tooltip": false, - "viz": false, - "legend": false - } - } - }, - "overrides": [] - }, - "pluginVersion": "10.2.0" + "pluginVersion": "10.4.2", + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "${DS_PROMETHEUS}" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "format": "heatmap", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "legendFormat": "{{le}}", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Request Generation Length", + "type": "heatmap" }, { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, + "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", "fieldConfig": { "defaults": { + "color": { + "mode": "palette-classic" + }, "custom": { - "drawStyle": "line", - "lineInterpolation": "linear", + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", "barAlignment": 0, - "lineWidth": 1, + "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", - "spanNulls": false, + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, "insertNulls": false, - "showPoints": "auto", + "lineInterpolation": "linear", + "lineWidth": 1, "pointSize": 5, - "stacking": { - "mode": "none", - "group": "A" - }, - "axisPlacement": "auto", - "axisLabel": "", - "axisColorMode": "text", - "axisBorderShow": false, "scaleDistribution": { "type": "linear" }, - "axisCenteredZero": false, - "hideFrom": { - "tooltip": false, - "viz": false, - "legend": false + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, - "color": { - "mode": "palette-classic" - }, "mappings": [], "thresholds": { "mode": "absolute", @@ -1123,22 +1175,22 @@ }, "id": 11, "options": { - "tooltip": { - "mode": "single", - "sort": "none" - }, "legend": { - "showLegend": true, + "calcs": [], "displayMode": "list", "placement": "bottom", - "calcs": [] + "showLegend": true + }, + "tooltip": { + "mode": "single", + "sort": "none" } }, "targets": [ { "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "disableTextWrap": false, "editorMode": "builder", @@ -1154,25 +1206,19 @@ } ], "title": "Finish Reason", - "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", "type": "timeseries" } ], "refresh": "", - "schemaVersion": 37, - "style": "dark", + "schemaVersion": 39, "tags": [], "templating": { "list": [ { - "current": { - "selected": false, - "text": "vllm", - "value": "vllm" - }, + "current": {}, "datasource": { "type": "prometheus", - "uid": "prometheus" + "uid": "${DS_PROMETHEUS}" }, "definition": "label_values(model_name)", "hide": 0, @@ -1201,6 +1247,6 @@ "timezone": "", "title": "vLLM", "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b", - "version": 2, + "version": 1, "weekStart": "" } From be0c5180ac0832a0b285d0845d458798bb3f0f4f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 9 May 2024 14:36:25 -0400 Subject: [PATCH 244/413] [Bugfix] Add logs for all model dtype casting (#4717) --- vllm/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index a2cb9b32c65fc..275814d72e6c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1063,6 +1063,7 @@ def _get_and_verify_dtype( if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 # models. + logger.info("Casting torch.float32 to torch.float16.") torch_dtype = torch.float16 else: torch_dtype = config_dtype @@ -1087,9 +1088,11 @@ def _get_and_verify_dtype( if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) pass else: # Casting between float16 and bfloat16 is allowed with a warning. From ebce310b7433e050086f52ca48571807df467f50 Mon Sep 17 00:00:00 2001 From: Hao Zhang <152229491+sfc-gh-hazhang@users.noreply.github.com> Date: Thu, 9 May 2024 15:37:14 -0700 Subject: [PATCH 245/413] [Model] Snowflake arctic model implementation (#4652) Co-authored-by: Dash Desai <1723932+iamontheinet@users.noreply.github.com> Co-authored-by: Aurick Qiao Co-authored-by: Aurick Qiao Co-authored-by: Aurick Qiao Co-authored-by: Cody Yu --- examples/offline_inference_arctic.py | 26 + .../layers/fused_moe/__init__.py | 4 +- .../layers/fused_moe/fused_moe.py | 137 +++-- .../layers/quantization/__init__.py | 3 + .../layers/quantization/deepspeedfp.py | 194 +++++++ vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/arctic.py | 521 ++++++++++++++++++ vllm/transformers_utils/configs/arctic.py | 204 +++++++ 8 files changed, 1042 insertions(+), 48 deletions(-) create mode 100644 examples/offline_inference_arctic.py create mode 100644 vllm/model_executor/layers/quantization/deepspeedfp.py create mode 100644 vllm/model_executor/models/arctic.py create mode 100644 vllm/transformers_utils/configs/arctic.py diff --git a/examples/offline_inference_arctic.py b/examples/offline_inference_arctic.py new file mode 100644 index 0000000000000..1fec3c99eb47c --- /dev/null +++ b/examples/offline_inference_arctic.py @@ -0,0 +1,26 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="snowflake/snowflake-arctic-instruct", + quantization="deepspeedfp", + tensor_parallel_size=8, + trust_remote_code=True) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. + +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 496d69c89c62b..2926c7d1c8a76 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,7 +1,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_moe, get_config_file_name) + fused_experts, fused_moe, fused_topk, get_config_file_name) __all__ = [ "fused_moe", + "fused_topk", + "fused_experts", "get_config_file_name", ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3cb0419404625..bb7938b3715be 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int, return None -def fused_moe( +def fused_topk( hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, - inplace: bool = False, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. +): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] + M, _ = hidden_states.shape - E, N, _ = w1.shape if is_hip(): # The MoE kernels are not yet supported on ROCm. @@ -393,6 +349,33 @@ def fused_moe( del token_expert_indicies # Not used. Will be used in the future. if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + M, _ = hidden_states.shape + E, N, _ = w1.shape if override_config: config = override_config @@ -477,3 +460,63 @@ def fused_moe( out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8=use_fp8, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1c652e347d4ad..5798bc359dcf2 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( @@ -19,6 +21,7 @@ "squeezellm": SqueezeLLMConfig, "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, + "deepspeedfp": DeepSpeedFPConfig } diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000000000..31cdffbcf0ab9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,194 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index c5cdc059473b3..d5263b500fe0f 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -54,6 +54,7 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000000000..796cef7c4a735 --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,521 @@ +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.arctic import ArcticConfig + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + layer_id: int, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMLP, self).__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + self.layer_id = layer_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + layer_id: int, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMoE, self).__init__() + + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + layer_id=layer_id, + quant_config=quant_config, + reduce_results=reduce_results) + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config) + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + layer_idx, + quant_config=quant_config) + self.block_sparse_moe = ArcticMoE( + config, + layer_id=layer_idx, + quant_config=quant_config, + reduce_results=(not self.use_residual)) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + layer_id=layer_idx, + is_residual_mlp=True, + reduce_results=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +class ArcticModel(nn.Module): + + def __init__( + self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.layers = nn.ModuleList([ + ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + attn_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module): + + def __init__(self, + config: ArcticConfig, + quant_config: Optional[QuantizationConfig] = None, + **kwargs) -> None: + super().__init__() + self.config = config + self.model = ArcticModel(config, quant_config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + ) + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping = [] + expert_params_mapping = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 0000000000000..7780bf5e78d6d --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,204 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoraConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret From 379da6dcb5f5d062d0452b2fc23291e5113dcf04 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 9 May 2024 16:38:07 -0700 Subject: [PATCH 246/413] [Kernel] [FP8] Improve FP8 linear layer performance (#4691) This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)). We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance. Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization: qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16) qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16) qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16) qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16) --- vllm/_custom_ops.py | 28 ++++++++++++++++++- .../model_executor/layers/quantization/fp8.py | 13 ++++++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5b56437487477..829c47003ad0e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -189,8 +189,34 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b57e1dde81a5f..ff996741c1d00 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -231,9 +231,14 @@ def apply(self, # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.act_scale is None and x_scale computed from x. # If static, layer.act_scale is scalar and x_scale set to act_scale. - qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) - - # Fused GEMM_DQ + qinput, x_scale = ops.scaled_fp8_quant(x, + layer.act_scale, + batch_dim_padding=17) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. output, _ = torch._scaled_mm( qinput, layer.weight, @@ -243,7 +248,7 @@ def apply(self, bias=bias, ) - return output + return torch.narrow(output, 0, 0, x.shape[0]) def all_close_1d(x: torch.Tensor) -> bool: From c83310174055bb124ea2197885b652efd59b7a0f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 9 May 2024 17:04:17 -0700 Subject: [PATCH 247/413] [Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535) --- .buildkite/check-wheel-size.py | 2 +- CMakeLists.txt | 2 +- cmake/utils.cmake | 4 +- csrc/attention/attention_kernels.cu | 286 ++++----- csrc/attention/dtype_fp8.cuh | 16 +- csrc/cache.h | 4 +- csrc/cache_kernels.cu | 143 +++-- .../fp8/{amd_detail => amd}/hip_float8.h | 0 .../fp8/{amd_detail => amd}/hip_float8_impl.h | 0 .../fp8/{amd_detail => amd}/quant_utils.cuh | 59 +- .../fp8/{fp8_cuda_kernels.cu => common.cu} | 0 csrc/quantization/fp8/nvidia/quant_utils.cuh | 568 ++++++++++++++++++ .../fp8_e5m2_kvcache/quant_utils.cuh | 277 --------- tests/kernels/test_attention.py | 4 +- tests/kernels/test_cache.py | 27 +- vllm/_custom_ops.py | 7 +- vllm/utils.py | 2 +- 17 files changed, 843 insertions(+), 558 deletions(-) rename csrc/quantization/fp8/{amd_detail => amd}/hip_float8.h (100%) rename csrc/quantization/fp8/{amd_detail => amd}/hip_float8_impl.h (100%) rename csrc/quantization/fp8/{amd_detail => amd}/quant_utils.cuh (81%) rename csrc/quantization/fp8/{fp8_cuda_kernels.cu => common.cu} (100%) create mode 100644 csrc/quantization/fp8/nvidia/quant_utils.cuh delete mode 100644 csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 90a5e54736cf3..41d9e682572a6 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 100 +MAX_SIZE_MB = 150 def print_top_10_largest_files(zip_file): diff --git a/CMakeLists.txt b/CMakeLists.txt index 47629f036fb09..1c7dfe0c048b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,7 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" - "csrc/quantization/fp8/fp8_cuda_kernels.cu" + "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 7c71673e36f29..00c81e4d00ad8 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) "Failed to determine torch nvcc compiler flags") if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) - list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") + list(APPEND GPU_FLAGS "-DENABLE_FP8") endif() if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) list(REMOVE_ITEM GPU_FLAGS @@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) list(APPEND GPU_FLAGS "-DUSE_ROCM" - "-DENABLE_FP8_E4M3" + "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" "-fno-gpu-rdc") diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8b1b5e098015f..41b337dd91d36 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -19,21 +19,17 @@ #include #include #include +#include #include "attention_dtypes.h" #include "attention_utils.cuh" -#if defined(ENABLE_FP8_E5M2) -#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "../quantization/fp8/amd_detail/quant_utils.cuh" -#endif - -#include - #ifdef USE_ROCM #include + #include "../quantization/fp8/amd/quant_utils.cuh" typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM @@ -92,7 +88,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -157,9 +153,7 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using Quant_vec = typename Vec::Type; -#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -223,21 +217,14 @@ __device__ void paged_attention_kernel( const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. - k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); -#elif defined(ENABLE_FP8_E4M3) - Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k - // cache vec to k vec in higher precision (FP16, BFloat16, etc.) - k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale); -#else - assert(false); -#endif - } else { + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); } } @@ -312,9 +299,7 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) using V_quant_vec = typename Vec::Type; -#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; @@ -348,21 +333,13 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec; - if constexpr (IS_FP8_KV_CACHE) { -#if defined(ENABLE_FP8_E5M2) + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); -#elif defined(ENABLE_FP8_E4M3) - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert - // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.) - v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale); -#else - assert(false); -#endif - } else { - v_vec = *reinterpret_cast(v_ptr + offset); + v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, @@ -448,7 +425,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE> + vllm::Fp8KVCacheDataType KV_DTYPE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -464,7 +441,7 @@ __global__ void paged_attention_v1_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -477,7 +454,7 @@ template< int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -496,7 +473,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride, const float kv_scale) { - paged_attention_kernel( + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); @@ -606,9 +583,9 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + KV_DTYPE>), shared_mem_size); \ vllm::paged_attention_v1_kernel<<>>( \ + KV_DTYPE><<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -629,7 +606,7 @@ template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -706,36 +683,36 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( @@ -752,65 +729,44 @@ void paged_attention_v1( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + seq_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, \ + kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< typename T, typename CACHE_T, int BLOCK_SIZE, - bool IS_FP8_KV_CACHE, + vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -897,39 +853,39 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + seq_lens, \ + max_seq_len, \ + alibi_slopes, \ kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( @@ -949,29 +905,7 @@ void paged_attention_v2( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8") { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index d11dee91ebe87..2b32ce372a64f 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -3,14 +3,21 @@ #include "attention_generic.cuh" #include -#ifdef ENABLE_FP8_E5M2 +#ifdef ENABLE_FP8 +#ifndef USE_ROCM #include -#endif +#endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { -#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) -// fp8 vector types for quantization of kv cache +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache template<> struct Vec { using Type = uint8_t; @@ -30,6 +37,5 @@ template<> struct Vec { using Type = uint2; }; -#endif // ENABLE_FP8_E5M2 } // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 212a3bf3ddc1c..8c176c452425e 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -34,5 +34,7 @@ void reshape_and_cache_flash( // Just for unittest void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache); + const float scale, + const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 76db96f099c69..e5b74da6ad068 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,10 +4,11 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#if defined(ENABLE_FP8_E5M2) -#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" -#elif defined(ENABLE_FP8_E4M3) -#include "quantization/fp8/amd_detail/quant_utils.cuh" + +#ifdef USE_ROCM +#include "quantization/fp8/amd/quant_utils.cuh" +#else +#include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -149,7 +150,7 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] @@ -194,19 +195,12 @@ __global__ void reshape_and_cache_kernel( + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; - if constexpr (is_fp8_kv_cache) { -#if defined(ENABLE_FP8_E5M2) - key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); - value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); -#elif defined(ENABLE_FP8_E4M3) - key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale); -#else - assert(false); -#endif - } else { + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); } } } @@ -248,19 +242,22 @@ __global__ void reshape_and_cache_flash_kernel( } } // namespace vllm -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x, \ kv_scale); void reshape_and_cache( @@ -285,25 +282,8 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (kv_cache_dtype == "auto") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, float, false); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); - } - } else if (kv_cache_dtype == "fp8") { - if (key.dtype() == at::ScalarType::Float) { - CALL_RESHAPE_AND_CACHE(float, uint8_t, true); - } else if (key.dtype() == at::ScalarType::Half) { - CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); - } else if (key.dtype() == at::ScalarType::BFloat16) { - CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); - } - } else { - TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); - } + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( @@ -353,35 +333,34 @@ void reshape_and_cache_flash( namespace vllm { -template +template __global__ void convert_fp8_kernel( const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache, + const float kv_scale, const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; -#if defined(ENABLE_FP8_E5M2) - dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); -#elif defined(ENABLE_FP8_E4M3) - dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); -#else - assert(false); -#endif + dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); } } } // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + kv_scale, \ block_stride); +// Only for testing. void convert_fp8( + torch::Tensor& dst_cache, torch::Tensor& src_cache, - torch::Tensor& dst_cache) + const float kv_scale, + const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); @@ -399,17 +378,35 @@ void convert_fp8( dim3 block(std::min(block_stride, int64_t(512))); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); } } diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8.h rename to csrc/quantization/fp8/amd/hip_float8.h diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h similarity index 100% rename from csrc/quantization/fp8/amd_detail/hip_float8_impl.h rename to csrc/quantization/fp8/amd/hip_float8_impl.h diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh similarity index 81% rename from csrc/quantization/fp8/amd_detail/quant_utils.cuh rename to csrc/quantization/fp8/amd/quant_utils.cuh index 894160972d9f4..df0329f79d361 100644 --- a/csrc/quantization/fp8/amd_detail/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,12 +5,17 @@ #include #include +#include "../../../attention/dtype_fp8.cuh" #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" namespace vllm { -namespace fp8_e4m3 { +#ifdef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + template __inline__ __device__ Tout vec_conversion(const Tin& x) { @@ -512,6 +517,58 @@ __inline__ __device__ float4 scaled_vec_conversion(const uint3 float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } +#endif // ENABLE_FP8 +template +__inline__ __device__ Tout convert(const Tin &x) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } +#endif + assert(false); } + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // fp8 +#endif // USE_ROCM } // namespace vllm diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/common.cu similarity index 100% rename from csrc/quantization/fp8/fp8_cuda_kernels.cu rename to csrc/quantization/fp8/common.cu diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh new file mode 100644 index 0000000000000..4eeacf7a6f9d9 --- /dev/null +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -0,0 +1,568 @@ +#pragma once + +#include "../../../attention/attention_dtypes.h" +#include +#include +#include +#include + +namespace vllm { +#ifndef USE_ROCM + +namespace fp8 { +#ifdef ENABLE_FP8 + +#if 0 // Disable the following code to reduce the binary size. +template +__inline__ __device__ Tout +vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); + tmp.u32[1] = + vec_conversion((uint16_t)(a >> 16U), fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x, fp8_type); + tmp.u64[1] = vec_conversion(a.y, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); + res.y = + vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float +vec_conversion(const uint8_t &a, + const __nv_fp8_interpretation_t fp8_type) { + // fp8 -> half + uint16_t tmp = vec_conversion(a, fp8_type); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = vec_conversion((uint16_t)a, fp8_type); + res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion( + const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x, fp8_type); + tmp2 = vec_conversion(a.y, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = + __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( + __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion( + const float &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion( + const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = vec_conversion(a, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val, fp8_type); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val, fp8_type); + + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + uint4 b; + b.x = vec_conversion(a.x, fp8_type); + b.y = vec_conversion(a.y, fp8_type); + b.z = vec_conversion(a.z, fp8_type); + b.w = vec_conversion(a.w, fp8_type); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( + const float2 &a, const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion( + const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion( + const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { + bf16_8_t b; + from_float(b, a); + return b; +} +#endif + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains Convention of the scale in API, e.g: FP8_data = + Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 + Dequant(FP8) * scale => HP + */ + +template +__inline__ __device__ Tout scaled_vec_conversion( + const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); + return float_to_half(half_to_float(tmp.x) * scale); +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); + tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); + tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); + return tmp.u32; +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = + scaled_vec_conversion((uint16_t)a, scale, fp8_type); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), + scale, fp8_type); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); + tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), + scale, fp8_type); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, + fp8_type); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale, fp8_type); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); + uint16_t tmp = res.x; + + // half -> float + return half_to_float(tmp) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + // fp8x2 -> half2 + uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, + fp8_type); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion( + const uint2 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); + tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const uint16_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16 &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, + __NV_SATFINITE, fp8_type); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const float &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + __nv_fp8_storage_t res = + __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion( + const uint32_t &a, const float scale, + const __nv_fp8_interpretation_t fp8_type) { + Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} +#endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin &x) { +#if 0 // Disable the following code to reduce the binary size. + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return vec_conversion(x, __NV_E5M2); + } +#endif + assert(false); +} + +template +__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { +#ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale, __NV_E4M3); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + return scaled_vec_conversion(x, scale, __NV_E5M2); + } +#endif + assert(false); +} + +// The following macro is used to dispatch the conversion function based on the +// data type of the key and value cache. The FN is a macro that calls a function +// with template. +#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh deleted file mode 100644 index 9bcab25db03cf..0000000000000 --- a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh +++ /dev/null @@ -1,277 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include "../../attention/attention_dtypes.h" -#include "../../attention/dtype_float32.cuh" -#include "../../attention/dtype_float16.cuh" -#include "../../attention/dtype_bfloat16.cuh" - - -namespace vllm { -#ifdef ENABLE_FP8_E5M2 -namespace fp8_e5m2_unscaled { - -template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; -} - -// fp8 -> half -template<> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - return res.x; -} - -// fp8x2 -> half2 -template<> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); - tmp.u16[0] = res.x; - tmp.u16[1] = res.y; - return tmp.u32; -} - -// fp8x4 -> half2x2 -template<> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; -} - -// fp8x8 -> half2x4 -template<> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; -} - -// fp8 -> __nv_bfloat16 -template<> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - // Note there is no direct convert function from fp8 to bf16. - // fp8 -> half - __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); - // half -> float -> bf16 - float tmp = half_to_float(res.x); - return __float2bfloat16(tmp); -} - -// fp8x2 -> __nv_bfloat162 -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; -} - -// fp8x4 -> bf16_4_t -template<> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> bf16_8_t -template<> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - -// fp8 -> float -template<> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - // fp8 -> half - uint16_t tmp = vec_conversion(a); - // half -> float - return half_to_float(tmp); -} - -// fp8x2 -> float2 -template<> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ - // fp8x2 -> half2 - uint32_t tmp = vec_conversion(a); - // half2 -> float2 - return half2_to_float2(tmp); -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; -} - -// fp8x8 -> float8 -template<> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; -} - - -// half -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; - __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// bf16 -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - assert(false); -#else - __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -#endif -} - -// float -> fp8 -template<> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); - return (uint8_t)res; -} - -// fp8x4 -> float4 -template<> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; -} - - -template<> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; - - float16 = __float22half2_rn(a); - return uint32; -} - -template<> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); - - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - - return b; -} - -template<> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; -} - -template<> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; -} - -template<> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { - __nv_bfloat162 b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { - bf16_4_t b; - from_float(b, a); - return b; -} - -template<> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { - bf16_8_t b; - from_float(b, a); - return b; -} - -} // namespace fp8_e5m2_unscaled -#endif // ENABLE_FP8_E5M2 -} // namespace vllm diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 84539205e0ae3..28496f187d466 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -236,14 +236,14 @@ def test_paged_attention( dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(key_cache, dequantized_key_cache) + ops.convert_fp8(dequantized_key_cache, key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) - ops.convert_fp8(value_cache, dequantized_value_cache) + ops.convert_fp8(dequantized_value_cache, value_cache) value_cache = dequantized_value_cache ref_output = torch.empty_like(query) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 4cae15c79c489..9f0cb60dc16e2 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,8 +5,6 @@ import torch from vllm import _custom_ops as ops -from vllm._C import cache_ops -from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -25,6 +23,8 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + +# We assume fp8 is always enabled for testing. KV_CACHE_DTYPE = ["auto", "fp8"] @@ -124,8 +124,6 @@ def test_reshape_and_cache( device: str, kv_cache_dtype: str, ) -> None: - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -149,9 +147,9 @@ def test_reshape_and_cache( # Clone the KV caches. if kv_cache_dtype == "fp8": cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, cloned_key_cache) + ops.convert_fp8(cloned_key_cache, key_cache) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, cloned_value_cache) + ops.convert_fp8(cloned_value_cache, value_cache) else: cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() @@ -165,9 +163,9 @@ def test_reshape_and_cache( if kv_cache_dtype == "fp8": result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) - ops.convert_fp8(key_cache, result_key_cache) + ops.convert_fp8(result_key_cache, key_cache) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) - ops.convert_fp8(value_cache, result_value_cache) + ops.convert_fp8(result_value_cache, value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -255,8 +253,8 @@ def test_reshape_and_cache_flash( cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) # Run the reference implementation. block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') @@ -299,8 +297,6 @@ def test_swap_blocks( ) -> None: if kv_cache_dtype == "fp8" and "cpu" in direction: pytest.skip() - if not is_hip() and kv_cache_dtype == "fp8": - pytest.skip() # This test is not tuned for e5m2 cuda precision random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): @@ -348,7 +344,6 @@ def test_swap_blocks( dist_value_caches[0][dst].cpu()) -@pytest.mark.skipif(not is_hip(), reason="FP8 conversion test requires e4m3") @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -357,7 +352,7 @@ def test_swap_blocks( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_fp8_conversion( +def test_fp8_e4m3_conversion( num_heads: int, head_size: int, block_size: int, @@ -377,9 +372,9 @@ def test_fp8_conversion( cache.uniform_(low, high) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) - ops.convert_fp8(cache, cache_fp8) + ops.convert_fp8(cache_fp8, cache) converted_cache = torch.empty_like(cache) - ops.convert_fp8(cache_fp8, converted_cache) + ops.convert_fp8(converted_cache, cache_fp8) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 829c47003ad0e..35a9f6329fc42 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -270,8 +270,11 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor, vllm_cache_ops.swap_blocks(src, dst, block_mapping) -def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None: - vllm_cache_ops.convert_fp8(output, input) +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) #TODO: cuda_utils, custom_ar diff --git a/vllm/utils.py b/vllm/utils.py index 6479a8dab320a..f0e71f5e99b64 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -329,7 +329,7 @@ def _generate_random_fp8( from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) - ops.convert_fp8(tensor_tmp, tensor) + ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp From 208b71bcc1b94df1fdd2fc10da3e04c706340188 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 May 2024 19:48:43 -0700 Subject: [PATCH 248/413] [Core][Distributed] refactor pynccl (#4591) [Core][Distributed] refactor pynccl to hold multiple communicators (#4591) --- tests/distributed/test_pynccl.py | 78 ++--- vllm/distributed/communication_op.py | 28 +- .../device_communicators/pynccl.py | 294 +++++------------- .../device_communicators/pynccl_utils.py | 66 ---- .../device_communicators/pynccl_wrapper.py | 258 +++++++++++++++ vllm/distributed/parallel_state.py | 131 ++++---- vllm/worker/model_runner.py | 25 +- vllm/worker/worker.py | 21 -- 8 files changed, 467 insertions(+), 434 deletions(-) delete mode 100644 vllm/distributed/device_communicators/pynccl_utils.py create mode 100644 vllm/distributed/device_communicators/pynccl_wrapper.py diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b6f461b76ed03..b3e30a0434423 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -1,15 +1,15 @@ import multiprocessing +import os import pytest import torch -import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils -from vllm.distributed.communication_op import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetUniqueId) -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, - init_distributed_environment, with_pynccl_for_all_reduce) +from vllm.distributed.communication_op import ( # noqa + graph_capture_mode, tensor_model_parallel_all_reduce) +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.utils import update_environment_variables @@ -41,6 +41,9 @@ def worker_fn_wrapper(fn): # and update the environment variables in the function def wrapped_fn(env): update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) init_distributed_environment() fn() @@ -49,11 +52,13 @@ def wrapped_fn(env): @worker_fn_wrapper def worker_fn(): - comm = NCCLCommunicator() - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) - comm.all_reduce(tensor) + pynccl_comm = PyNcclCommunicator() + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() - assert result == comm.world_size + assert result == pynccl_comm.world_size @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -70,37 +75,35 @@ def multiple_tp_worker_fn(): torch.distributed.new_group(ranks=[2, 3], backend="gloo") ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] - comm = NCCLCommunicator(group=group, device=device) + pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - comm.all_reduce(tensor) - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 - else: - comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.all_reduce(tensor) + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + pynccl_comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): # this tests pynccl for multiple tp groups, in a standalone way - # i.e. call `comm.all_reduce` directly + # i.e. call `pynccl_comm.all_reduce` directly distributed_run(multiple_tp_worker_fn, 4) @worker_fn_wrapper def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - torch.cuda.set_device(torch.distributed.get_rank()) ensure_model_parallel_initialized(2, 2) - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with with_pynccl_for_all_reduce(): + with graph_capture_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) @@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - comm = NCCLCommunicator() + pynccl_comm = PyNcclCommunicator() # run something in the default stream to initialize torch engine - a = torch.ones((4, 4), device=f'cuda:{comm.rank}') + a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph, stream=comm.stream): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): # operation during the graph capture is recorded but not executed # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - comm.all_reduce(a) - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**0 + pynccl_comm.all_reduce(a) + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() - comm.stream.synchronize() - assert a.mean().cpu().item() == comm.world_size**1 + pynccl_comm.stream.synchronize() + assert a.mean().cpu().item() == pynccl_comm.world_size**1 @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph(): def test_ncclGetUniqueId(): - unique_id = ncclGetUniqueId() + lib = NCCLLibrary() + unique_id = lib.ncclGetUniqueId() # `list(unique_id.internal)` is something like this: # [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0, # 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 80d03129bdb9b..32ab5694e5390 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,4 +1,5 @@ from collections import namedtuple +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -8,7 +9,26 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - is_pynccl_enabled_for_all_reduce) + get_tp_pynccl_communicator) + + +@contextmanager +def graph_capture_mode(): + # In graph capture, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor size + # is too large, it will fallback to the next available option. + pynccl_comm = get_tp_pynccl_communicator() + assert pynccl_comm is not None + with pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()): + yield def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: @@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( custom_all_reduce) @@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: out = custom_all_reduce(input_) if out is not None: return out - if is_pynccl_enabled_for_all_reduce(): - pynccl_utils.all_reduce(input_) + pynccl_comm = get_tp_pynccl_communicator() + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) else: torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 758994352e3de..168d4cc2df8a6 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,26 +1,4 @@ -# This file is a pure Python wrapper for the NCCL library. -# The main purpose is to use NCCL combined with CUDA graph. -# Before writing this script, we tried the following approach: -# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself -# often gets stuck when initializing the NCCL communicator. -# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` -# contains many other potential cuda APIs, that are not allowed during -# capturing the CUDA graph. For further details, please check -# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . -# -# Another rejected idea is to write a C/C++ binding for NCCL. It is usually -# doable, but we often encounter issues related with nccl versions, and need -# to switch between different versions of NCCL. See -# https://github.com/NVIDIA/nccl/issues/1234 for more details. -# A C/C++ binding is not flexible enough to handle this. It requires -# recompilation of the code every time we want to switch between different -# versions. This current implementation, with a **pure** Python wrapper, is -# more flexible. We can easily switch between different versions of NCCL by -# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` -# variable in the code. - -import ctypes -import platform +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -28,217 +6,70 @@ import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check logger = init_logger(__name__) -so_file = find_nccl_library() - -try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) - nccl = ctypes.CDLL(so_file) -except Exception as e: - logger.error( - "Failed to load NCCL library from %s ." - "It is expected if you are not running on NVIDIA/AMD GPUs." - "Otherwise, the nccl library might not exist, be corrupted " - "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" - " to point to the correct nccl library path.", so_file, - platform.platform()) - raise e - -# === export types and functions from nccl to Python === -# for the original nccl definition, please check -# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in - -ncclResult_t = ctypes.c_int - -_c_ncclGetErrorString = nccl.ncclGetErrorString -_c_ncclGetErrorString.restype = ctypes.c_char_p -_c_ncclGetErrorString.argtypes = [ncclResult_t] - - -def NCCL_CHECK(result: ncclResult_t) -> None: - if result != 0: - error_str = _c_ncclGetErrorString(result) - error_str = error_str.decode("utf-8") - raise RuntimeError(f"NCCL error: {error_str}") - - -# equivalent to c declaration: -# ncclResult_t ncclGetVersion(int *version); -_c_ncclGetVersion = nccl.ncclGetVersion -_c_ncclGetVersion.restype = ctypes.c_int -_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - - -def ncclGetVersion() -> str: - version = ctypes.c_int() - NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version))) - # something like 21903 --> "2.19.3" - version_str = str(version.value) - major = version_str[0].lstrip("0") - minor = version_str[1:3].lstrip("0") - patch = version_str[3:].lstrip("0") - return f"{major}.{minor}.{patch}" - - -class NcclUniqueId(ctypes.Structure): - _fields_ = [("internal", ctypes.c_byte * 128)] - - -# equivalent to c declaration: -# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -_c_ncclGetUniqueId = nccl.ncclGetUniqueId -_c_ncclGetUniqueId.restype = ctypes.c_int -_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)] - - -def ncclGetUniqueId() -> NcclUniqueId: - unique_id = NcclUniqueId() - NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id))) - return unique_id - - -# equivalent to c declaration: -# ncclResult_t ncclCommInitRank( -# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -# note that ncclComm_t is a pointer type, so the first argument -# is a pointer to a pointer -_c_ncclCommInitRank = nccl.ncclCommInitRank -_c_ncclCommInitRank.restype = ctypes.c_int -_c_ncclCommInitRank.argtypes = [ - ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int -] - -ncclDataType_t = ctypes.c_int - - -class ncclDataTypeEnum: - ncclInt8 = 0 - ncclChar = 0 - ncclUint8 = 1 - ncclInt32 = 2 - ncclInt = 2 - ncclUint32 = 3 - ncclInt64 = 4 - ncclUint64 = 5 - ncclFloat16 = 6 - ncclHalf = 6 - ncclFloat32 = 7 - ncclFloat = 7 - ncclFloat64 = 8 - ncclDouble = 8 - ncclBfloat16 = 9 - ncclNumTypes = 10 - @classmethod - def from_torch(cls, dtype: torch.dtype) -> int: - if dtype == torch.int8: - return cls.ncclInt8 - if dtype == torch.uint8: - return cls.ncclUint8 - if dtype == torch.int32: - return cls.ncclInt32 - if dtype == torch.int64: - return cls.ncclInt64 - if dtype == torch.float16: - return cls.ncclFloat16 - if dtype == torch.float32: - return cls.ncclFloat32 - if dtype == torch.float64: - return cls.ncclFloat64 - if dtype == torch.bfloat16: - return cls.ncclBfloat16 - raise ValueError(f"Unsupported dtype: {dtype}") - - -ncclRedOp_t = ctypes.c_int - - -class ncclRedOpTypeEnum: - ncclSum = 0 - ncclProd = 1 - ncclMax = 2 - ncclMin = 3 - ncclAvg = 4 - ncclNumOps = 5 - - @classmethod - def from_torch(cls, op: ReduceOp) -> int: - if op == ReduceOp.SUM: - return cls.ncclSum - if op == ReduceOp.PRODUCT: - return cls.ncclProd - if op == ReduceOp.MAX: - return cls.ncclMax - if op == ReduceOp.MIN: - return cls.ncclMin - if op == ReduceOp.AVG: - return cls.ncclAvg - raise ValueError(f"Unsupported op: {op}") - - -# equivalent to c declaration: -# ncclResult_t ncclAllReduce( -# const void* sendbuff, void* recvbuff, size_t count, -# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, -# udaStream_t stream); -# note that cudaStream_t is a pointer type, so the last argument is a pointer -_c_ncclAllReduce = nccl.ncclAllReduce -_c_ncclAllReduce.restype = ctypes.c_int -_c_ncclAllReduce.argtypes = [ - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t, - ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p -] - -# be cautious! this is a collective call, it will block until all -# processes in the communicator have called this function. -# because Python object destruction can happen in random order, -# it is better not to call it at all. -# equivalent to c declaration: -# ncclResult_t ncclCommDestroy(ncclComm_t comm); -_c_ncclCommDestroy = nccl.ncclCommDestroy -_c_ncclCommDestroy.restype = ctypes.c_int -_c_ncclCommDestroy.argtypes = [ctypes.c_void_p] - - -class NCCLCommunicator: +class PyNcclCommunicator: def __init__( self, group: Optional[ProcessGroup] = None, device: Optional[Union[int, str, torch.device]] = None, + library_path: Optional[str] = None, ): """ Args: group: the process group to work on. If None, it will use the default process group. - device: the device to bind the NCCLCommunicator to. If None, + device: the device to bind the PyNcclCommunicator to. If None, it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. It is the caller's responsibility to make sure each communicator is bind to a unique device. """ assert dist.is_initialized() group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( - "NCCLCommunicator should be attached to a non-NCCL group.") + "PyNcclCommunicator should be attached to a non-NCCL group.") self.group = group # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + if self.rank == 0: - self.unique_id = ncclGetUniqueId() + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() else: - self.unique_id = NcclUniqueId() + # construct an empty unique id + self.unique_id = ncclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) ranks = dist.get_process_group_ranks(group) # arg `src` in `broadcast` is the global rank @@ -246,7 +77,6 @@ def __init__( byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte - self.comm = ctypes.c_void_p() if device is None: local_rank = get_local_rank() device = torch.device(f"cuda:{local_rank}") @@ -261,15 +91,25 @@ def __init__( # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one with torch.cuda.device(device): - NCCL_CHECK( - _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, - self.unique_id, self.rank)) + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) self.stream = torch.cuda.Stream() + # A small all_reduce for warmup. + self.all_reduce(torch.zeros(1, device=device)) + self.stream.synchronize() + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + def all_reduce(self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None): + if self.disabled: + return # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" @@ -278,10 +118,32 @@ def all_reduce(self, f"but the input tensor is on {tensor.device}") if stream is None: stream = self.stream - NCCL_CHECK( - _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()), - ctypes.c_void_p(tensor.data_ptr()), - tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), - ncclRedOpTypeEnum.from_torch(op), self.comm, - ctypes.c_void_p(stream.cuda_stream))) + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py deleted file mode 100644 index 44e4f39217a41..0000000000000 --- a/vllm/distributed/device_communicators/pynccl_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -import contextlib -from typing import Optional - -import torch -from torch.distributed import ProcessGroup, ReduceOp - -from vllm.logger import init_logger - -logger = init_logger(__name__) - -try: - from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, - ncclGetVersion) -except Exception as e: - # in non-NVIDIA environments, we can't import the nccl module - # e.g. when running on machines with AMD GPUs - logger.info("Failed to import NCCL library: %s", e) - logger.info("It is expected if you are not running on NVIDIA GPUs.") - pass - -comm: Optional["NCCLCommunicator"] = None - - -def is_initialized() -> bool: - """Returns whether the NCCL backend is initialized.""" - return comm is not None - - -@contextlib.contextmanager -def set_pynccl_stream(stream: torch.cuda.Stream): - """Set the cuda stream for communication""" - try: - assert comm is not None - comm.stream = stream - yield - finally: - pass - - -def init_process_group(group: Optional[ProcessGroup] = None) -> None: - assert not is_initialized() - global comm - logger.info("vLLM is using nccl==%s", ncclGetVersion()) - comm = NCCLCommunicator(group=group) - - -def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: - """All-reduces the input tensor across the process group.""" - assert input_.is_cuda, f"{input_} should be a cuda tensor" - assert comm is not None - comm.all_reduce(input_, op) - - -def destroy_process_group() -> None: - global comm - comm = None - - -def get_world_size() -> int: - """Returns the world size.""" - assert comm is not None - return comm.world_size - - -def get_nccl_backend() -> Optional["NCCLCommunicator"]: - return comm diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000000..43d85674b23d0 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,258 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library, nccl_integrity_check + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + # load the library in another process. + # if it core dumps, it will not crash the current process + nccl_integrity_check(so_file) + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "One solution is to download libnccl2 version 2.18 from " + "https://developer.download.nvidia.com/compute/cuda/repos/ " + "and extract the libnccl.so.2 file. If you already have the " + "library, please set the environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index be5bb4e857caf..5075da11bb1b8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,10 +3,10 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" -import contextlib -from typing import Optional +from typing import List, Optional import torch +from torch.distributed import ProcessGroup import vllm.envs as envs from vllm.logger import init_logger @@ -14,10 +14,11 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TP_DEVICE_GROUP = None -_TP_CPU_GROUP = None +_TP_DEVICE_GROUP: Optional[ProcessGroup] = None +_TP_CPU_GROUP: Optional[ProcessGroup] = None +_TP_PYNCCL_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None +_PP_DEVICE_GROUP: Optional[ProcessGroup] = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -41,11 +42,16 @@ # A list of global ranks for each pipeline group to ease calculation of the # source rank when broadcasting from the first or last pipeline stage. -_PIPELINE_GLOBAL_RANKS = None +_PP_GLOBAL_RANKS: Optional[List[int]] = None _LOCAL_RANK = -1 +def get_tp_pynccl_communicator(): + global _TP_PYNCCL_COMMUNICATOR + return _TP_PYNCCL_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -80,10 +86,20 @@ def init_distributed_environment( # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1 and distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank global _LOCAL_RANK _LOCAL_RANK = local_rank + # A small all_reduce for warmup. + data = torch.zeros(1) + if torch.cuda.is_available(): + data = data.to(device=f"cuda:{local_rank}") + torch.distributed.all_reduce(data) def initialize_model_parallel( @@ -133,29 +149,36 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): - ranks = range(i * tensor_model_parallel_size, - (i + 1) * tensor_model_parallel_size) + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) group = torch.distributed.new_group(ranks, backend=backend) cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _TP_DEVICE_GROUP = group _TP_CPU_GROUP = cpu_group + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + global _PP_DEVICE_GROUP + global _PP_GLOBAL_RANKS + assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks + _PP_DEVICE_GROUP = group + _PP_GLOBAL_RANKS = ranks def ensure_model_parallel_initialized( @@ -188,8 +211,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP_DEVICE_GROUP is not None - and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) def get_cpu_world_group(): @@ -214,9 +236,9 @@ def get_tensor_model_parallel_cpu_group(): def get_pipeline_model_parallel_group(): """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + assert _PP_DEVICE_GROUP is not None, ( "pipeline model parallel group is not initialized") - return _PIPELINE_MODEL_PARALLEL_GROUP + return _PP_DEVICE_GROUP def get_tensor_model_parallel_world_size(): @@ -253,36 +275,36 @@ def get_tensor_model_parallel_src_rank(): def get_pipeline_model_parallel_first_rank(): """Return the global rank of the first process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") - return _PIPELINE_GLOBAL_RANKS[0] + return _PP_GLOBAL_RANKS[0] def get_pipeline_model_parallel_last_rank(): """Return the global rank of the last process in the pipeline for the current tensor parallel group""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] + return _PP_GLOBAL_RANKS[last_rank_local] def get_pipeline_model_parallel_next_rank(): """Return the global rank that follows the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] def get_pipeline_model_parallel_prev_rank(): """Return the global rank that precedes the caller in the pipeline""" - assert _PIPELINE_GLOBAL_RANKS is not None, ( + assert _PP_GLOBAL_RANKS is not None, ( "Pipeline parallel group is not initialized") rank_in_pipeline = get_pipeline_model_parallel_rank() world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] def destroy_model_parallel(): @@ -295,45 +317,12 @@ def destroy_model_parallel(): if _TP_CPU_GROUP: torch.distributed.destroy_process_group(_TP_CPU_GROUP) _TP_CPU_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - if _PIPELINE_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _PIPELINE_GLOBAL_RANKS - _PIPELINE_GLOBAL_RANKS = None - from vllm.distributed.device_communicators import pynccl_utils - - # Destroy the pynccl states if any. - pynccl_utils.destroy_process_group() - - -# Whether to use pynccl for nccl all reduce. -# We use pynccl for all reduce when using CUDA graph, because torch.distributed -# is not well supported by CUDA graph. -_ENABLE_PYNCCL_FOR_ALL_REDUCE = False - - -@contextlib.contextmanager -def with_pynccl_for_all_reduce(): - from vllm.distributed.device_communicators import pynccl_utils - """use pynccl instead of torch.distributed for all reduce""" - tp_size = get_tensor_model_parallel_world_size() - if tp_size == 1: - # No-op. - # NOTE(woosuk): We don't initialize pynccl when tp_size is 1. - yield - else: - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - old = _ENABLE_PYNCCL_FOR_ALL_REDUCE - _ENABLE_PYNCCL_FOR_ALL_REDUCE = True - - stream = torch.cuda.current_stream() - with pynccl_utils.set_pynccl_stream(stream): - yield - _ENABLE_PYNCCL_FOR_ALL_REDUCE = old - - -def is_pynccl_enabled_for_all_reduce(): - """check if pynccl is enabled for all reduce""" - global _ENABLE_PYNCCL_FOR_ALL_REDUCE - return _ENABLE_PYNCCL_FOR_ALL_REDUCE + global _TP_PYNCCL_COMMUNICATOR + _TP_PYNCCL_COMMUNICATOR = None + + global _PP_DEVICE_GROUP + if _PP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) + _PP_DEVICE_GROUP = None + global _PP_GLOBAL_RANKS + _PP_GLOBAL_RANKS = None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e582116297c..3fc76c6142165 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,3 @@ -import contextlib import time from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple @@ -12,9 +11,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) -from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce -from vllm.distributed.device_communicators import (custom_all_reduce, - pynccl_utils) +from vllm.distributed import broadcast_tensor_dict +from vllm.distributed.communication_op import graph_capture_mode +from vllm.distributed.device_communicators import custom_all_reduce from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -917,10 +916,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: Since it is used for decoding-only, it assumes there's only 1 token per sequence in the batch. """ - # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never - # deleted before the CUDA graphs. - self.pynccl_backend = pynccl_utils.get_nccl_backend() - assert not self.model_config.enforce_eager logger.info("Capturing the model for CUDA graphs. This may lead to " "unexpected consequences if the model is not static. To " @@ -1046,7 +1041,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with _maybe_pynccl(): + with graph_capture_mode(): self.model( input_ids, positions, @@ -1061,7 +1056,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with _maybe_pynccl(): + with graph_capture_mode(): hidden_states = self.model( input_ids, positions, @@ -1113,16 +1108,6 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -@contextlib.contextmanager -def _maybe_pynccl(): - if pynccl_utils.is_initialized( - ) and not custom_all_reduce.is_initialized(): - with with_pynccl_for_all_reduce(): - yield - else: - yield - - def _get_graph_batch_size(batch_size: int) -> int: """Returns the padded batch size given actual batch size. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 43f6b2b443b70..0ca9c2b64cf30 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,7 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - get_tensor_model_parallel_cpu_group, init_distributed_environment) -from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest @@ -306,29 +304,10 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - if pynccl_utils.is_initialized(): - pynccl_world_size = pynccl_utils.get_world_size() - if pynccl_world_size != parallel_config.world_size: - raise RuntimeError( - "pynccl is already initialized but the pynccl world " - "size does not match parallel_config.world_size " - f"({pynccl_world_size} vs. {parallel_config.world_size}).") - elif parallel_config.world_size > 1: - # NOTE(woosuk): We don't initialize pynccl process group when world size - # is 1. - # NOTE(kaichao): By default, pynccl is initialized for tp group. - pynccl_utils.init_process_group( - group=get_tensor_model_parallel_cpu_group()) - # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: init_custom_ar() - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) - if pynccl_utils.is_initialized(): - pynccl_utils.all_reduce(torch.zeros(1).cuda()) - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. From e965d4618430db94fed73d4305c44c9147e47150 Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Fri, 10 May 2024 12:42:38 +0800 Subject: [PATCH 249/413] [Misc] Keep only one implementation of the create_dummy_prompt function. (#4716) --- tests/test_sequence.py | 36 ++++-------------------------------- 1 file changed, 4 insertions(+), 32 deletions(-) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b16bdc141e57c..53061278d5be4 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,36 +1,8 @@ -import time -from typing import Optional - import pytest -from vllm import SamplingParams -from vllm.lora.request import LoRARequest -from vllm.sequence import (SamplerOutput, Sequence, SequenceData, - SequenceGroup, SequenceGroupOutput, SequenceOutput) - - -def create_dummy_prompt( - request_id: str, - prompt_length: int, - block_size: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, - best_of: int = 1, -) -> SequenceGroup: - if not block_size: - block_size = prompt_length - - # Create dummy prompt sequence with tokens 0...block_size-1 - # and prompt "0 ... block_size". - prompt_tokens = list(range(prompt_length)) - prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) - - return seq_group +from tests.core.utils import create_dummy_prompt +from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, + SequenceOutput) @pytest.fixture @@ -102,7 +74,7 @@ def test_sequence_data_prefill(): def test_sequence_group_stage(): - seq_group = create_dummy_prompt("1", 12) + _, seq_group = create_dummy_prompt("1", 12) assert seq_group.is_prefill() is True seq_group.update_num_computed_tokens(6) assert seq_group.is_prefill() is True From 51d4094fda63b1d738f55ae9dd75d354b9c1143c Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 9 May 2024 22:13:23 -0700 Subject: [PATCH 250/413] chunked-prefill-doc-syntax (#4603) Fix the docs: https://docs.vllm.ai/en/latest/models/performance.html Co-authored-by: sang --- docs/source/models/performance.rst | 36 +++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst index 067757699f32a..589fce21056c2 100644 --- a/docs/source/models/performance.rst +++ b/docs/source/models/performance.rst @@ -7,7 +7,7 @@ Chunked Prefill --------------- vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. -You can enable the feature by specifying +You can enable the feature by specifying ``--enable-chunked-prefill`` in the command line or setting ``enable_chunked_prefill=True`` in the LLM constructor. .. code-block:: python @@ -16,23 +16,29 @@ You can enable the feature by specifying # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) -By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. +This policy optimizes the TTFT (time to the first token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. -Once chunked prefill is enabled, the policy is changed to +Once chunked prefill is enabled, the policy is changed to prioritize decode requests. +It batches all pending decode requests to the batch before scheduling any prefill. +When there are available token_budget (``max_num_batched_tokens``), it schedules pending prefills. +If a last pending prefill request cannot fit into ``max_num_batched_tokens``, it chunks it. -- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. -- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. +This policy has two benefits: -This policy has two benefits. - -- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. +- It improves ITL and generation decode because decode requests are prioritized. - It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. -You can tune the performance by changing `max_num_batched_tokens`. -By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. -Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. -Higher batch size achieves better TTFT as you can put more prefill to the batch. -If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). -Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. +You can tune the performance by changing ``max_num_batched_tokens``. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark (llama 70B and mixtral 8x22B). +Smaller ``max_num_batched_tokens`` achieves better ITL because there are fewer prefills interrupting decodes. +Higher ``max_num_batched_tokens`` achieves better TTFT as you can put more prefill to the batch. + +- If ``max_num_batched_tokens`` is the same as ``max_model_len``, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +- Note that the default value (512) of ``max_num_batched_tokens`` is optimized for ITL, and it may have lower throughput than the default scheduler. + +We recommend you set ``max_num_batched_tokens > 2048`` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). -See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). +Please try out this feature and let us know your feedback via GitHub issues! \ No newline at end of file From 64b77dfd7e1378853ec7b189f3d7d0e51ce18855 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 10 May 2024 20:52:48 +0800 Subject: [PATCH 251/413] [Core]fix type annotation for `swap_blocks` (#4726) --- vllm/_custom_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 35a9f6329fc42..42dedfdf76c4f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch @@ -266,7 +266,7 @@ def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, def swap_blocks(src: torch.Tensor, dst: torch.Tensor, - block_mapping: Dict[int, int]) -> None: + block_mapping: torch.Tensor) -> None: vllm_cache_ops.swap_blocks(src, dst, block_mapping) From dac6a3f6ed14ea4061b672f9290bfdf8bcdd996d Mon Sep 17 00:00:00 2001 From: Steve Grubb Date: Fri, 10 May 2024 09:37:05 -0400 Subject: [PATCH 252/413] [Misc] Apply a couple g++ cleanups (#4719) --- csrc/cpu/cache.cpp | 2 +- csrc/cpu/pos_encoding.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 620d11ef1ed6c..26e81685d623e 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -84,7 +84,7 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, const torch::Tensor& block_mapping) { - int num_layers = key_caches.size(); + unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e9b3992204bb2..5dc1bde45ac5f 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -19,7 +19,6 @@ void rotary_embedding_impl( const int num_tokens) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); - constexpr int ELEM_SIZE = sizeof(scalar_t); const int embed_dim = rot_dim / 2; TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); From 6a0f617210dfba76f3db4db1155d1f1489609133 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 10 May 2024 23:54:32 +0900 Subject: [PATCH 253/413] [Core] Fix circular reference which leaked llm instance in local dev env (#4737) Storing exception frame is extremely prone to circular refernece because it contains the reference to objects. When tensorizer is not installed, it leaks llm instance because error frame has references to various modules which cause circular reference problem. I also found spec decoding has a circular reference issue, and I solved it using weakref.proxy. --- tests/basic_correctness/test_basic_correctness.py | 13 +++++++++++++ vllm/model_executor/model_loader/tensorizer.py | 10 +++++----- vllm/spec_decode/multi_step_worker.py | 3 ++- vllm/spec_decode/ngram_worker.py | 3 ++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d75279dd9cfa9..7d8117447ca0a 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -3,9 +3,12 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import os +import weakref import pytest +from vllm import LLM + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -13,6 +16,16 @@ VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" +def test_vllm_gc_ed(): + """Verify vllm instance is GC'ed when it is deleted""" + llm = LLM("facebook/opt-125m") + weak_llm = weakref.ref(llm) + del llm + # If there's any circular reference to vllm, this fails + # because llm instance is not GC'ed. + assert weak_llm() is None + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index af433b86e604d..219a2a392e129 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -tensorizer_load_fail = None +tensorizer_error_msg = None try: from tensorizer import (DecryptionParams, EncryptionParams, @@ -28,7 +28,7 @@ from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) except ImportError as e: - tensorizer_load_fail = e + tensorizer_error_msg = str(e) __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', @@ -254,11 +254,11 @@ class TensorizerAgent: def __init__(self, tensorizer_config: TensorizerConfig, quant_config: QuantizationConfig, **extra_kwargs): - if tensorizer_load_fail is not None: + if tensorizer_error_msg is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " - "to use this feature with `pip install vllm[tensorizer]`." - ) from tensorizer_load_fail + "to use this feature with `pip install vllm[tensorizer]`. " + "Error message: {}".format(tensorizer_error_msg)) self.tensorizer_config = tensorizer_config self.tensorizer_args = ( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5044cc1ef85fd..20098ebaeea32 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,4 +1,5 @@ import copy +import weakref from typing import List, Tuple import torch @@ -32,7 +33,7 @@ def init_device(self): super().init_device() self._proposer = Top1Proposer( - self, + weakref.proxy(self), self.device, self.vocab_size, max_proposal_len=self.max_model_len, diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index f18f9387f5b23..6cd50fcc1a041 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,3 +1,4 @@ +import weakref from typing import List, Optional, Tuple import torch @@ -37,7 +38,7 @@ def init_device(self): # Current only support Top1Proposer self._proposer = Top1Proposer( - self, + weakref.proxy(self), device=self.device, vocab_size=self.vocab_size, ) From 706588a77d2099b118f53a53ef2dd7f8c2de9ffc Mon Sep 17 00:00:00 2001 From: "Allen.Dou" Date: Fri, 10 May 2024 23:00:56 +0800 Subject: [PATCH 254/413] [Bugfix] Fix CLI arguments in OpenAI server docs (#4729) --- docs/requirements-docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 0e76763a87b7c..ed569816200ee 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -10,3 +10,4 @@ pydantic torch py-cpuinfo transformers +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args From 2e7796f2cf4537dd3b08d5de3aa8349f5db1a168 Mon Sep 17 00:00:00 2001 From: heeju-kim2 <157340754+heeju-kim2@users.noreply.github.com> Date: Sat, 11 May 2024 02:36:25 +0900 Subject: [PATCH 255/413] [Speculative decoding] CUDA graph support (#4295) Co-authored-by: Cade Daniel --- .../e2e/test_multistep_correctness.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94d71fb012727..d2da039e84c07 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -611,3 +611,40 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) From fcc2994be657a897dd0732928754749048520b28 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 10 May 2024 16:01:01 -0600 Subject: [PATCH 256/413] [CI] Nits for bad initialization of SeqGroup in testing (#4748) --- tests/core/test_block_manager.py | 13 +++++++++---- tests/core/utils.py | 11 +++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 9db58e075196d..22a9f0cf47d32 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -142,8 +142,10 @@ def test_append_slot_cow(): child = prompt.fork(new_seq_id=2) # Allocate space for the sequence group. - seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), - time.time(), time.perf_counter) + seq_group = SequenceGroup(request_id="1", + seqs=[prompt, child], + arrival_time=time.time(), + sampling_params=SamplingParams()) block_manager.allocate(seq_group) # Fork and append a new token id. We expect a COW to be scheduled. @@ -303,8 +305,11 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks parent = Sequence(1, "one two three", [0, 1, 2], block_size) - seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), - None) + seq_group = SequenceGroup(request_id="1", + seqs=[parent], + arrival_time=time.time(), + sampling_params=SamplingParams(), + lora_request=None) block_manager.allocate(seq_group) # assert the number of blocks allocated is correct diff --git a/tests/core/utils.py b/tests/core/utils.py index 22c1d3826dff4..8fb13177a2d6c 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -22,10 +22,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup( - request_id, [prompt], - SamplingParams(use_beam_search=use_beam_search, best_of=best_of), - time.time(), lora_request) + seq_group = SequenceGroup(request_id=request_id, + seqs=[prompt], + arrival_time=time.time(), + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + lora_request=lora_request) return prompt, seq_group From 4e12131089f192334f6e09c8fe5cd85af1e25327 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 May 2024 15:14:40 -0700 Subject: [PATCH 257/413] [Core][Test] fix function name typo in custom allreduce (#4750) --- tests/distributed/test_custom_all_reduce.py | 4 ++-- vllm/distributed/device_communicators/custom_all_reduce.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 3b1cd1773af19..308b874280f55 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port): init_test_distributed_environment(1, world_size, rank, distributed_init_port) - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: with custom_all_reduce.capture(): @@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port): distributed_init_port) sz = 1024 - custom_all_reduce.init_custom_all_reduce() + custom_all_reduce.init_custom_ar() fa = custom_all_reduce.get_handle() inp = torch.ones(sz, dtype=torch.float32, device=device) out = fa.all_reduce_unreg(inp) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index cc5f8166877ce..5d26254fb832a 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -52,6 +52,10 @@ def init_custom_ar() -> None: "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" " is set.") return + + # we only use a subset of GPUs here + # so we only need to check the nvlink connectivity of these GPUs + num_dev = world_size # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES From e254497b66dcd87038969b0ad34d34425edfc5fe Mon Sep 17 00:00:00 2001 From: Chang Su Date: Sat, 11 May 2024 11:30:37 -0700 Subject: [PATCH 258/413] [Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734) --- examples/offline_inference_embedding.py | 17 ++ examples/openai_embedding_client.py | 23 ++ requirements-dev.txt | 9 +- tests/conftest.py | 38 ++- .../output_processor/test_multi_step.py | 12 +- tests/entrypoints/openai/test_serving_chat.py | 1 + tests/entrypoints/test_openai_server.py | 96 ++++++- tests/models/test_embedding.py | 44 +++ tests/samplers/test_logits_processor.py | 6 +- tests/samplers/test_seeded_generate.py | 2 +- tests/spec_decode/utils.py | 6 +- tests/test_sequence.py | 12 +- vllm/__init__.py | 7 +- vllm/config.py | 15 + vllm/core/embedding_model_block_manager.py | 84 ++++++ vllm/core/interfaces.py | 5 + vllm/core/scheduler.py | 10 +- vllm/engine/arg_utils.py | 1 + vllm/engine/async_llm_engine.py | 158 +++++++++-- vllm/engine/llm_engine.py | 143 ++++++++-- vllm/entrypoints/llm.py | 150 ++++++++-- vllm/entrypoints/openai/api_server.py | 20 +- vllm/entrypoints/openai/protocol.py | 36 ++- vllm/entrypoints/openai/serving_embedding.py | 134 +++++++++ vllm/entrypoints/openai/serving_engine.py | 16 +- vllm/executor/gpu_executor.py | 10 +- vllm/model_executor/layers/pooler.py | 56 ++++ vllm/model_executor/layers/sampler.py | 7 +- vllm/model_executor/models/__init__.py | 12 +- vllm/model_executor/models/llama_embedding.py | 87 ++++++ vllm/model_executor/pooling_metadata.py | 69 +++++ vllm/outputs.py | 82 +++++- vllm/pooling_params.py | 20 ++ vllm/sequence.py | 89 +++++- vllm/spec_decode/util.py | 5 +- vllm/worker/embedding_model_runner.py | 266 ++++++++++++++++++ vllm/worker/model_runner.py | 25 +- vllm/worker/worker.py | 14 +- 38 files changed, 1627 insertions(+), 160 deletions(-) create mode 100644 examples/offline_inference_embedding.py create mode 100644 examples/openai_embedding_client.py create mode 100644 tests/models/test_embedding.py create mode 100644 vllm/core/embedding_model_block_manager.py create mode 100644 vllm/entrypoints/openai/serving_embedding.py create mode 100644 vllm/model_executor/layers/pooler.py create mode 100644 vllm/model_executor/models/llama_embedding.py create mode 100644 vllm/model_executor/pooling_metadata.py create mode 100644 vllm/pooling_params.py create mode 100644 vllm/worker/embedding_model_runner.py diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py new file mode 100644 index 0000000000000..7d5ef128bc8e0 --- /dev/null +++ b/examples/offline_inference_embedding.py @@ -0,0 +1,17 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.encode(prompts) +# Print the outputs. +for output in outputs: + print(output.outputs.embedding) # list of 4096 floats diff --git a/examples/openai_embedding_client.py b/examples/openai_embedding_client.py new file mode 100644 index 0000000000000..b73360fe15a24 --- /dev/null +++ b/examples/openai_embedding_client.py @@ -0,0 +1,23 @@ +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +responses = client.embeddings.create(input=[ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" +], + model=model) + +for data in responses.data: + print(data.embedding) # list of float of len 4096 diff --git a/requirements-dev.txt b/requirements-dev.txt index e6d375cbafa39..796c9e37d0230 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -19,12 +19,15 @@ pytest-forked pytest-asyncio pytest-rerunfailures pytest-shard -httpx + +# testing utils +awscli einops # required for MPT +httpx +peft requests ray -peft -awscli +sentence-transformers # required for embedding # Benchmarking aiohttp diff --git a/tests/conftest.py b/tests/conftest.py index 1f2ad1cbd7298..b8117a19c75d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -133,6 +133,10 @@ def example_long_prompts() -> List[str]: "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, } +_EMBEDDING_MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + class HfRunner: @@ -145,14 +149,7 @@ def __init__( assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] self.model_name = model_name - if model_name not in _VISION_LANGUAGE_MODELS: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = None - else: + if model_name in _VISION_LANGUAGE_MODELS: self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( model_name, torch_dtype=torch_dtype, @@ -162,6 +159,20 @@ def __init__( model_name, torch_dtype=torch_dtype, ) + elif model_name in _EMBEDDING_MODELS: + # Lazy init required for AMD CI + from sentence_transformers import SentenceTransformer + self.model = SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype).cuda() + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ).cuda() + self.processor = None if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -334,6 +345,9 @@ def generate_greedy_logprobs_limit( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts) + def __del__(self): del self.model cleanup() @@ -459,6 +473,14 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + def __del__(self): del self.model cleanup() diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 6da3da091db78..2bf4bf69da203 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput, - SequenceStatus) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -51,7 +51,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int): new_token_ids = list(range(num_new_tokens)) outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -103,7 +103,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int, new_token_ids = list(range(num_new_tokens)) outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -170,7 +170,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int, new_token_ids[eos_index] = eos_token_id outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, @@ -239,7 +239,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int, new_token_ids[eos_index] = eos_token_id outputs = [ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( parent_seq_id=seq.seq_id, diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 13e2e372cef33..74b49726734b5 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -14,6 +14,7 @@ class MockModelConfig: tokenizer_mode = "auto" max_model_len = 100 tokenizer_revision = None + embedding_mode = False @dataclass diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index e53e64a0c1ff8..c22ac4507658b 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -23,6 +23,7 @@ MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" # technically this needs Mistral-7B-v0.1 as base, but we're not testing # generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" @@ -121,7 +122,7 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ @@ -150,6 +151,25 @@ def server(zephyr_lora_files): ray.shutdown() +@pytest.fixture(scope="module") +def embedding_server(zephyr_lora_files): + ray.shutdown() + ray.init() + server_runner = ServerRunner.remote([ + "--model", + EMBEDDING_MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + @pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( @@ -890,5 +910,79 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + input = [ + "The chef prepared a delicious meal.", + ] + + # test single embedding + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 9 + assert embeddings.usage.total_tokens == 9 + + # test using token IDs + input = [1, 1, 1, 1, 1] + embeddings = await client.embeddings.create( + model=model_name, + input=input, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 5 + assert embeddings.usage.total_tokens == 5 + + +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, + model_name: str): + # test List[str] + inputs = [ + "The cat sat on the mat.", "A feline was resting on a rug.", + "Stars twinkle brightly in the night sky." + ] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 3 + assert len(embeddings.data[0].embedding) == 4096 + + # test List[List[int]] + inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], + [25, 32, 64, 77]] + embeddings = await client.embeddings.create( + model=model_name, + input=inputs, + encoding_format="float", + ) + assert embeddings.id is not None + assert embeddings.data is not None and len(embeddings.data) == 4 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 17 + assert embeddings.usage.total_tokens == 17 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py new file mode 100644 index 0000000000000..59bf054913f7c --- /dev/null +++ b/tests/models/test_embedding.py @@ -0,0 +1,44 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_llama_embedding.py`. +""" +import pytest +import torch +import torch.nn.functional as F + +MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.encode(example_prompts) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.encode(example_prompts) + del vllm_model + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 3788e9e9752ff..be4c2ea1b7810 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -36,14 +36,14 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( prompt=example_prompts[0], - sampling_params=params_with_logprobs, + params=params_with_logprobs, prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( prompt=example_prompts[1], - sampling_params=SamplingParams( + params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), @@ -53,7 +53,7 @@ def pick_vllm(token_ids, logits): # test grouped requests vllm_model.model._add_request( prompt=example_prompts[2], - sampling_params=SamplingParams(max_tokens=max_tokens), + params=SamplingParams(max_tokens=max_tokens), prompt_token_ids=None, ) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 3cd659cef58da..ce4501bbf71e5 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -60,7 +60,7 @@ def test_random_sample_with_seed( llm._add_request( prompt=prompt, prompt_token_ids=None, - sampling_params=params, + params=params, ) results = llm._run_engine(use_tqdm=False) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f288652d51556..d52b22c30bd43 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -7,8 +7,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, SamplerOutput, SequenceData, - SequenceGroupMetadata, SequenceGroupOutput, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine @@ -170,7 +170,7 @@ def create_sampler_output_list( return [ SamplerOutput(outputs=[ - SequenceGroupOutput( + CompletionSequenceGroupOutput( samples=[ SequenceOutput( output_token=token_id, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 53061278d5be4..b8ea1f6b77200 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,17 +1,17 @@ import pytest from tests.core.utils import create_dummy_prompt -from vllm.sequence import (SamplerOutput, SequenceData, SequenceGroupOutput, - SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, + SequenceData, SequenceOutput) @pytest.fixture def sample_outputs(): return [ - SequenceGroupOutput(samples=[ + CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) ], - prompt_logprobs=None) for i in range(5) + prompt_logprobs=None) for i in range(5) ] @@ -32,10 +32,10 @@ def test_sampler_output_getitem(sampler_output, sample_outputs): def test_sampler_output_setitem(sampler_output): - new_output = SequenceGroupOutput(samples=[ + new_output = CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) ], - prompt_logprobs=None) + prompt_logprobs=None) sampler_output[2] = new_output assert sampler_output[2] == new_output diff --git a/vllm/__init__.py b/vllm/__init__.py index 59810da3ca411..74674ca0d12af 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -6,7 +6,9 @@ from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster from vllm.model_executor.models import ModelRegistry -from vllm.outputs import CompletionOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, RequestOutput) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams __version__ = "0.4.2" @@ -17,9 +19,12 @@ "SamplingParams", "RequestOutput", "CompletionOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", "AsyncEngineArgs", "initialize_ray_cluster", + "PoolingParams", ] diff --git a/vllm/config.py b/vllm/config.py index 275814d72e6c3..fab9cfbf41a2d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) +from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron @@ -22,6 +23,7 @@ logger = init_logger(__name__) _GB = 1 << 30 +_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 class ModelConfig: @@ -126,6 +128,7 @@ def __init__( served_model_name) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -137,6 +140,11 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode + def _verify_embedding_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.embedding_mode = any( + ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] @@ -591,6 +599,7 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. + embedding_mode: Whether the running model is for embedding. """ def __init__( @@ -602,6 +611,7 @@ def __init__( num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, + embedding_mode: Optional[bool] = False, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -610,6 +620,10 @@ def __init__( # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. self.max_num_batched_tokens = 512 + elif embedding_mode: + # For embedding, choose specific value for higher throughput + self.max_num_batched_tokens = max( + max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. @@ -623,6 +637,7 @@ def __init__( self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill + self.embedding_mode = embedding_mode self._verify_args() diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py new file mode 100644 index 0000000000000..a09d79ec3c420 --- /dev/null +++ b/vllm/core/embedding_model_block_manager.py @@ -0,0 +1,84 @@ +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup + + +class EmbeddingModelBlockSpaceManager(BlockSpaceManager): + """An embedding version of BlockSpaceManager for use in environments + with embedding models where block management is not required. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return None # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: SequenceGroup) -> List[int]: + return None # type: ignore + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + pass diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index b2a5e41990f39..689cbc2179ee1 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -35,6 +35,11 @@ def get_block_space_manager_class(version: str): from vllm.core.block_manager_v2 import BlockSpaceManagerV2 return BlockSpaceManagerV2 + if version == "embedding": + from vllm.core.embedding_model_block_manager import ( + EmbeddingModelBlockSpaceManager) + return EmbeddingModelBlockSpaceManager + raise ValueError(f"Unknown version {version=}") @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 35e3db18f1c43..fb6e985b2f31c 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -270,9 +270,14 @@ def __init__( self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if self.scheduler_config.embedding_mode: + version = "embedding" + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( - version="v2" if self.scheduler_config. - use_v2_block_manager else "v1") + version) # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( @@ -968,6 +973,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: sampling_params=seq_group.sampling_params, block_tables=block_tables, do_sample=do_sample, + pooling_params=seq_group.pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5c2acbef13129..163723b4be364 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -574,6 +574,7 @@ def create_engine_config(self, ) -> EngineConfig: speculative_config.num_lookahead_slots), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, + embedding_mode=model_config.embedding_mode, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37a2dc77a3b50..a31f10b7748d3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -14,7 +14,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -47,15 +48,16 @@ def _raise_exception_on_finish( class AsyncStream: - """A stream of RequestOutputs for a request that can be - iterated over asynchronously.""" + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously.""" def __init__(self, request_id: str) -> None: self.request_id = request_id self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, Exception]) -> None: + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: if self._finished: return self._queue.put_nowait(item) @@ -71,7 +73,7 @@ def finished(self) -> bool: def __aiter__(self): return self - async def __anext__(self) -> RequestOutput: + async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]: result = await self._queue.get() if isinstance(result, Exception): raise result @@ -108,7 +110,8 @@ def propagate_exception(self, self.abort_request(rid) def process_request_output(self, - request_output: RequestOutput, + request_output: Union[RequestOutput, + EmbeddingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -196,7 +199,8 @@ def has_new_requests(self): class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" - async def step_async(self) -> List[RequestOutput]: + async def step_async( + self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -251,7 +255,7 @@ async def add_request_async( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -270,8 +274,8 @@ async def add_request_async( return self.add_request(request_id, prompt=prompt, + params=params, prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, arrival_time=arrival_time, lora_request=lora_request, multi_modal_data=multi_modal_data) @@ -511,7 +515,7 @@ async def add_request( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -528,9 +532,9 @@ async def add_request( max_log_len] logger.info( "Received request %s: prompt: %r, " - "sampling_params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, - sampling_params, shortened_token_ids, lora_request) + "params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, params, + shortened_token_ids, lora_request) if not self.is_running: if self.start_engine_loop: @@ -562,7 +566,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, prompt=prompt, - sampling_params=sampling_params, + params=params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, @@ -597,8 +601,8 @@ async def generate( multi_modal_data: Multi modal data per request. Yields: - The output `RequestOutput` objects from the LLMEngine for the - request. + The output `RequestOutput` objects from the LLMEngine + for the request. Details: - If the engine is not running, start the background loop, @@ -643,25 +647,123 @@ async def generate( >>> # Process and return the final output >>> ... """ - # Preprocess the request. - arrival_time = time.time() - - try: - stream = await self.add_request( + async for output in self.process_request( request_id, prompt, sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data, - ) + prompt_token_ids, + lora_request, + multi_modal_data, + ): + yield output + + async def encode( + self, + prompt: Optional[str], + pooling_params: PoolingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None + ) -> AsyncIterator[EmbeddingRequestOutput]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data per request. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "input": "What is LLM?", + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.encode( + >>> example_input["input"], + >>> PoolingParams(), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in self.process_request( + request_id, + prompt, + pooling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ): + yield output + + async def process_request( + self, + request_id: str, + prompt: Optional[str], + params: Union[SamplingParams, PoolingParams], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: + """Common logic to process requests with SamplingParams or + PoolingParams.""" + arrival_time = time.time() + + stream = await self.add_request( + request_id, + prompt, + params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) + try: async for request_output in stream: yield request_output except (Exception, asyncio.CancelledError) as e: - # If there is an exception or coroutine is cancelled, abort the - # request. self._abort(request_id) raise e diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b9938b045ba2b..46fa41030b4a1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -20,9 +20,12 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + MultiModalData, PoolerOutput, SamplerOutput, Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -169,7 +172,8 @@ def __init__( load_config=load_config, ) - self._initialize_kv_caches() + if not self.model_config.embedding_mode: + self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -354,7 +358,7 @@ def add_request( self, request_id: str, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -370,7 +374,8 @@ def add_request( request_id: The unique ID of the request. prompt: The prompt string. Can be None if prompt_token_ids is provided. - sampling_params: The sampling parameters for text generation. + params: Parameters for sampling or pooling. SamplingParams + for text generation. PoolingParams for pooling. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use @@ -404,13 +409,6 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.time() prompt_token_ids = self.encode_request( @@ -432,6 +430,50 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, eos_token_id, lora_request) + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time, + lora_request, + multi_modal_data, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time, + lora_request, + multi_modal_data, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() @@ -443,11 +485,35 @@ def add_request( self.generation_config_fields) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, multi_modal_data) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + multi_modal_data=multi_modal_data) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + pooling_params=pooling_params) + return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -484,13 +550,25 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() + def _process_sequence_group_outputs( + self, + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], + ) -> None: + seq_group.embeddings = outputs[0].embeddings + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + def _process_model_outputs( self, - output: List[SamplerOutput], + output: List[Union[SamplerOutput, PoolerOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[RequestOutput]: + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. @@ -510,6 +588,9 @@ def _process_model_outputs( seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) + if self.model_config.embedding_mode: + self._process_sequence_group_outputs(seq_group, outputs) + continue self.output_processor.process_prompt_logprob(seq_group, outputs) if seq_group_meta.do_sample: @@ -519,18 +600,19 @@ def _process_model_outputs( self.scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[RequestOutput] = [] + request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] for scheduled_seq_group in scheduled_seq_groups: seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) for seq_group in ignored_seq_groups: - request_output = RequestOutput.from_seq_group(seq_group) + request_output = RequestOutputFactory.create(seq_group) request_outputs.append(request_output) return request_outputs - def step(self) -> List[RequestOutput]: + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -570,7 +652,7 @@ def step(self) -> List[RequestOutput]: >>> while True: >>> if example_inputs: >>> req_id, prompt, sampling_params = example_inputs.pop(0) - >>> engine.add_request(str(req_id), prompt, sampling_params) + >>> engine.add_request(str(req_id),prompt,sampling_params) >>> >>> # continue the request processing >>> request_outputs = engine.step() @@ -637,12 +719,15 @@ def _get_stats( # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks - num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + gpu_cache_usage_sys = 0. + if num_total_gpu is not None: + num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks( + ) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. - if num_total_cpu > 0: + if num_total_cpu is not None and num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) @@ -716,8 +801,10 @@ def _get_stats( seq.get_output_len() for seq in seq_group.get_finished_seqs() ]) - best_of_requests.append(seq_group.sampling_params.best_of) - n_requests.append(seq_group.sampling_params.n) + if seq_group.sampling_params is not None: + best_of_requests.append( + seq_group.sampling_params.best_of) + n_requests.append(seq_group.sampling_params.n) finished_reason_requests.extend([ SequenceStatus.get_finished_reason(seq.status) for seq in seq_group.get_finished_seqs() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 71620139fba39..25f4428100b27 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,13 +6,17 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter +logger = init_logger(__name__) + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -164,8 +168,89 @@ def generate( multi_modal_data: Multi modal data. Returns: - A list of `RequestOutput` objects containing the generated - completions in the same order as the input prompts. + A list of `RequestOutput` objects containing the + generated completions in the same order as the input prompts. + """ + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + requests_data = self._validate_and_prepare_requests( + prompts, + sampling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + self._add_request(**request_data) + + return self._run_engine(use_tqdm) + + def encode( + self, + prompts: Optional[Union[str, List[str]]] = None, + pooling_params: Optional[Union[PoolingParams, + List[PoolingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + multi_modal_data: Multi modal data. + + Returns: + A list of `EmbeddingRequestOutput` objects containing the + generated embeddings in the same order as the input prompts. + """ + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + requests_data = self._validate_and_prepare_requests( + prompts, + pooling_params, + prompt_token_ids, + lora_request, + multi_modal_data, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + self._add_request(**request_data) + + return self._run_engine(use_tqdm) + + def _validate_and_prepare_requests( + self, + prompts: Optional[Union[str, List[str]]], + params: Union[Union[SamplingParams, PoolingParams], + List[Union[SamplingParams, + PoolingParams]]], # Unified parameter + prompt_token_ids: Optional[List[List[int]]] = None, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[dict]: + """Validates and prepares request data for adding to the engine. + + Ensures prompts and token IDs are consistent, and returns a list of + dictionaries with request data for further processing. """ if prompts is None and prompt_token_ids is None: raise ValueError("Either prompts or prompt_token_ids must be " @@ -188,40 +273,43 @@ def generate( assert prompt_token_ids is not None num_requests = len(prompt_token_ids) - if sampling_params is None: - # Use default sampling params. - sampling_params = SamplingParams() - - elif isinstance(sampling_params, - list) and len(sampling_params) != num_requests: - raise ValueError("The lengths of prompts and sampling_params " + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " "must be the same.") if multi_modal_data: multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. + requests_data = [] for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request( + + multi_modal_item = MultiModalData( + type=multi_modal_data.type, + data=multi_modal_data.data[i].unsqueeze(0), + ) if multi_modal_data else None + + requests_data.append({ + "prompt": prompt, - sampling_params[i] - if isinstance(sampling_params, list) else sampling_params, + "params": + params[i] if isinstance(params, list) else params, + "prompt_token_ids": token_ids, - lora_request=lora_request, - # Get ith image while maintaining the batch dim. - multi_modal_data=MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, - ) - return self._run_engine(use_tqdm) + "lora_request": + lora_request, + "multi_modal_data": + multi_modal_item, + }) + + return requests_data def _add_request( self, prompt: Optional[str], - sampling_params: SamplingParams, + params: Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, @@ -229,12 +317,14 @@ def _add_request( request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, - sampling_params, + params, prompt_token_ids, lora_request=lora_request, multi_modal_data=multi_modal_data) - def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + def _run_engine( + self, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -245,7 +335,7 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[RequestOutput] = [] + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() @@ -253,10 +343,12 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: if output.finished: outputs.append(output) if use_tqdm: - total_toks += (sum( - len(stp.token_ids) for stp in output.outputs)) - spd = total_toks / pbar.format_dict["elapsed"] - pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + total_toks += sum( + len(stp.token_ids) for stp in output.outputs) + spd = total_toks / pbar.format_dict["elapsed"] + pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" pbar.update(1) if use_tqdm: pbar.close() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 362f28d05c3bb..7cd51b959a0ea 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -22,9 +22,11 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, - CompletionRequest, ErrorResponse) + CompletionRequest, + EmbeddingRequest, ErrorResponse) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext @@ -32,6 +34,8 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion +openai_serving_embedding: OpenAIServingEmbedding + logger = init_logger(__name__) _running_tasks: Set[asyncio.Task] = set() @@ -123,6 +127,17 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return JSONResponse(content=generator.model_dump()) +@app.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await openai_serving_embedding.create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + else: + return JSONResponse(content=generator.model_dump()) + + if __name__ == "__main__": args = parse_args() @@ -190,7 +205,8 @@ async def authentication(request: Request, call_next): args.chat_template) openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, args.lora_modules) - + openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, + served_model_names) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3cd9ddad3b7b7..139c5716c7cea 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1,13 +1,14 @@ # Adapted from # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py import time -from typing import Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union import torch from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -363,6 +364,24 @@ def check_guided_decoding_count(cls, data): return data +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + dimensions: Optional[int] = None + user: Optional[str] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -416,6 +435,21 @@ class CompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class EmbeddingResponseData(BaseModel): + index: int + object: str = "embedding" + embedding: List[float] + + +class EmbeddingResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + class ChatMessage(OpenAIBaseModel): role: str content: str diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000000000..7a57be0c88915 --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,134 @@ +import time +from typing import AsyncIterator, List, Tuple + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, UsageInfo) +from vllm.entrypoints.openai.serving_completion import parse_prompt_format +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.logger import init_logger +from vllm.outputs import EmbeddingRequestOutput +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] + + +def request_output_to_embedding_response( + final_res_batch: List[EmbeddingRequestOutput], + request_id: str, + created_time: int, + model_name: str, +) -> EmbeddingResponse: + data = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + + embedding_data = EmbeddingResponseData( + index=idx, embedding=final_res.outputs.embedding) + data.append(embedding_data) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, + served_model_names: List[str]): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + self._check_embedding_mode(model_config.embedding_mode) + + async def create_embedding(self, request: EmbeddingRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # Return error for unsupported features. + if request.encoding_format == "base64": + return self.create_error_response( + "base64 encoding is not currently supported") + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + + # Schedule the request and get the result generator. + generators = [] + try: + prompt_is_tokens, prompts = parse_prompt_format(request.input) + pooling_params = request.to_pooling_params() + + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + prompt_formats = self._validate_prompt_and_tokenize( + request, prompt=prompt) + + prompt_ids, prompt_text = prompt_formats + + generators.append( + self.engine.generate(prompt_text, + pooling_params, + f"{request_id}-{i}", + prompt_token_ids=prompt_ids)) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator: AsyncIterator[Tuple[ + int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) + + # Non-streaming response + final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + + return response + + def _check_embedding_mode(self, embedding_mode: bool): + if not embedding_mode: + logger.warning( + "embedding_mode is False. Embedding API will not work.") + else: + logger.info("Activating the server engine with embedding enabled.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f10718c5f3d80..58a1c2f7e73fe 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -9,7 +9,8 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest, ErrorResponse, + CompletionRequest, + EmbeddingRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger @@ -165,7 +166,8 @@ def _maybe_get_lora( def _validate_prompt_and_tokenize( self, - request: Union[ChatCompletionRequest, CompletionRequest], + request: Union[ChatCompletionRequest, CompletionRequest, + EmbeddingRequest], prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None @@ -191,6 +193,16 @@ def _validate_prompt_and_tokenize( prompt_ids) token_num = len(input_ids) + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, EmbeddingRequest): + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.", ) + return input_ids, input_text + if request.max_tokens is None: if token_num >= self.max_model_len: raise ValueError( diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index fa3480fa64837..2b72b31b5f070 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -123,8 +123,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest + ) -> List[Union[SamplerOutput, PoolerOutput]]: output = self.driver_worker.execute_model(execute_model_req) return output @@ -150,7 +150,7 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: + ) -> List[Union[SamplerOutput, PoolerOutput]]: output = await make_async(self.driver_worker.execute_model )(execute_model_req=execute_model_req, ) return output diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 0000000000000..445b30b8c6e9b --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,56 @@ +from enum import IntEnum + +import torch +import torch.nn as nn + +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools specific information from hidden states based on metadata.""" + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + if self.pooling_type == PoolingType.LAST: + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_flat_indices] + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e52e350d2726f..c8bab46c83eca 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,8 +10,9 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, - SamplerOutput, SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceOutput) # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] @@ -1019,7 +1020,7 @@ def _build_sampler_output( seq_outputs.append( SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d5263b500fe0f..6aec104be8da4 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) # Architecture -> (module, class). -_MODELS = { +_GENERATION_MODELS = { "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b @@ -58,6 +58,12 @@ "XverseForCausalLM": ("xverse", "XverseForCausalLM"), } +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), +} + +_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} + # Architecture -> type. # out of tree models _OOT_MODELS: Dict[str, Type[nn.Module]] = {} @@ -114,6 +120,10 @@ def register_model(model_arch: str, model_cls: Type[nn.Module]): global _OOT_MODELS _OOT_MODELS[model_arch] = model_cls + @staticmethod + def is_embedding_model(model_arch: str) -> bool: + return model_arch in _EMBEDDING_MODELS + __all__ = [ "ModelRegistry", diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py new file mode 100644 index 0000000000000..8f1c77da50d96 --- /dev/null +++ b/vllm/model_executor/models/llama_embedding.py @@ -0,0 +1,87 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + + +class LlamaEmbeddingModel(nn.Module): + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py new file mode 100644 index 0000000000000..b86cafce85d12 --- /dev/null +++ b/vllm/model_executor/pooling_metadata.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + + +class PoolingMetadata: + """Metadata for pooling operations in the Pooler layer. + + This class holds the necessary information for pooling operations, + providing context for how to perform pooling and other related operations. + + Attributes: + seq_groups: List of (seq_ids, pooling_params). + seq_data: A mapping of sequence ID to additional sequence data. + prompt_lens: List of the lengths of each prompt. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], PoolingParams]], + seq_data: Dict[int, Any], # Specific data related to sequences + prompt_lens: List[int], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + + def __repr__(self) -> str: + return ("PoolingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens})") + + +@dataclass +class PoolingTensors: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + + @classmethod + def from_pooling_metadata( + cls, + pooling_metadata: "PoolingMetadata", + device: torch.device, + ) -> "PoolingTensors": + """ + Create PoolingTensors from PoolingMetadata. + + Args: + pooling_metadata: PoolingMetadata instance to convert. + device: Device to store the tensors. + """ + # Convert prompt lengths to tensor + pin_memory = is_pin_memory_available() + + prompt_lens_t = torch.tensor( + pooling_metadata.prompt_lens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + + return cls(prompt_lens=prompt_lens_t.to(device=device, + non_blocking=True), ) diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd2..f9bce9e683f22 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -57,8 +57,27 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +class EmbeddingOutput: + """The output data of one completion output of a request. + + Args: + embedding: The embedding vector, which is a list of floats. The + length of vector depends on the model as listed in the embedding guide. + """ + + def __init__( + self, + embedding: List[float], + ) -> None: + self.embedding = embedding + + def __repr__(self) -> str: + return (f"EmbeddingOutput(" + f"embedding={len(self.embedding)}") + + class RequestOutput: - """The output data of a request to the LLM. + """The output data of a completion request to the LLM. Args: request_id: The unique ID of the request. @@ -93,6 +112,9 @@ def __init__( @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + if seq_group.sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -148,3 +170,61 @@ def __repr__(self) -> str: f"finished={self.finished}, " f"metrics={self.metrics}, " f"lora_request={self.lora_request})") + + +class EmbeddingRequestOutput: + """ + The output data of an embedding request to the LLM. + + Args: + request_id (str): A unique identifier for the embedding request. + outputs (EmbeddingOutput): The embedding results for the given input. + prompt_token_ids (List[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the embedding is completed. + """ + + def __init__(self, request_id: str, outputs: 'EmbeddingOutput', + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @classmethod + def from_seq_group(cls, + seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + if seq_group.embeddings is None: + raise ValueError( + "Embeddings are missing in seq_group for EmbeddingRequest.") + output = EmbeddingOutput(seq_group.embeddings) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return cls(seq_group.request_id, output, prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an EmbeddingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the EmbeddingRequestOutput instance. + """ + return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + +class RequestOutputFactory: + + @staticmethod + def create(seq_group): + # Determine the type based on a condition, for example: + if hasattr(seq_group, + 'embeddings') and seq_group.embeddings is not None: + return EmbeddingRequestOutput.from_seq_group(seq_group) + else: + return RequestOutput.from_seq_group(seq_group) diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py new file mode 100644 index 0000000000000..3b95d73ddc2c5 --- /dev/null +++ b/vllm/pooling_params.py @@ -0,0 +1,20 @@ +from typing import Any, Optional + + +class PoolingParams: + """Pooling parameters for pooling. + + Attributes: + additional_data: Any additional data needed for pooling. + """ + + def __init__(self, additional_data: Optional[Any] = None): + self.additional_data = additional_data + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return PoolingParams(additional_data=self.additional_data, ) + + def __repr__(self) -> str: + return (f"PoolingParams(" + f"additional_metadata={self.additional_data})") diff --git a/vllm/sequence.py b/vllm/sequence.py index 3cebb85b49d27..46ac33b7ecabd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,11 +1,13 @@ """Sequence and its related classes.""" import copy import enum +from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -375,12 +377,12 @@ class SequenceGroupState: class MultiModalData: """Multi modal request. - + Args: type: The data type. data: The actual data. The required shape and semantic meaning of it depends on the vision - language config of the hosted model. + language config of the hosted model. See `VisionLanguageConfig` in `config.py`. """ @@ -402,16 +404,22 @@ class SequenceGroup: arrival_time: The arrival time of the request. lora_request: LoRA request. multi_modal_data: Multi modal data associated with the request. + embeddings: The embeddings vectors of the prompt of the sequence group + for an embedding model. + pooling_params: The pooling parameters used to generate the pooling + for an embedding model. """ def __init__( self, request_id: str, seqs: List[Sequence], - sampling_params: SamplingParams, arrival_time: float, + sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + embeddings: Optional[List[float]] = None, + pooling_params: Optional[PoolingParams] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -425,6 +433,8 @@ def __init__( self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() self.multi_modal_data = multi_modal_data + self.embeddings = embeddings + self.pooling_params = pooling_params @property def prompt(self) -> str: @@ -479,12 +489,13 @@ def set_finished_time(self, time: Optional[float]) -> None: def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" - if self.sampling_params.use_beam_search: + if self.sampling_params and self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. return self.sampling_params.best_of else: - if self.sampling_params.best_of > self.num_seqs(): + if (self.sampling_params + and self.sampling_params.best_of > self.num_seqs()): # At prompt stage, the sequence group is not yet filled up # and only have one sequence running. However, in the # generation stage, we will have `best_of` sequences running. @@ -555,7 +566,7 @@ def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) def is_prefill(self) -> bool: - # Every sequences should be in the same stage. + # Every sequence should be in the same stage. return self.get_seqs()[0].is_prefill() def __repr__(self) -> str: @@ -594,6 +605,7 @@ def __init__( sampling_params: SamplingParams, block_tables: Dict[int, List[int]], do_sample: bool = True, + pooling_params: Optional[PoolingParams] = None, token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, @@ -605,6 +617,7 @@ def __init__( self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.pooling_params = pooling_params self.lora_request = lora_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data @@ -669,8 +682,20 @@ def __eq__(self, other: object) -> bool: return equal and log_probs_equal -class SequenceGroupOutput: - """The model output associated with a sequence group.""" +class SequenceGroupOutput(ABC): + """The base class for model outputs associated with a sequence group.""" + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass + + +class CompletionSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with a completion sequence group.""" def __init__( self, @@ -682,26 +707,45 @@ def __init__( self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " + return (f"CompletionSequenceGroupOutput(samples={self.samples}, " f"prompt_logprobs={self.prompt_logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceGroupOutput): + if not isinstance(other, CompletionSequenceGroupOutput): raise NotImplementedError() return (self.samples == other.samples and self.prompt_logprobs == other.prompt_logprobs) +class EmbeddingSequenceGroupOutput(SequenceGroupOutput): + """The model output associated with an embedding sequence group.""" + + def __init__( + self, + embeddings: List[float], + ) -> None: + self.embeddings = embeddings + + def __repr__(self) -> str: + return (f"EmbeddingSequenceGroupOutput(" + f"embeddings_shape={len(self.embeddings)})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EmbeddingSequenceGroupOutput): + raise NotImplementedError() + return self.embeddings == other.embeddings + + @dataclass class SamplerOutput: """For each sequence group, we generate a list of SequenceOutput object, each of which contains one possible candidate for the next token. - This datastructure implements methods so it can be used like a list, but + This data structure implements methods, so it can be used like a list, but also has optional fields for device tensors. """ - outputs: List[SequenceGroupOutput] + outputs: List[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None @@ -742,6 +786,27 @@ def __repr__(self) -> str: f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") +@dataclass +class PoolerOutput: + """The output from a pooling operation in the embedding model.""" + outputs: List[EmbeddingSequenceGroupOutput] + + spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + @dataclass class ExecuteModelRequest: """The model execution request.""" diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index d6f80c82b80bf..4dc6c49eb58d2 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -4,7 +4,8 @@ import torch -from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SamplerOutput, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput) SeqId = int @@ -94,7 +95,7 @@ def create_sequence_group_output( for topk_logprob_index, _ in enumerate(topk_token_ids) }) - return SequenceGroupOutput( + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py new file mode 100644 index 0000000000000..2d3f160c60dc1 --- /dev/null +++ b/vllm/worker/embedding_model_runner.py @@ -0,0 +1,266 @@ +from typing import Dict, List, Optional, Set, Tuple + +import torch + +from vllm.attention import AttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata +from vllm.worker.model_runner import BatchType, ModelRunner + +logger = init_logger(__name__) + + +class EmbeddingModelRunner(ModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + vision_language_config: Optional[VisionLanguageConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[torch.Tensor], + ) -> Optional[PoolerOutput]: + (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Currently cuda graph is only supported by the decode phase. + prefill_meta = attn_metadata.prefill_metadata + decode_meta = attn_metadata.decode_metadata + if prefill_meta is None and decode_meta.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + hidden_states = model_executable(**execute_model_kwargs) + + return self.model.pooler(hidden_states=hidden_states, + pooling_metadata=pooling_metadata) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, + Set[LoRARequest], LoRAMapping, torch.Tensor]: + if self.is_driver_worker: + prefill_reqs = [] + decode_reqs = [] + for seq_group_meta in seq_group_metadata_list: + if seq_group_meta.is_prompt: + prefill_reqs.append(seq_group_meta) + else: + decode_reqs.append(seq_group_meta) + + # Prepare input tensors. + ( + input_tokens, + input_positions, + prefill_attn_metadata, + prompt_lens, + subquery_lens, + lora_index_mapping, + lora_prompt_mapping, + lora_requests, + multi_modal_input, + slot_mapping, + ) = self._prepare_prompt(prefill_reqs) + ( + decode_input_tokens, + decode_input_positions, + decode_attn_metadata, + decode_lora_index_mapping, + decode_lora_prompt_mapping, + decode_lora_requests, + decode_slot_mapping, + ) = self._prepare_decode(decode_reqs) + + # Prepare PoolingMetadata + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + prompt_lens) + + if not self.scheduler_config.chunked_prefill_enabled: + assert (len(prefill_reqs) and len(decode_reqs)) == 0 + + num_prefills = len(prompt_lens) + num_prefill_tokens = len(input_tokens) + num_decode_tokens = len(decode_input_tokens) + + # Coalesce tensors. Note that attn_metadata is currently not + # coalesced for simplicity. + input_tokens.extend(decode_input_tokens) + input_positions.extend(decode_input_positions) + slot_mapping.extend(decode_slot_mapping) + lora_index_mapping.extend(decode_lora_index_mapping) + lora_prompt_mapping.extend(decode_lora_prompt_mapping) + lora_requests.update(decode_lora_requests) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + # Broadcast the metadata. + # If batch contains both prefill and decode, it sends 2 broadcasts. + # If it only contains 1 type, it triggers a single broadcast. + if (prefill_attn_metadata is not None + and decode_attn_metadata is not None): + batch_type = BatchType.MIXED + elif prefill_attn_metadata is not None: + batch_type = BatchType.PREFILL + else: + batch_type = BatchType.DECODE + + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, + "num_prefill_tokens": num_prefill_tokens, + "num_decode_tokens": num_decode_tokens, + "slot_mapping": slot_mapping, + "num_prefills": num_prefills, + "batch_type": batch_type, + } + if prefill_attn_metadata is not None: + metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) + else: + assert decode_attn_metadata is not None + metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + + # Broadcast decode attn metadata for mixed batch type. + # The additional broadcast costs 300us overhead on 4 A10 GPUs. + # We can potentially reduce the overhead by coelescing tensors. + if batch_type == BatchType.MIXED: + assert decode_attn_metadata is not None + metadata_dict = decode_attn_metadata.asdict_zerocopy() + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + slot_mapping = metadata_dict.pop("slot_mapping") + num_prefills = metadata_dict.pop("num_prefills") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + multi_modal_input = metadata_dict.pop("multi_modal_input") + num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") + num_decode_tokens = metadata_dict.pop("num_decode_tokens") + batch_type = metadata_dict.pop("batch_type") + + # Create an attention metadata. + prefill_attn_metadata = None + decode_attn_metadata = None + if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: + prefill_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + else: + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + pooling_metadata = PoolingMetadata(seq_groups=None, + seq_data=None, + prompt_lens=None) + + # if it is a mixed batch, decode attn_metadata is broadcasted + # separately. + if batch_type == BatchType.MIXED: + metadata_dict = broadcast_tensor_dict(src=0) + decode_attn_metadata = self.attn_backend.make_metadata( + **metadata_dict) + + attn_metadata = AttentionMetadata( + num_prefills=num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + prefill_metadata=prefill_attn_metadata, + decode_metadata=decode_attn_metadata, + kv_cache_dtype=self.kv_cache_dtype, + ) + + return (input_tokens, input_positions, attn_metadata, pooling_metadata, + lora_requests, lora_mapping, multi_modal_input) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3fc76c6142165..21d76fd531e49 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,6 +1,6 @@ import time from enum import IntEnum -from typing import Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch @@ -287,18 +287,18 @@ def _prepare_prompt( lora_requests.add(seq_group_metadata.lora_request) lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend( - [lora_id] * - (seq_len - context_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + lora_prompt_mapping.extend([lora_id] * ( + seq_len - context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( seq_group_metadata.multi_modal_data.data) - if seq_group_metadata.block_tables is None: + if _is_block_tables_empty(seq_group_metadata.block_tables): # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. + # In embeddings, the block tables are {seq_id: None}. slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue @@ -813,7 +813,6 @@ def profile_run(self) -> None: sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1139,3 +1138,15 @@ def _prepare_fake_inputs( prompt_tokens = [0] * seq_len fake_image_input = None return SequenceData(prompt_tokens), fake_image_input + + +def _is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + if isinstance(block_tables, dict) and all( + value is None for value in block_tables.values()): + return True + return False diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0ca9c2b64cf30..e4fbc877b8c9f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch import torch.distributed @@ -16,8 +16,9 @@ init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -68,7 +69,9 @@ def __init__( assert not self.lora_config, ( "To be tested: vision language model with LoRA settings.") - self.model_runner = ModelRunner( + ModelRunnerClass = (EmbeddingModelRunner if + self.model_config.embedding_mode else ModelRunner) + self.model_runner = ModelRunnerClass( model_config, parallel_config, scheduler_config, @@ -83,7 +86,8 @@ def __init__( # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: CacheEngine - self.gpu_cache: List[torch.Tensor] + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[torch.tensor]] = None def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -209,7 +213,7 @@ def cache_swap( def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: + ) -> List[Union[SamplerOutput, PoolerOutput]]: if execute_model_req is None: seq_group_metadata_list = None From 6eaccb7353cfe84d77981da726f6d82a8aefd2be Mon Sep 17 00:00:00 2001 From: Yikang Shen Date: Sun, 12 May 2024 00:27:24 -0400 Subject: [PATCH 259/413] [Model] Add support for IBM Granite Code models (#4636) --- vllm/model_executor/models/llama.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f6d7fc8733fce..127e4612b2e40 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -58,15 +58,16 @@ def __init__( intermediate_size: int, hidden_act: str, quant_config: Optional[QKVParallelLinear] = None, + bias: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, - bias=False, + bias=bias, quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, - bias=False, + bias=bias, quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -209,6 +210,7 @@ def __init__( intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -348,6 +350,8 @@ def __init__( # compatibility if not lora_config else lora_config.lora_vocab_padding_size, ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, From a709e87a4f35c1637d2bbea4bc2c9e5fe7fd70b5 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sun, 12 May 2024 20:46:31 -0400 Subject: [PATCH 260/413] [CI/Build] Tweak Marlin Nondeterminism Issues (#4713) --- tests/models/test_gptq_marlin.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index 4d73843f970c4..b1c2b88bc99af 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -1,13 +1,11 @@ """Compares the outputs of gptq vs gptq_marlin Note: GPTQ and Marlin do not have bitwise correctness. As a result, in this test, we just confirm that the top selected tokens of the -Marlin/GPTQ models are in the top 3 selections of each other. +Marlin/GPTQ models are in the top 5 selections of each other. Note: Marlin internally uses locks to synchronize the threads. This can result in very slight nondeterminism for Marlin. As a result, we re-run the test up to 3 times to see if we pass. -Note: This test currently fails running with --forked with the following: - RuntimeError: Cannot re-initialize CUDA in forked subprocess. - To use CUDA with multiprocessing, you must use the 'spawn' start method + Run `pytest tests/models/test_gptq_marlin.py`. """ import os @@ -49,7 +47,7 @@ ] -@pytest.mark.flaky(reruns=2) +@pytest.mark.flaky(reruns=3) @pytest.mark.skipif(gptq_marlin_not_supported, reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) @@ -75,7 +73,7 @@ def test_models( tensor_parallel_size=1) gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + example_prompts[:-1], max_tokens, num_logprobs) del gptq_marlin_model # Run gptq. @@ -85,7 +83,7 @@ def test_models( quantization="gptq", max_model_len=MAX_MODEL_LEN, tensor_parallel_size=1) - gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts[:-1], max_tokens, num_logprobs) del gptq_model From a7be4d00725db5ae4f738f70c3a89fd9dedaf7ec Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 12 May 2024 20:47:47 -0400 Subject: [PATCH 261/413] [CORE] Improvement in ranks code (#4718) --- vllm/model_executor/layers/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c8bab46c83eca..a84f562909d50 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -681,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] - return (x > vals[:, None]).long().sum(1).add_(1) + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) def _get_logprobs( From 702bee461f448b0186eb9d673baad29fd923c884 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 May 2024 17:47:59 -0700 Subject: [PATCH 262/413] [Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754) --- tests/distributed/test_comm_ops.py | 22 +- tests/distributed/test_custom_all_reduce.py | 87 +++-- tests/distributed/test_pynccl.py | 4 +- vllm/distributed/communication_op.py | 45 ++- .../device_communicators/custom_all_reduce.py | 320 ++++++++++-------- .../device_communicators/pynccl.py | 4 +- vllm/distributed/parallel_state.py | 28 +- vllm/test_utils.py | 17 +- vllm/worker/model_runner.py | 15 +- vllm/worker/worker.py | 11 +- 10 files changed, 327 insertions(+), 226 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 9a7a1f07e1b8d..a4423bbfddf46 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -16,7 +16,7 @@ @ray.remote(num_gpus=1, max_calls=1) -def all_reduce_test_worker(tensor_parallel_size: int, rank: int, +def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_elements = 8 all_tensors = [ torch.arange(num_elements, dtype=torch.float32, device="cuda") * - (r + 1) for r in range(tensor_parallel_size) + (r + 1) for r in range(tp_size) ] expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) t = all_tensors[rank] @@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def all_gather_test_worker(tensor_parallel_size: int, rank: int, +def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) num_dimensions = 3 tensor_size = list(range(2, num_dimensions + 2)) @@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, all_tensors = [ torch.arange(total_size, dtype=torch.float32, device="cuda").reshape(tensor_size) * (r + 1) - for r in range(tensor_parallel_size) + for r in range(tp_size) ] expected = torch.cat(all_tensors, dim=all_gather_dimension) t = all_tensors[rank] @@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @ray.remote(num_gpus=1, max_calls=1) -def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, +def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs @@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, tensor_parallel_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) test_dict = { # device tensor @@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("test_target", [ all_reduce_test_worker, all_gather_test_worker, broadcast_tensor_dict_test_worker ]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) +def test_multi_process_tensor_parallel(tp_size, test_target): + multi_process_tensor_parallel(tp_size, 1, test_target) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 308b874280f55..bdca031e39be1 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -6,8 +6,10 @@ import torch import torch.distributed as dist -from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import ( # noqa + graph_capture, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_ca_communicator) from vllm.test_utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -18,17 +20,36 @@ @ray.remote(num_gpus=1, max_calls=1) -def graph_allreduce(world_size, rank, distributed_init_port): +def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) - custom_all_reduce.init_custom_ar() + group = get_tensor_model_parallel_group() + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 + for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with custom_all_reduce.capture(): + with graph_capture(): # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port): torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): - out1 = tensor_model_parallel_all_reduce(inp1) - # the input buffer is immediately modified to test - # synchronization - dist.all_reduce(inp1) - out2 = tensor_model_parallel_all_reduce(inp2) - dist.all_reduce(inp2) + for i in range(num_communication): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) graph.replay() assert torch.allclose(out1, inp1) assert torch.allclose(out2, inp2) @ray.remote(num_gpus=1, max_calls=1) -def eager_allreduce(world_size, rank, distributed_init_port): +def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): del os.environ["CUDA_VISIBLE_DEVICES"] device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - init_test_distributed_environment(1, world_size, rank, + init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port) + # we use the first group to communicate once + # and the second group to communicate twice + # and so on + # this is used to demonstrate that each group can + # communicate independently + num_communication = rank // tp_size + 1 sz = 1024 - custom_all_reduce.init_custom_ar() - fa = custom_all_reduce.get_handle() + fa = get_tp_ca_communicator() inp = torch.ones(sz, dtype=torch.float32, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) - out = fa.all_reduce_unreg(inp) - assert torch.allclose(out, inp * world_size) + out = inp + for _ in range(num_communication): + out = fa.all_reduce_unreg(out) + assert torch.allclose(out, inp * (tp_size**num_communication)) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1, 2]) @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) -def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): - multi_process_tensor_parallel(tensor_parallel_size, test_target) - - -if __name__ == "__main__": - multi_process_tensor_parallel(2, graph_allreduce) +def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index b3e30a0434423..a0f7500bf0ee9 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( # noqa - graph_capture_mode, tensor_model_parallel_all_reduce) + graph_mode, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_capture_mode(): + with graph_mode(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 32ab5694e5390..9cc776f8324f2 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -9,12 +9,13 @@ get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_ca_communicator, get_tp_pynccl_communicator) @contextmanager -def graph_capture_mode(): - # In graph capture, we have to be very careful about the collective +def graph_mode(): + # In graph mode, we have to be very careful about the collective # operations. The current status is: # allreduce \ Mode | Eager | Graph | # -------------------------------------------- @@ -24,10 +25,32 @@ def graph_capture_mode(): # # Note that custom allreduce will have a runtime check, if the tensor size # is too large, it will fallback to the next available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. pynccl_comm = get_tp_pynccl_communicator() - assert pynccl_comm is not None - with pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()): + if pynccl_comm is None: + context = nullcontext() + else: + context = pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()) + with context: + yield + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should include the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. + """ + ca_comm = get_tp_ca_communicator() + context = nullcontext() if ca_comm is None else ca_comm.capture() + with context: yield @@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: TLDR: always assume this function modifies its input, but use the return value as the output. """ - from vllm.distributed.device_communicators.custom_all_reduce import ( - custom_all_reduce) + ca_comm = get_tp_ca_communicator() # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: return input_ - out = custom_all_reduce(input_) - if out is not None: - return out + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out pynccl_comm = get_tp_pynccl_communicator() if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 5d26254fb832a..30ee9d1f8a1e9 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,154 +1,42 @@ from contextlib import contextmanager -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as dist +from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.parallel_state import ( + get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger try: import pynvml from vllm._C import custom_ar + + @contextmanager + def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + except ImportError: # For AMD GPUs custom_ar = None pynvml = None -logger = init_logger(__name__) + @contextmanager + def _nvml(): + try: + yield + finally: + pass -_CA_HANDLE: Optional["CustomAllreduce"] = None -_IS_CAPTURING = False -_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] - - -def init_custom_ar() -> None: - from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - global _CA_HANDLE - if _CA_HANDLE is not None: - return - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - if world_size == 1: - # No need to initialize custom allreduce for single GPU case. - return - - if world_size not in _SUPPORTED_WORLD_SIZES: - logger.warning( - "Custom allreduce is disabled due to an unsupported world size: " - "%d. Supported world sizes: %s. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.", world_size, - str(_SUPPORTED_WORLD_SIZES)) - return - num_dev = torch.cuda.device_count() - # note: num dev can be larger than world_size if we're only using - # first few GPUs - if num_dev < world_size: - logger.warning( - "Cannot test GPU P2P because not all GPUs are visible to the " - "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'" - " is set.") - return - - # we only use a subset of GPUs here - # so we only need to check the nvlink connectivity of these GPUs - num_dev = world_size - # test nvlink first, this will filter out most of the cases - # where custom allreduce is not supported - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES - if cuda_visible_devices: - device_ids = list(map(int, cuda_visible_devices.split(","))) - else: - device_ids = list(range(num_dev)) - # this checks hardware and driver support for NVLink - full_nvlink = _is_full_nvlink(device_ids) - if world_size > 2 and not full_nvlink: - logger.warning( - "Custom allreduce is disabled because it's not supported on more" - " than two PCIe-only GPUs. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - # test P2P capability, this checks software/cudaruntime support - # this is expensive to compute at the first time - # then we cache the result - if not _can_p2p(rank, world_size): - logger.warning( - "Custom allreduce is disabled because your platform lacks GPU P2P" - " capability or P2P test failed. To silence this warning, specify" - " disable_custom_all_reduce=True explicitly.") - return - _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) - - -def begin_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = True - - -def end_capture() -> None: - global _IS_CAPTURING - _IS_CAPTURING = False - - -def is_capturing() -> bool: - return _IS_CAPTURING and _CA_HANDLE is not None - - -def get_handle() -> Optional["CustomAllreduce"]: - return _CA_HANDLE - - -def is_initialized() -> bool: - return _CA_HANDLE is not None - - -@contextmanager -def capture(): - try: - begin_capture() - yield - finally: - end_capture() - handle = get_handle() - if handle is not None: - handle.register_graph_buffers() - - -def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: - ca_handle = get_handle() - # when custom allreduce is disabled, this will be None - if ca_handle is None: - return None - if is_capturing(): - if torch.cuda.is_current_stream_capturing(): - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_reg(input) - else: - if ca_handle.should_custom_ar(input): - # if warm up, mimic the allocation pattern - # since custom allreduce is out-of-place - return torch.empty_like(input) - else: - # note: outside of cuda graph context, - # custom allreduce incurs a cost of cudaMemcpy, which should - # be small(<=1% of overall latency) compared to the performance - # gains of using custom kernels - if ca_handle.should_custom_ar(input): - return ca_handle.all_reduce_unreg(input) - - return None - - -@contextmanager -def _nvml(): - try: - pynvml.nvmlInit() - yield - finally: - pynvml.nvmlShutdown() + +logger = init_logger(__name__) @_nvml() @@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool: class CustomAllreduce: + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + # max_size: max supported allreduce size def __init__(self, - rank, - world_size, - full_nvlink, + group: Optional[ProcessGroup] = None, + device: Optional[Union[int, str, torch.device]] = None, max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if custom_ar is None: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + group = group or get_tensor_model_parallel_cpu_group() + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if device is None: + local_rank = get_local_rank() + device = torch.device(f"cuda:{local_rank}") + elif isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch.cuda.device_count())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + full_nvlink = _is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False # buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. self.meta = torch.zeros(custom_ar.meta_size() + max_size, dtype=torch.uint8, - device="cuda") + device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -211,8 +189,9 @@ def __init__(self, # needs less than 10000 of registered tuples. self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, - device="cuda") + device=self.device) self.max_size = max_size + self.rank = rank self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink @@ -221,6 +200,21 @@ def __init__(self, self.full_nvlink) self.register_buffer(self.buffer) + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + def _get_ipc_meta(self, inp: torch.Tensor): data = inp.untyped_storage()._share_cuda_() shard_data = ( @@ -230,14 +224,29 @@ def _get_ipc_meta(self, inp: torch.Tensor): return self._gather_ipc_meta(shard_data) def _gather_ipc_meta(self, shard_data): - all_data: List[Optional[Any]] = [None] * self.world_size - dist.all_gather_object(all_data, shard_data) + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. handles = [] offsets = [] for i in range(len(all_data)): - handles.append(all_data[i][0]) # type: ignore - offsets.append(all_data[i][1]) # type: ignore + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore return handles, offsets def register_buffer(self, inp: torch.Tensor): @@ -269,8 +278,31 @@ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + # when custom allreduce is disabled, this will be None + if self.disabled: + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + if self.should_custom_ar(input): + return self.all_reduce_reg(input) + else: + if self.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if self.should_custom_ar(input): + return self.all_reduce_unreg(input) + + return None + def close(self): - if self._ptr: + if not self.disabled and self._ptr: custom_ar.dispose(self._ptr) self._ptr = 0 diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 168d4cc2df8a6..092a0910329ad 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -96,8 +96,10 @@ def __init__( self.stream = torch.cuda.Stream() # A small all_reduce for warmup. - self.all_reduce(torch.zeros(1, device=device)) + data = torch.zeros(1, device=device) + self.all_reduce(data) self.stream.synchronize() + del data # by default it is disabled, e.g. in profiling models and prefill phase. # to use it, use under `with obj.change_state(enable=True)`, usually diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 5075da11bb1b8..d24104e3ed276 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -13,10 +13,13 @@ logger = init_logger(__name__) +_ENABLE_CUSTOM_ALL_REDUCE = True + # Tensor model parallel group that the current rank belongs to. _TP_DEVICE_GROUP: Optional[ProcessGroup] = None _TP_CPU_GROUP: Optional[ProcessGroup] = None _TP_PYNCCL_COMMUNICATOR = None +_TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None @@ -47,11 +50,21 @@ _LOCAL_RANK = -1 +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + def get_tp_pynccl_communicator(): global _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR +def get_tp_ca_communicator(): + global _TP_CA_COMMUNICATOR + return _TP_CA_COMMUNICATOR + + def get_local_rank(): global _LOCAL_RANK return _LOCAL_RANK @@ -100,6 +113,9 @@ def init_distributed_environment( if torch.cuda.is_available(): data = data.to(device=f"cuda:{local_rank}") torch.distributed.all_reduce(data) + if torch.cuda.is_available(): + torch.cuda.synchronize() + del data def initialize_model_parallel( @@ -149,7 +165,8 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): @@ -168,6 +185,15 @@ def initialize_model_parallel( device=_LOCAL_RANK, ) + # Initialize a custom fast all-reduce implementation. + if _ENABLE_CUSTOM_ALL_REDUCE: + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + _TP_CA_COMMUNICATOR = CustomAllreduce( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) + # Build the pipeline model-parallel groups. global _PP_DEVICE_GROUP global _PP_GLOBAL_RANKS diff --git a/vllm/test_utils.py b/vllm/test_utils.py index 0cf23e4bb7e75..addd8ec1c26c8 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -6,24 +6,24 @@ def init_test_distributed_environment( - pipeline_parallel_size: int, - tensor_parallel_size: int, + tp_size: int, + pp_size: int, rank: int, distributed_init_port: str, local_rank: int = -1, ) -> None: distributed_init_method = f"tcp://localhost:{distributed_init_port}" init_distributed_environment( - world_size=pipeline_parallel_size * tensor_parallel_size, + world_size=pp_size * tp_size, rank=rank, distributed_init_method=distributed_init_method, local_rank=local_rank) - ensure_model_parallel_initialized(tensor_parallel_size, - pipeline_parallel_size) + ensure_model_parallel_initialized(tp_size, pp_size) def multi_process_tensor_parallel( - tensor_parallel_size: int, + tp_size: int, + pp_size: int, test_target, ) -> None: # Using ray helps debugging the error when it failed @@ -32,10 +32,9 @@ def multi_process_tensor_parallel( distributed_init_port = get_open_port() refs = [] - for rank in range(tensor_parallel_size): + for rank in range(tp_size * pp_size): refs.append( - test_target.remote(tensor_parallel_size, rank, - distributed_init_port)) + test_target.remote(tp_size, pp_size, rank, distributed_init_port)) ray.get(refs) ray.shutdown() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 21d76fd531e49..f46b475bdc2db 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -12,8 +12,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture_mode -from vllm.distributed.device_communicators import custom_all_reduce +from vllm.distributed.communication_op import graph_capture, graph_mode from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -942,13 +941,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce - # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - with custom_all_reduce.capture(): + with graph_capture(): # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): @@ -1040,7 +1033,7 @@ def capture( # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with graph_capture_mode(): + with graph_mode(): self.model( input_ids, positions, @@ -1055,7 +1048,7 @@ def capture( # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with graph_capture_mode(): + with graph_mode(): hidden_states = self.model( input_ids, positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e4fbc877b8c9f..82cf58101a95b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,9 +11,8 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.device_communicators.custom_all_reduce import ( - init_custom_ar) + init_distributed_environment, + set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput @@ -302,16 +301,14 @@ def init_worker_distributed_environment( local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) - # Initialize a custom fast all-reduce implementation. - if not parallel_config.disable_custom_all_reduce: - init_custom_ar() - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. From 350f9e107f0c00e59be1b970f96395494ed68b48 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 13 May 2024 22:50:09 +0800 Subject: [PATCH 263/413] [CI/Build] Move `test_utils.py` to `tests/utils.py` (#4425) Since #4335 was merged, I've noticed that the definition of ServerRunner in the tests is the same as in the test for OpenAI API. I have moved the class to the test utilities to avoid code duplication. (Although it only has been repeated twice so far, I will add another similar test suite in #4200 which would duplicate the code a third time) Also, I have moved the test utilities file (test_utils.py) to under the test directory (tests/utils.py), since none of its code is actually used in the main package. Note that I have added __init__.py to each test subpackage and updated the ray.init() call in the test utilities file in order to relative import tests/utils.py. --- .buildkite/test-pipeline.yaml | 24 +++-- tests/async_engine/__init__.py | 0 tests/async_engine/test_openapi_server_ray.py | 51 +---------- tests/basic_correctness/__init__.py | 0 tests/core/block/e2e/__init__.py | 0 tests/core/block/e2e/conftest.py | 3 +- tests/distributed/__init__.py | 0 .../test_basic_distributed_correctness.py | 6 +- tests/distributed/test_comm_ops.py | 5 +- tests/distributed/test_custom_all_reduce.py | 5 +- tests/engine/__init__.py | 0 tests/engine/output_processor/__init__.py | 0 .../output_processor/test_multi_step.py | 3 +- tests/entrypoints/__init__.py | 0 tests/entrypoints/test_openai_server.py | 47 +--------- tests/kernels/__init__.py | 0 tests/kernels/test_activation.py | 3 +- tests/kernels/test_attention.py | 3 +- tests/kernels/test_pos_encoding.py | 3 +- tests/metrics/__init__.py | 0 tests/model_executor/__init__.py | 0 tests/models/__init__.py | 0 tests/models/test_gptq_marlin.py | 3 +- tests/models/test_marlin.py | 3 +- tests/models/test_mistral.py | 2 +- tests/prefix_caching/__init__.py | 0 tests/quantization/__init__.py | 0 tests/samplers/__init__.py | 0 tests/samplers/test_logprobs.py | 3 +- tests/spec_decode/e2e/conftest.py | 3 +- tests/tensorizer_loader/test_tensorizer.py | 3 +- tests/test_sequence.py | 3 +- tests/utils.py | 89 +++++++++++++++++++ vllm/test_utils.py | 40 --------- 34 files changed, 138 insertions(+), 164 deletions(-) create mode 100644 tests/async_engine/__init__.py create mode 100644 tests/basic_correctness/__init__.py create mode 100644 tests/core/block/e2e/__init__.py create mode 100644 tests/distributed/__init__.py create mode 100644 tests/engine/__init__.py create mode 100644 tests/engine/output_processor/__init__.py create mode 100644 tests/entrypoints/__init__.py create mode 100644 tests/kernels/__init__.py create mode 100644 tests/metrics/__init__.py create mode 100644 tests/model_executor/__init__.py create mode 100644 tests/models/__init__.py create mode 100644 tests/prefix_caching/__init__.py create mode 100644 tests/quantization/__init__.py create mode 100644 tests/samplers/__init__.py create mode 100644 tests/utils.py delete mode 100644 vllm/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2eeba904a209d..4feea786f38ba 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -24,28 +24,26 @@ steps: command: pytest -v -s core - label: Distributed Comm Ops Test - command: pytest -v -s test_comm_ops.py - working_dir: "/vllm-workspace/tests/distributed" + command: pytest -v -s distributed/test_comm_ops.py + working_dir: "/vllm-workspace/tests" num_gpus: 2 - label: Distributed Tests - working_dir: "/vllm-workspace/tests/distributed" - - num_gpus: 2 # only support 1 or 2 for now. + working_dir: "/vllm-workspace/tests" + num_gpus: 2 mirror_hardwares: [amd] - commands: - - pytest -v -s test_pynccl_library.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py + - pytest -v -s distributed/test_pynccl_library.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py - label: Distributed Tests (Multiple Groups) - working_dir: "/vllm-workspace/tests/distributed" + working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - - pytest -v -s test_pynccl.py + - pytest -v -s distributed/test_pynccl.py - label: Engine Test #mirror_hardwares: [amd] diff --git a/tests/async_engine/__init__.py b/tests/async_engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 4b97af88012b9..ace4c53916c71 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,61 +1,16 @@ -# imports for guided decoding tests -import os -import subprocess -import sys -import time - import openai # use the official client for correctness check import pytest # using Ray for overall ease of process management, parallel requests, # and debugging. import ray -import requests -MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +from ..utils import ServerRunner + # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" -@ray.remote(num_gpus=1) -class ServerRunner: - - def __init__(self, args): - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get( - "http://localhost:8000/health").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError( - "Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(): ray.init() server_runner = ServerRunner.remote([ diff --git a/tests/basic_correctness/__init__.py b/tests/basic_correctness/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/block/e2e/__init__.py b/tests/core/block/e2e/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index 1d99cb5d32219..b0d62c8993d3f 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,9 +1,10 @@ import pytest -from tests.conftest import cleanup from vllm import LLM from vllm.model_executor.utils import set_random_seed +from ....conftest import cleanup + @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 527452630c9f5..d63f015511ada 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -4,10 +4,12 @@ variables. Run: ```sh +cd $VLLM_PATH/tests + TEST_DIST_MODEL=facebook/opt-125m pytest \ - test_basic_distributed_correctness.py + distributed/test_basic_distributed_correctness.py TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \ - test_basic_distributed_correctness.py + distributed/test_basic_distributed_correctness.py ``` """ import os diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index a4423bbfddf46..53654dc40d10d 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,8 +11,9 @@ from vllm.distributed import (broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) + +from ..utils import (init_test_distributed_environment, + multi_process_tensor_parallel) @ray.remote(num_gpus=1, max_calls=1) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index bdca031e39be1..630b25a3b6132 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -10,8 +10,9 @@ graph_capture, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, get_tp_ca_communicator) -from vllm.test_utils import (init_test_distributed_environment, - multi_process_tensor_parallel) + +from ..utils import (init_test_distributed_environment, + multi_process_tensor_parallel) random.seed(42) test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] diff --git a/tests/engine/__init__.py b/tests/engine/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/engine/output_processor/__init__.py b/tests/engine/output_processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py index 2bf4bf69da203..4f32a622546f0 100644 --- a/tests/engine/output_processor/test_multi_step.py +++ b/tests/engine/output_processor/test_multi_step.py @@ -4,7 +4,6 @@ import pytest from transformers import PreTrainedTokenizer -from tests.core.utils import create_seq_group from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker @@ -14,6 +13,8 @@ from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter +from ...core.utils import create_seq_group + @pytest.mark.parametrize("seq_output_len", [128]) @pytest.mark.parametrize("num_new_tokens", [1, 12]) diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index c22ac4507658b..ee2f034fd2c46 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1,10 +1,6 @@ # imports for guided decoding tests import json -import os import re -import subprocess -import sys -import time import jsonschema import openai # use the official client for correctness check @@ -12,7 +8,6 @@ # using Ray for overall ease of process management, parallel requests, # and debugging. import ray -import requests import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -20,7 +15,8 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +from ..utils import ServerRunner + # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -78,45 +74,6 @@ pytestmark = pytest.mark.asyncio -@ray.remote(num_gpus=1) -class ServerRunner: - - def __init__(self, args): - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get( - "http://localhost:8000/health").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError( - "Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - @pytest.fixture(scope="session") def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) diff --git a/tests/kernels/__init__.py b/tests/kernels/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 86ecc6412c648..a624c4ca9ee62 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -2,11 +2,12 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, NewGELU, SiluAndMul) +from .allclose_default import get_default_atol, get_default_rtol + DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 28496f187d466..fdf313262ca97 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,13 +3,14 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import _custom_ops as ops from vllm.utils import get_max_shared_memory_bytes, is_hip +from .allclose_default import get_default_atol, get_default_rtol + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index bf1856972cf33..18c8e351aa778 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -3,10 +3,11 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm.model_executor.layers.rotary_embedding import get_rope +from .allclose_default import get_default_atol, get_default_rtol + IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/model_executor/__init__.py b/tests/model_executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index b1c2b88bc99af..db55d4488a374 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -13,9 +13,10 @@ import pytest import torch -from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from .utils import check_logprobs_close + os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index fa846d43d0e88..37c1664afec55 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -15,9 +15,10 @@ import pytest import torch -from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from .utils import check_logprobs_close + capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] marlin_not_supported = (capability < diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 33d28da85d9e7..d0a5bfbfcd922 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -4,7 +4,7 @@ """ import pytest -from tests.models.utils import check_logprobs_close +from .utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", diff --git a/tests/prefix_caching/__init__.py b/tests/prefix_caching/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/samplers/__init__.py b/tests/samplers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 57d6d2a410ee5..40d054cd472b8 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,9 +1,10 @@ import pytest import torch -from tests.conftest import VllmRunner from vllm import SamplingParams +from ..conftest import VllmRunner + MODELS = ["facebook/opt-125m"] diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index eda7293ea7cee..da8b92711380e 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -9,7 +9,6 @@ from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit) -from tests.conftest import cleanup from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -21,6 +20,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, random_uuid +from ...conftest import cleanup + class AsyncLLM: """AsyncLLM diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index df1db4e6c4001..ad4748c5ebe96 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -9,12 +9,13 @@ import ray import torch -from tests.entrypoints.test_openai_server import ServerRunner from vllm import SamplingParams from vllm.model_executor.model_loader.tensorizer import ( EncryptionParams, TensorizerConfig, TensorSerializer, is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) +from ..utils import ServerRunner + prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b8ea1f6b77200..3136402518b9f 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,9 +1,10 @@ import pytest -from tests.core.utils import create_dummy_prompt from vllm.sequence import (CompletionSequenceGroupOutput, SamplerOutput, SequenceData, SequenceOutput) +from .core.utils import create_dummy_prompt + @pytest.fixture def sample_outputs(): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000..689d8c8c5ba8a --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,89 @@ +import os +import subprocess +import sys +import time + +import ray +import requests + +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.utils import get_open_port + +# Path to root of repository so that utilities can be imported by ray workers +VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) + + +@ray.remote(num_gpus=1) +class ServerRunner: + MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > self.MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +def init_test_distributed_environment( + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str, + local_rank: int = -1, +) -> None: + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=pp_size * tp_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=local_rank) + ensure_model_parallel_initialized(tp_size, pp_size) + + +def multi_process_tensor_parallel( + tp_size: int, + pp_size: int, + test_target, +) -> None: + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init(runtime_env={"working_dir": VLLM_PATH}) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(tp_size * pp_size): + refs.append( + test_target.remote(tp_size, pp_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() diff --git a/vllm/test_utils.py b/vllm/test_utils.py deleted file mode 100644 index addd8ec1c26c8..0000000000000 --- a/vllm/test_utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import ray - -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.utils import get_open_port - - -def init_test_distributed_environment( - tp_size: int, - pp_size: int, - rank: int, - distributed_init_port: str, - local_rank: int = -1, -) -> None: - distributed_init_method = f"tcp://localhost:{distributed_init_port}" - init_distributed_environment( - world_size=pp_size * tp_size, - rank=rank, - distributed_init_method=distributed_init_method, - local_rank=local_rank) - ensure_model_parallel_initialized(tp_size, pp_size) - - -def multi_process_tensor_parallel( - tp_size: int, - pp_size: int, - test_target, -) -> None: - # Using ray helps debugging the error when it failed - # as compared to multiprocessing. - ray.init() - - distributed_init_port = get_open_port() - refs = [] - for rank in range(tp_size * pp_size): - refs.append( - test_target.remote(tp_size, pp_size, rank, distributed_init_port)) - ray.get(refs) - - ray.shutdown() From e7c46b9527c9a50253657fd0078a0b1f23560ce4 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Mon, 13 May 2024 23:50:44 +0900 Subject: [PATCH 264/413] [Scheduler] Warning upon preemption and Swapping (#4647) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- docs/source/models/performance.rst | 19 ++++++++++ tests/basic_correctness/test_preemption.py | 44 +++++++++++++++++++++- tests/conftest.py | 16 ++++++++ tests/core/test_scheduler.py | 1 + vllm/core/scheduler.py | 18 +++++++++ vllm/engine/llm_engine.py | 4 +- vllm/engine/metrics.py | 9 ++++- 7 files changed, 108 insertions(+), 3 deletions(-) diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst index 589fce21056c2..d8750ddc34e8e 100644 --- a/docs/source/models/performance.rst +++ b/docs/source/models/performance.rst @@ -3,6 +3,25 @@ Performance and Tuning ====================== +Preemption +---------- +Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. +The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes +available again. When this occurs, the following warning is printed: + +``` +WARNING 05-09 00:49:33 scheduler.py:1057] Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1 +``` + +While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency. +If you frequently encounter preemptions from the vLLM engine, consider the following actions: + +- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space. +- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space. +- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache. + +You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False. + Chunked Prefill --------------- vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index ffb0717b3bfdb..29a4c39cd25a1 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -6,6 +6,7 @@ pytest tests/basic_correctness/test_preemption.py`. """ import pytest +from prometheus_client import REGISTRY from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, @@ -71,6 +72,7 @@ def test_chunked_prefill_recompute( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) def test_preemption( + caplog_vllm, hf_runner, vllm_runner, example_prompts, @@ -87,10 +89,13 @@ def test_preemption( vllm_model = vllm_runner( model, dtype=dtype, + disable_log_stats=False, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) del vllm_model for i in range(len(example_prompts)): @@ -100,6 +105,20 @@ def test_preemption( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.RECOMPUTE mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption @pytest.mark.parametrize("model", MODELS) @@ -107,6 +126,7 @@ def test_preemption( @pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("beam_width", [4]) def test_swap( + caplog_vllm, hf_runner, vllm_runner, example_prompts, @@ -122,11 +142,18 @@ def test_swap( max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, swap_space=10) + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + disable_log_stats=False, + ) vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, max_tokens) assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) del vllm_model for i in range(len(example_prompts)): @@ -138,6 +165,21 @@ def test_swap( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + assert ("is preempted by PreemptionMode.SWAP mode because there " + "is not enough KV cache space." in caplog_vllm.text) + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + preemption_metrics = None + for m in REGISTRY.collect(): + if m.name == "vllm:num_preemptions": + preemption_metrics = m + assert preemption_metrics is not None + total_recorded_preemption = 0 + for sample in preemption_metrics.samples: + total_recorded_preemption += sample.value + assert total_preemption == total_recorded_preemption + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) diff --git a/tests/conftest.py b/tests/conftest.py index b8117a19c75d9..999ace2c3c699 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -499,3 +499,19 @@ def get_tokenizer_pool_config(tokenizer_group_type): pool_type="ray", extra_config={}) raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 6bcabc4f95fa9..07fc8731e1847 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -180,6 +180,7 @@ def test_scheduler_schedule_preempt_abort(): and not out.blocks_to_swap_out) assert len(seq_group_meta) == 1 assert scheduler.get_num_unfinished_seq_groups() == 2 + assert out.preempted == 1 # Abort seq group a. Re-schedule seq group b prompt with recomputation. scheduler.abort_seq_group("1") diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fb6e985b2f31c..fbde27f998233 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -129,6 +129,7 @@ class SchedulerOutputs: num_lookahead_slots: int # The number of requests in the running queue running_queue_size: int + preempted: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -310,6 +311,7 @@ def __init__( self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT if self.enable_artificial_preemption else 0) + self.num_cumulative_preemption: int = 0 @property def lora_enabled(self) -> bool: @@ -785,6 +787,8 @@ def _schedule_default(self) -> SchedulerOutputs: # Update swapped requests. self.swapped = remaining_swapped self.swapped.extend(running_scheduled.swapped_out) + preempted = (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)) # There should be no prefill from running queue because this policy # doesn't allow chunked prefills. @@ -804,6 +808,7 @@ def _schedule_default(self) -> SchedulerOutputs: swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), + preempted=preempted, ) def _schedule_chunked_prefill(self): @@ -891,6 +896,8 @@ def _schedule_chunked_prefill(self): ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), ) def _schedule(self) -> SchedulerOutputs: @@ -1057,6 +1064,17 @@ def _preempt( preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", seq_group.request_id, + preemption_mode, self.num_cumulative_preemption + 1) + self.num_cumulative_preemption += 1 + if preemption_mode == PreemptionMode.RECOMPUTE: self._preempt_by_recompute(seq_group) elif preemption_mode == PreemptionMode.SWAP: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 46fa41030b4a1..e258a3f4afd54 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -737,6 +737,8 @@ def _get_stats( num_generation_tokens_iter = 0 time_to_first_tokens_iter: List[float] = [] time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) # Request stats # Latency @@ -830,7 +832,6 @@ def _get_stats( return Stats( now=now, - # System stats # Scheduler State num_running_sys=num_running_sys, @@ -846,6 +847,7 @@ def _get_stats( time_to_first_tokens_iter=time_to_first_tokens_iter, time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, # Request stats # Latency diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 3c4aac91549a9..ae7ae144bc04f 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -61,6 +61,10 @@ def __init__(self, labelnames: List[str], max_model_len: int): labelnames=labelnames) # Iteration stats + self.counter_num_preemption = Counter( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) self.counter_prompt_tokens = Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -181,6 +185,7 @@ class Stats: num_generation_tokens_iter: int time_to_first_tokens_iter: List[float] time_per_output_tokens_iter: List[float] + num_preemption_iter: int # Request stats (should have _requests suffix) # Latency @@ -244,6 +249,8 @@ def _log_prometheus(self, stats: Stats) -> None: stats.cpu_cache_usage_sys) # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) self._log_counter(self.metrics.counter_prompt_tokens, stats.num_prompt_tokens_iter) self._log_counter(self.metrics.counter_generation_tokens, @@ -336,7 +343,7 @@ def log(self, stats: Stats) -> None: "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " "Pending: %d reqs, GPU KV cache usage: %.1f%%, " - "CPU KV cache usage: %.1f%%", + "CPU KV cache usage: %.1f%%.", prompt_throughput, generation_throughput, stats.num_running_sys, From 0fca3cdcf265cd375bca684d951702b6b7adf65a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 May 2024 10:47:25 -0700 Subject: [PATCH 265/413] [Misc] Enhance attention selector (#4751) --- tests/worker/test_model_runner.py | 1 - vllm/attention/__init__.py | 4 +- vllm/attention/backends/abstract.py | 5 +- vllm/attention/backends/flash_attn.py | 13 +++-- vllm/attention/backends/flashinfer.py | 33 +++++++---- vllm/attention/backends/rocm_flash_attn.py | 16 ++--- vllm/attention/backends/torch_sdpa.py | 28 +++++---- vllm/attention/backends/xformers.py | 12 ++-- vllm/attention/layer.py | 19 +++++- vllm/attention/selector.py | 28 +++++++-- vllm/model_executor/model_loader/__init__.py | 19 +++--- vllm/model_executor/model_loader/loader.py | 61 +++++++++++++------- vllm/model_executor/models/arctic.py | 16 ++++- vllm/model_executor/models/baichuan.py | 29 +++++++--- vllm/model_executor/models/bloom.py | 15 +++-- vllm/model_executor/models/chatglm.py | 20 +++++-- vllm/model_executor/models/commandr.py | 14 ++++- vllm/model_executor/models/dbrx.py | 17 ++++-- vllm/model_executor/models/decilm.py | 4 +- vllm/model_executor/models/deepseek.py | 16 ++++- vllm/model_executor/models/falcon.py | 15 +++-- vllm/model_executor/models/gemma.py | 14 +++-- vllm/model_executor/models/gpt2.py | 16 +++-- vllm/model_executor/models/gpt_bigcode.py | 14 +++-- vllm/model_executor/models/gpt_j.py | 20 +++++-- vllm/model_executor/models/gpt_neox.py | 16 +++-- vllm/model_executor/models/internlm2.py | 13 ++++- vllm/model_executor/models/jais.py | 12 +++- vllm/model_executor/models/llama.py | 17 ++++-- vllm/model_executor/models/llava.py | 6 +- vllm/model_executor/models/minicpm.py | 13 ++++- vllm/model_executor/models/mixtral.py | 13 ++++- vllm/model_executor/models/mixtral_quant.py | 31 ++++++---- vllm/model_executor/models/mpt.py | 18 ++++-- vllm/model_executor/models/olmo.py | 14 +++-- vllm/model_executor/models/opt.py | 16 +++-- vllm/model_executor/models/orion.py | 13 ++++- vllm/model_executor/models/phi.py | 16 +++-- vllm/model_executor/models/qwen.py | 15 ++++- vllm/model_executor/models/qwen2.py | 14 +++-- vllm/model_executor/models/qwen2_moe.py | 16 ++++- vllm/model_executor/models/stablelm.py | 14 +++-- vllm/model_executor/models/starcoder2.py | 18 +++++- vllm/model_executor/models/xverse.py | 14 +++-- vllm/worker/cache_engine.py | 14 ++++- vllm/worker/cpu_model_runner.py | 15 +++-- vllm/worker/cpu_worker.py | 10 +++- vllm/worker/embedding_model_runner.py | 1 - vllm/worker/model_runner.py | 15 +++-- 49 files changed, 573 insertions(+), 220 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 3e3d2e3f5c53d..c2d1c5769619b 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -307,7 +307,6 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) - assert attn_metadata.kv_cache_dtype == "auto" assert attn_metadata.num_prefills == prefill_batch_size if enforce_eager: assert attn_metadata.num_decode_tokens == decode_batch_size diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 7636b34a16fed..088f48def7668 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -5,9 +5,9 @@ from vllm.attention.selector import get_attn_backend __all__ = [ + "Attention", "AttentionBackend", "AttentionMetadata", - "Attention", - "get_attn_backend", "AttentionMetadataPerStage", + "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 64ccb309a0480..98d70fcab1a18 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]): # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The kv cache's data type. - kv_cache_dtype: str def __post_init__(self): if self.num_prefill_tokens > 0: @@ -116,6 +114,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: raise NotImplementedError @@ -127,6 +126,6 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4bad226512b69..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -140,16 +140,18 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -167,7 +169,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata[FlashAttentionMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -196,8 +198,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -264,7 +265,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 36e162671f944..92d0fe0487516 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -149,20 +149,33 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) - self.alibi_slopes = alibi_slopes - self.scale = scale self.num_heads = num_heads self.head_size = head_size + self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.kv_cache_dtype = kv_cache_dtype - def forward(self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], - kv_scale: float): + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float = 1.0, + ) -> torch.Tensor: + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -183,7 +196,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, kv_cache[:, 0], kv_cache[:, 1], attn_metadata.slot_mapping.flatten(), - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, ) if prefill_meta := attn_metadata.prefill_metadata: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 8fc1af1aa1e1c..539585b46c7aa 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -138,25 +138,27 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. @@ -229,7 +231,7 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, kv_scale, ) @@ -323,7 +325,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c29218dfd0cfc..2dd72a00c6e30 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -83,26 +83,32 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window if alibi_slopes is not None: - assert len(alibi_slopes) == num_heads alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") def forward( self, @@ -111,7 +117,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: TorchSDPAMetadata, # type: ignore - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -124,6 +130,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ + assert kv_scale == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -136,8 +143,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) if attn_metadata.is_prompt: assert attn_metadata.seq_lens is not None @@ -195,7 +201,7 @@ def forward( attn_metadata.block_tables, attn_metadata.seq_lens_tensor, attn_metadata.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 2a9150dea5875..cb2028553461f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -149,15 +149,17 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -175,7 +177,7 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[XFormersMetadata], - kv_scale: float, + kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -188,7 +190,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - num_tokens, hidden_size = query.shape query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -203,8 +204,7 @@ def forward( PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, - attn_metadata.kv_cache_dtype, - kv_scale) + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -262,7 +262,7 @@ def forward( decode_meta.block_tables, decode_meta.seq_lens_tensor, decode_meta.max_seq_len, - attn_metadata.kv_cache_dtype, + self.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ee7be26c0876c..8a872dba8c877 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import (AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig class Attention(nn.Module): @@ -29,10 +30,24 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() - self.backend = get_attn_backend(torch.get_default_dtype()) - impl_cls = self.backend.get_impl_cls() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + if num_kv_heads is None: + num_kv_heads = num_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) + impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index f4446bac6b8d2..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,6 +1,6 @@ import enum from functools import lru_cache -from typing import Type +from typing import Optional, Type import torch @@ -21,8 +21,18 @@ class _Backend(enum.Enum): @lru_cache(maxsize=None) -def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: - backend = _which_attn_to_use(dtype) +def get_attn_backend( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> Type[AttentionBackend]: + backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -44,14 +54,22 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend. ") + logger.warning("Eager mode is enforced for the Flashinfer backend.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use(dtype: torch.dtype) -> _Backend: +def _which_attn_to_use( + num_heads: int, + head_size: int, + num_kv_heads: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, +) -> _Backend: """Returns which flash attention backend to use.""" if is_cpu(): return _Backend.TORCH_SDPA diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 6f90e49994fb2..e3e32d61ab04d 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -2,26 +2,29 @@ from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) from vllm.model_executor.model_loader.loader import (BaseModelLoader, get_model_loader) from vllm.model_executor.model_loader.utils import ( get_architecture_class_name, get_model_architecture) -def get_model( - *, model_config: ModelConfig, load_config: LoadConfig, - device_config: DeviceConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def get_model(*, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: loader = get_model_loader(load_config) return loader.load_model(model_config=model_config, device_config=device_config, lora_config=lora_config, vision_language_config=vision_language_config, parallel_config=parallel_config, - scheduler_config=scheduler_config) + scheduler_config=scheduler_config, + cache_config=cache_config) __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bafa2de62e5df..fc9c8aa0af44b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,9 +9,9 @@ import torch from torch import nn -from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -77,15 +77,16 @@ def _get_model_initialization_kwargs( return extra_kwargs -def _initialize_model( - model_config: ModelConfig, load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: +def _initialize_model(model_config: ModelConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] quant_config = _get_quantization_config(model_config, load_config) return model_class(config=model_config.hf_config, + cache_config=cache_config, quant_config=quant_config, **_get_model_initialization_kwargs( model_class, lora_config, vision_language_config)) @@ -103,7 +104,8 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: """Load a model with the given configurations.""" ... @@ -216,11 +218,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, @@ -253,11 +257,13 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) @@ -286,9 +292,12 @@ def _get_weights_iterator( return tensorizer_weights_iterator(tensorizer_args) def _load_model_unserialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load an unserialized model with tensorizer. @@ -299,15 +308,19 @@ def _load_model_unserialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, - lora_config, vision_language_config) + lora_config, vision_language_config, + cache_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( - self, model_config: ModelConfig, device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig] + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + cache_config: CacheConfig, ) -> nn.Module: """Load a serialized model with tensorizer. @@ -321,6 +334,7 @@ def _load_model_serialized( extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, vision_language_config) extra_kwargs["quant_config"] = quant_config + extra_kwargs["cache_config"] = cache_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -335,16 +349,19 @@ def load_model(self, *, model_config: ModelConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig) -> nn.Module: + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) if is_vllm_serialized_tensorizer(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) return self._load_model_unserialized(model_config, device_config, lora_config, - vision_language_config) + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 796cef7c4a735..cb99939cbb17a 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -5,6 +5,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -215,6 +216,7 @@ def __init__( self, config: ArcticConfig, layer_idx: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -265,7 +267,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -288,6 +291,7 @@ def __init__( self, config: ArcticConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -297,6 +301,7 @@ def __init__( self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, layer_idx, + cache_config, quant_config=quant_config) self.block_sparse_moe = ArcticMoE( config, @@ -356,6 +361,7 @@ class ArcticModel(nn.Module): def __init__( self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -366,7 +372,10 @@ def __init__( config.hidden_size, org_num_embeddings=self.vocab_size) self.layers = nn.ModuleList([ - ArcticDecoderLayer(config, layer_idx, quant_config=quant_config) + ArcticDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self._attn_implementation = config._attn_implementation @@ -392,11 +401,12 @@ class ArcticForCausalLM(nn.Module): def __init__(self, config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, **kwargs) -> None: super().__init__() self.config = config - self.model = ArcticModel(config, quant_config) + self.model = ArcticModel(config, cache_config, quant_config) self.vocab_size = config.vocab_size self.lm_head = ParallelLMHead( self.vocab_size, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 186cee2584369..58b3405d319d1 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,7 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -111,6 +111,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -162,7 +163,10 @@ def __init__( base=self.rope_theta, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -185,6 +189,7 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size @@ -197,6 +202,7 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = BaiChuanMLP( @@ -244,6 +250,7 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -255,7 +262,8 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, quant_config) + BaiChuanDecoderLayer(config, position_embedding, cache_config, + quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,13 +312,15 @@ def __init__( self, config, position_embedding: str, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = BaiChuanModel(config, position_embedding, quant_config) + self.model = BaiChuanModel(config, position_embedding, cache_config, + quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -389,13 +399,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", quant_config, lora_config) + super().__init__(config, "ALIBI", cache_config, quant_config, + lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -404,7 +417,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", quant_config, lora_config) + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1d7e5d2517c72..fe2de87b20dc9 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -24,6 +24,7 @@ from transformers import BloomConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -71,6 +72,7 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,7 +110,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + cache_config=cache_config) def forward( self, @@ -158,6 +161,7 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -165,7 +169,8 @@ def __init__( self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, quant_config) + self.self_attention = BloomAttention(config, cache_config, + quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -214,6 +219,7 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -229,7 +235,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, quant_config) + BloomBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -262,12 +268,13 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = BloomModel(config, quant_config) + self.transformer = BloomModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e116af2ed080d..29c76682109c6 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -34,6 +34,7 @@ class GLMAttention(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -90,6 +91,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -167,6 +169,7 @@ class GLMBlock(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -181,7 +184,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, quant_config) + self.self_attention = GLMAttention(config, cache_config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -237,6 +240,7 @@ class GLMTransformer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -246,8 +250,10 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList( - [GLMBlock(config, quant_config) for i in range(self.num_layers)]) + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -282,6 +288,7 @@ class ChatGLMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -292,7 +299,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, quant_config) + self.encoder = GLMTransformer(config, cache_config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -334,13 +341,14 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config - self.transformer = ChatGLMModel(config, quant_config) + self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 17c2f1223d96b..7354d11f98b15 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,6 +29,7 @@ from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -124,6 +125,7 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -180,6 +182,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, @@ -219,11 +222,14 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, quant_config=quant_config) + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config) self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), @@ -258,6 +264,7 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -266,7 +273,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, quant_config=quant_config) + CohereDecoderLayer(config, cache_config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -299,6 +306,7 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -306,7 +314,7 @@ def __init__( self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, quant_config) + self.model = CohereModel(config, cache_config, quant_config) self.sampler = Sampler() @torch.no_grad() diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index a4a0ae50c645e..083ddf0159f71 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -5,6 +5,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -166,6 +167,7 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -221,6 +223,7 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, + cache_config=cache_config, ) def forward( @@ -279,10 +282,12 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) + self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, + quant_config) self.ffn = DbrxExperts(config, quant_config) def forward( @@ -308,6 +313,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -315,8 +321,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + DbrxBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -349,13 +357,14 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, quant_config) + self.transformer = DbrxModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index be9a6b6813f8f..e293ee491908d 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -28,7 +28,7 @@ import torch from transformers import PretrainedConfig -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -56,12 +56,14 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, + cache_config=cache_config, quant_config=quant_config, lora_config=lora_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index e5f7ba086a35d..62e04f9649915 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -28,6 +28,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -178,6 +179,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -229,7 +231,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -252,6 +255,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -267,6 +271,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.n_routed_experts is not None @@ -321,6 +326,7 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -332,7 +338,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) + DeepseekDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -360,12 +369,13 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = DeepseekModel(config, quant_config) + self.model = DeepseekModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 08dd69923dc6d..ab9e1994be426 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -27,6 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -77,6 +78,7 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -168,7 +170,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -229,12 +232,14 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, quant_config) + self.self_attention = FalconAttention(config, cache_config, + quant_config) self.mlp = FalconMLP(config, quant_config) self.config = config @@ -311,6 +316,7 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -327,7 +333,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, quant_config) + FalconDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -359,12 +365,13 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = FalconModel(config, quant_config) + self.transformer = FalconModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index bb73ff4d206da..d1502b718a773 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -22,7 +22,7 @@ from transformers import GemmaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul @@ -107,6 +107,7 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -155,7 +156,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -177,6 +179,7 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -188,6 +191,7 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = GemmaMLP( @@ -236,6 +240,7 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -246,7 +251,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, quant_config) + GemmaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -309,6 +314,7 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -316,7 +322,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = GemmaModel(config, quant_config) + self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 75eaebf0dbd15..0deaa58ed9eb5 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -24,6 +24,7 @@ from transformers import GPT2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -70,7 +72,10 @@ def __init__( bias=True, quant_config=quant_config, ) - self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config) def forward( self, @@ -122,6 +127,7 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -130,7 +136,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, quant_config) + self.attn = GPT2Attention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config, quant_config) @@ -163,6 +169,7 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -174,7 +181,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, quant_config) + GPT2Block(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -203,12 +210,13 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPT2Model(config, quant_config) + self.transformer = GPT2Model(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index d057fd928fdb5..c20fb3230c394 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,6 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -85,7 +87,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -143,6 +146,7 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -151,7 +155,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, quant_config) + self.attn = GPTBigCodeAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -184,6 +188,7 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -195,7 +200,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, quant_config) + GPTBigCodeBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -224,12 +229,13 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 8d7fe8a5beef7..5f4d8ec3d3a7a 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -23,6 +23,7 @@ from transformers import GPTJConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -83,7 +85,10 @@ def __init__( base=rope_theta, is_neox_style=False, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -135,13 +140,14 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, quant_config) + self.attn = GPTJAttention(config, cache_config, quant_config) self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -169,6 +175,7 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,8 +185,10 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList( - [GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) + self.h = nn.ModuleList([ + GPTJBlock(config, cache_config, quant_config) + for _ in range(config.n_layer) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,13 +216,14 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, quant_config) + self.transformer = GPTJModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index bab563b9c5a39..dcb52ff666c95 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -23,6 +23,7 @@ from transformers import GPTNeoXConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -45,6 +46,7 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -84,7 +86,10 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -134,6 +139,7 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -142,7 +148,7 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, quant_config) + self.attention = GPTNeoXAttention(config, cache_config, quant_config) self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -182,6 +188,7 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -192,7 +199,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, quant_config) + GPTNeoXLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -223,12 +230,13 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.gpt_neox = GPTNeoXModel(config, quant_config) + self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 5811cae83bf8b..65f7ddb8b082c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -64,6 +65,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -114,7 +116,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -136,6 +139,7 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -151,6 +155,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.feed_forward = InternLM2MLP( @@ -196,6 +201,7 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -207,7 +213,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, quant_config) + InternLMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -239,12 +245,13 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = InternLM2Model(config, quant_config) + self.model = InternLM2Model(config, cache_config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index bd6a180ec8dfc..df30fd1ba0a37 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -26,6 +26,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -69,6 +70,7 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -108,6 +110,7 @@ def __init__( self.head_dim, scale=self.scale, alibi_slopes=alibi_slopes, + cache_config=cache_config, ) def forward( @@ -170,6 +173,7 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -178,7 +182,7 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, quant_config) + self.attn = JAISAttention(config, cache_config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -211,6 +215,7 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -228,7 +233,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, quant_config) + JAISBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -262,12 +267,13 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = JAISModel(config, quant_config) + self.transformer = JAISModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 127e4612b2e40..ebdc64e0e220e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -94,6 +94,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -153,7 +154,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -176,6 +178,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -204,6 +207,7 @@ def __init__( quant_config=quant_config, bias=attention_bias, sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -251,6 +255,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -267,7 +272,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +337,16 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.model = LlamaModel(config, quant_config, lora_config=lora_config) + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dcde4dfa0795e..3b99b337a2765 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -7,7 +7,7 @@ from transformers import CLIPVisionModel, LlavaConfig from vllm.attention import AttentionMetadata -from vllm.config import VisionLanguageConfig +from vllm.config import CacheConfig, VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( @@ -62,6 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional["QuantizationConfig"] = None) -> None: super().__init__() self.config = config @@ -85,7 +86,8 @@ def __init__(self, projector_hidden_act=config.projector_hidden_act) self.quant_config = quant_config - self.language_model = LlamaModel(config.text_config, quant_config) + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c90bcfbfc4707..0b85cf1c94795 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -28,7 +28,7 @@ from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -181,6 +181,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +235,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -259,6 +261,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -275,6 +278,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) @@ -330,6 +334,7 @@ class MiniCPMModel(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -346,7 +351,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, quant_config) + MiniCPMDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -413,6 +418,7 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -421,6 +427,7 @@ def __init__( self.num_experts = getattr(self.config, "num_experts", 0) self.quant_config = quant_config self.model = MiniCPMModel(config, + cache_config, quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index efa4de7516212..113abbaa6036d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -252,6 +252,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -313,6 +314,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -335,6 +337,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -348,6 +351,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, @@ -394,6 +398,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -410,7 +415,9 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -460,12 +467,14 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.model = MixtralModel(config, + cache_config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 38c62afced28a..ee2626b1c1aa2 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -30,6 +30,7 @@ from transformers import MixtralConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -157,14 +158,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -215,6 +219,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -237,6 +242,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -250,6 +256,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, + cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) @@ -292,6 +299,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -303,7 +311,9 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, quant_config=quant_config) + MixtralDecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -332,12 +342,13 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = MixtralModel(config, quant_config) + self.model = MixtralModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 6fa5c5bd3014a..716ac51cde94d 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -7,6 +7,7 @@ import torch.nn as nn from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn @@ -43,6 +44,7 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -107,7 +109,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -166,12 +169,13 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, quant_config) + self.attn = MPTAttention(config, cache_config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -201,6 +205,7 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -211,8 +216,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList( - [MPTBlock(config, quant_config) for _ in range(config.n_layers)]) + self.blocks = nn.ModuleList([ + MPTBlock(config, cache_config, quant_config) + for _ in range(config.n_layers) + ]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -246,6 +253,7 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -253,7 +261,7 @@ def __init__( assert config.tie_word_embeddings self.quant_config = quant_config - self.transformer = MPTModel(config, quant_config) + self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index f212ea2166e1d..69f23bbfb5d0a 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -28,6 +28,7 @@ from transformers import OlmoConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,7 @@ class OlmoAttention(nn.Module): def __init__( self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -93,7 +95,8 @@ def __init__( self.scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) # Attention output projection. self.o_proj = RowParallelLinear( @@ -175,10 +178,11 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, quant_config) + self.self_attn = OlmoAttention(config, cache_config, quant_config) # MLP block. self.mlp = OlmoMLP(config, quant_config) @@ -217,6 +221,7 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -224,7 +229,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, quant_config) + OlmoDecoderLayer(config, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, @@ -271,10 +276,11 @@ class OlmoForCausalLM(nn.Module): def __init__(self, config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = OlmoModel(config, quant_config) + self.model = OlmoModel(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight else: diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 336f765ababaa..d241756e50f4a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -24,6 +24,7 @@ from transformers import OPTConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -61,6 +62,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -88,7 +90,8 @@ def __init__( ) self.attn = Attention(self.num_heads, self.head_dim, - scale=self.scaling) + scale=self.scaling, + cache_config=cache_config) def forward( self, @@ -108,6 +111,7 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -117,6 +121,7 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, + cache_config=cache_config, quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -181,6 +186,7 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -226,7 +232,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, quant_config) + OPTDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -259,10 +265,11 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, quant_config) + self.decoder = OPTDecoder(config, cache_config, quant_config) def forward( self, @@ -279,12 +286,13 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.model = OPTModel(config, quant_config) + self.model = OPTModel(config, cache_config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 9ab5dfb97c19a..59cd42e31b374 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -68,6 +69,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -118,7 +120,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -140,6 +143,7 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -155,6 +159,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) self.mlp = OrionMLP( @@ -202,6 +207,7 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -213,7 +219,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, quant_config) + OrionDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -245,12 +251,13 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = OrionModel(config, quant_config) + self.model = OrionModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 4a45879201af3..ed25a232f4208 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,6 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -63,6 +64,7 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads @@ -105,7 +107,10 @@ def __init__(self, max_position=max_position_embeddings, base=rope_theta, ) - self.attn = Attention(self.num_heads, self.head_size, scaling) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config) def forward( self, @@ -155,11 +160,12 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, quant_config) + self.self_attn = PhiAttention(config, cache_config, quant_config) self.mlp = PhiMLP(config, quant_config) def forward( @@ -186,6 +192,7 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -193,7 +200,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, quant_config) + PhiLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -225,12 +232,13 @@ class PhiForCausalLM(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.quant_config = quant_config - self.model = PhiModel(config, quant_config) + self.model = PhiModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e5e0028888c88..d158846a3a1f5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,6 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -68,6 +69,7 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -101,7 +103,10 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, self.head_dim, self.scaling) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config) def forward( self, @@ -123,6 +128,7 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -135,6 +141,7 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, + cache_config=cache_config, quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -175,6 +182,7 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -186,7 +194,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, quant_config) + QWenBlock(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -218,12 +226,13 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = QWenModel(config, quant_config) + self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 62bc7fe22c367..31ba6441f9f7a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -29,7 +29,7 @@ from transformers import Qwen2Config from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -87,6 +87,7 @@ def __init__(self, max_position: int = 4096 * 32, rope_theta: float = 10000, use_sliding_window: bool = False, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() @@ -137,7 +138,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + sliding_window=self.sliding_window, + cache_config=cache_config) def forward( self, @@ -160,6 +162,7 @@ def __init__( self, config: Qwen2Config, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, use_sliding_window=use_sliding_window, + cache_config=cache_config, quant_config=quant_config, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( @@ -222,6 +226,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -234,7 +239,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, quant_config) + Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -287,6 +292,7 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -294,7 +300,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2Model(config, quant_config) + self.model = Qwen2Model(config, cache_config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 8da89a2b7ba6c..2a3b0173adf8b 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -30,6 +30,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -187,6 +188,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -238,7 +240,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + cache_config=cache_config) def forward( self, @@ -261,6 +264,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -276,6 +280,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, ) if (config.num_experts is not None @@ -328,6 +333,7 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -339,7 +345,10 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) + Qwen2MoeDecoderLayer(config, + layer_idx, + cache_config, + quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -369,12 +378,13 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Qwen2MoeModel(config, quant_config) + self.model = Qwen2MoeModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 3d4f4f700f867..8b4a5507feade 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -72,6 +73,7 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config @@ -124,7 +126,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - num_kv_heads=self.num_key_value_heads) + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config) def forward( self, @@ -146,10 +149,11 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.self_attn = StablelmAttention(config) + self.self_attn = StablelmAttention(config, cache_config, quant_config) self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) @@ -188,6 +192,7 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( @@ -195,7 +200,7 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, quant_config) + StablelmDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -227,12 +232,13 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, quant_config) + self.model = StableLMEpochModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 33998e2aad5c5..3c19d63276a77 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -25,6 +25,7 @@ from transformers import Starcoder2Config from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -46,6 +47,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -101,6 +103,7 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, + cache_config=cache_config, ) def forward( @@ -150,10 +153,13 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, quant_config=quant_config) + self.self_attn = Starcoder2Attention(config, + cache_config, + quant_config=quant_config) self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -191,6 +197,7 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -201,7 +208,9 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, quant_config=quant_config) + Starcoder2DecoderLayer(config, + cache_config, + quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -226,10 +235,13 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, quant_config=quant_config) + self.model = Starcoder2Model(config, + cache_config, + quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 0fb2662b2f715..6ef230a8ebbca 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -27,7 +27,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import LoRAConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -89,6 +89,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -133,7 +134,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + sliding_window=sliding_window, + cache_config=cache_config) def forward( self, @@ -155,6 +157,7 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -175,6 +178,7 @@ def __init__( quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, + cache_config=cache_config, ) self.mlp = XverseMLP( hidden_size=self.hidden_size, @@ -221,6 +225,7 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: @@ -237,7 +242,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, quant_config) + XverseDecoderLayer(config, cache_config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -295,13 +300,14 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = XverseModel(config, quant_config) + self.model = XverseModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1fb63a3e47921..07d51dca226bd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -31,7 +31,7 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks @@ -43,7 +43,15 @@ def __init__( self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + model_config.get_num_attention_heads(parallel_config), + self.head_size, + self.num_kv_heads, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") @@ -56,7 +64,7 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 6c8b1685dadcf..0a0b0d70cfe21 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -53,7 +53,15 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -66,7 +74,8 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, lora_config=self.lora_config, parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) def _prepare_prompt( self, @@ -158,7 +167,6 @@ def _prepare_prompt( decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) @@ -242,7 +250,6 @@ def _prepare_decode( prefill_metadata=None, decode_metadata=None, block_tables=block_tables, - kv_cache_dtype=self.kv_cache_dtype, ) return ( input_tokens, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 5e4ae564cb57e..3ee394f9912e9 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -53,7 +53,15 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] # Get attention backend. - self.attn_backend = get_attn_backend(model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + ) # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 2d3f160c60dc1..d04bebbdc31b6 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -235,7 +235,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, pooling_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f46b475bdc2db..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -141,10 +141,18 @@ def __init__( self.graph_block_tables = np.zeros( (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), dtype=np.int32) - self.attn_backend = get_attn_backend(self.model_config.dtype) + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) # Lazy initialization - self.model: torch.nn.Module # Set after load_model + self.model: nn.Module # Set after load_model # Set if the backend is flashinfer. self.flashinfer_workspace_buffer: torch.Tensor # Set after load_model. @@ -160,6 +168,7 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, + cache_config=self.cache_config, ) self.model_memory_usage = m.consumed_memory @@ -753,7 +762,6 @@ def prepare_input_tensors( num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, attn_metadata, @@ -965,7 +973,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: slot_mapping=slot_mapping[:batch_size], prefill_metadata=None, decode_metadata=decode_metadata, - kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: From 8bc68e198c4c90ddc2e54fa76eb81c2c714bb1cd Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Mon, 13 May 2024 17:57:07 -0400 Subject: [PATCH 266/413] [Frontend] [Core] perf: Automatically detect vLLM-tensorized model, update `tensorizer` to version 2.9.0 (#4208) --- .buildkite/test-pipeline.yaml | 4 +- examples/tensorize_vllm_model.py | 200 ++++++-------- requirements-dev.txt | 2 +- setup.py | 2 +- .../tensorize_vllm_model_for_testing.py | 245 ------------------ tests/tensorizer_loader/test_tensorizer.py | 189 +++++--------- vllm/engine/arg_utils.py | 4 +- vllm/envs.py | 2 +- vllm/model_executor/model_loader/loader.py | 28 +- .../model_executor/model_loader/tensorizer.py | 106 ++++++-- 10 files changed, 259 insertions(+), 523 deletions(-) delete mode 100644 tests/tensorizer_loader/tensorize_vllm_model_for_testing.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4feea786f38ba..3c3da41c3abf3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -60,11 +60,13 @@ steps: mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - - pip install awscli + # install tensorizer for tensorize_vllm_model.py + - pip install awscli tensorizer - python3 offline_inference.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 llava_example.py + - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Kernels Test %N command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py index e2456168de9d5..8b74ae1d75a1d 100644 --- a/examples/tensorize_vllm_model.py +++ b/examples/tensorize_vllm_model.py @@ -1,23 +1,20 @@ import argparse import dataclasses +import json import os -import time import uuid from functools import partial -from typing import Type -import torch -import torch.nn as nn -from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, - TensorSerializer, stream_io) -from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor -from transformers import AutoConfig, PretrainedConfig +from tensorizer import stream_io -from vllm.distributed import initialize_model_parallel +from vllm import LLM +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.model_executor.model_loader.tensorizer import TensorizerArgs -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs, + TensorizerConfig, + serialize_vllm_model) # yapf conflicts with isort for this docstring # yapf: disable @@ -27,25 +24,25 @@ to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint, or locally. Tensor encryption and decryption is also supported, although libsodium must be installed to use it. Install vllm with tensorizer support -using `pip install vllm[tensorizer]`. +using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit +https://github.com/coreweave/tensorizer To serialize a model, install vLLM from source, then run something like this from the root level of this repository: python -m examples.tensorize_vllm_model \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ + --model facebook/opt-125m \ serialize \ - --serialized-directory s3://my-bucket/ \ - --suffix vllm + --serialized-directory s3://my-bucket \ + --suffix v1 Which downloads the model from HuggingFace, loads it into vLLM, serializes it, and saves it to your S3 bucket. A local directory can also be used. This assumes your S3 credentials are specified as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. -To provide S3 credentials directly, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this -script. +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and +`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide +`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint` +as CLI args to this script. You can also encrypt the model weights with a randomly-generated key by providing a `--keyfile` argument. @@ -57,7 +54,7 @@ --model EleutherAI/gpt-j-6B \ --dtype float16 \ deserialize \ - --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors Which downloads the model tensors from your S3 bucket and deserializes them. @@ -71,26 +68,30 @@ `python -m examples.tensorize_vllm_model deserialize --help`. -Once a model is serialized, it can be used to load the model when running the -OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing -the `--tensorizer-uri` CLI argument that is functionally the same as the -`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to -signify that the model to be deserialized is a vLLM model, rather than a -HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer -in the same inference server, albeit without the speed optimizations. To -deserialize an encrypted file, the `--encryption-keyfile` argument can be used -to provide the path to the keyfile used to encrypt the model weights. For -information on all the arguments that can be used to configure tensorizer's -deserialization, check out the tensorizer options argument group in the -`vllm/entrypoints/openai/api_server.py` script with `--help`. - -Tensorizer can also be invoked with the `LLM` class directly to load models: +Once a model is serialized, tensorizer can be invoked with the `LLM` class +directly to load models: llm = LLM(model="facebook/opt-125m", load_format="tensorizer", - tensorizer_uri=path_to_opt_tensors, - num_readers=3, - vllm_tensorized=True) + model_loader_extra_config=TensorizerConfig( + tensorizer_uri = path_to_tensors, + num_readers=3, + ) + ) + +A serialized model can be used during model loading for the vLLM OpenAI +inference server. `model_loader_extra_config` is exposed as the CLI arg +`--model-loader-extra-config`, and accepts a JSON string literal of the +TensorizerConfig arguments desired. + +In order to see all of the available arguments usable to configure +loading with tensorizer that are given to `TensorizerConfig`, run: + +`python -m examples.tensorize_vllm_model deserialize --help` + +under the `tensorizer options` section. These can also be used for +deserialization in this example script, although `--tensorizer-uri` and +`--path-to-tensors` are functionally the same in this case. """ @@ -158,95 +159,35 @@ def parse_args(): help=("Path to a binary key to use to decrypt the model weights," " if the model was serialized with encryption")) - return parser.parse_args() - - -def make_model_contiguous(model): - # Ensure tensors are saved in memory contiguously - for param in model.parameters(): - param.data = param.data.contiguous() - - -def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def serialize(): - - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) - engine = LLMEngine.from_engine_args(engine_args) + TensorizerArgs.add_cli_args(deserialize_parser) - model = (engine.model_executor.driver_worker. - model_runner.model) - - encryption_params = EncryptionParams.random() if keyfile else None - if keyfile: - with _write_stream(keyfile) as stream: - stream.write(encryption_params.key) - - with _write_stream(model_path) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - serializer.close() + return parser.parse_args() - print("Serialization complete. Model tensors saved to", model_path) - if keyfile: - print("Key saved to", keyfile) def deserialize(): - config = AutoConfig.from_pretrained(model_ref) - - with no_init_or_tensor(): - model_class = _get_vllm_model_architecture(config) - model = model_class(config) - - before_mem = get_mem_usage() - start = time.time() - - if keyfile: - with _read_stream(keyfile) as stream: - key = stream.read() - decryption_params = DecryptionParams.from_key(key) - tensorizer_args.deserializer_params['encryption'] = \ - decryption_params - - with (_read_stream(model_path)) as stream, TensorDeserializer( - stream, **tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(model) - end = time.time() - - # Brag about how fast we are. - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - print( - f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" + llm = LLM(model=args.model, + load_format="tensorizer", + model_loader_extra_config=tensorizer_config ) - print(f"Memory usage before: {before_mem}") - print(f"Memory usage after: {after_mem}") + return llm - return model args = parse_args() -s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") - or None) -s3_secret_access_key = (args.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY") or None) +s3_access_key_id = (getattr(args, 's3_access_key_id', None) + or os.environ.get("S3_ACCESS_KEY_ID", None)) +s3_secret_access_key = (getattr(args, 's3_secret_access_key', None) + or os.environ.get("S3_SECRET_ACCESS_KEY", None)) +s3_endpoint = (getattr(args, 's3_endpoint', None) + or os.environ.get("S3_ENDPOINT_URL", None)) -s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) +credentials = { + "s3_access_key_id": s3_access_key_id, + "s3_secret_access_key": s3_secret_access_key, + "s3_endpoint": s3_endpoint +} _read_stream, _write_stream = (partial( stream_io.open_stream, @@ -263,20 +204,41 @@ def deserialize(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "8080" -torch.distributed.init_process_group(world_size=1, rank=0) +init_distributed_environment(world_size=1, rank=0, local_rank=0) initialize_model_parallel() keyfile = args.keyfile if args.keyfile else None + +if args.model_loader_extra_config: + config = json.loads(args.model_loader_extra_config) + tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args() + tensorizer_args.tensorizer_uri = args.path_to_tensors +else: + tensorizer_args = None + if args.command == "serialize": + eng_args_dict = {f.name: getattr(args, f.name) for f in + dataclasses.fields(EngineArgs)} + + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) + engine = LLMEngine.from_engine_args(engine_args) + input_dir = args.serialized_directory.rstrip('/') suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" model_path = f"{base_path}/model.tensors" - serialize() + tensorizer_config = TensorizerConfig( + tensorizer_uri=model_path, + **credentials) + serialize_vllm_model(engine, tensorizer_config, keyfile) elif args.command == "deserialize": - tensorizer_args = TensorizerArgs.from_cli_args(args) - model_path = args.path_to_tensors + if not tensorizer_args: + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + encryption_keyfile = keyfile, + **credentials + ) deserialize() else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/requirements-dev.txt b/requirements-dev.txt index 796c9e37d0230..4f6c27d95fe6a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ types-setuptools # testing pytest -tensorizer==2.9.0 +tensorizer>=2.9.0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 0dc8818b44a9e..a66af2c5d556f 100644 --- a/setup.py +++ b/setup.py @@ -426,7 +426,7 @@ def _read_requirements(filename: str) -> List[str]: install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "tensorizer": ["tensorizer==2.9.0"], + "tensorizer": ["tensorizer>=2.9.0"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py deleted file mode 100644 index 0e113ab647e67..0000000000000 --- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py +++ /dev/null @@ -1,245 +0,0 @@ -import argparse -import dataclasses -import os -import time -import uuid -from functools import partial -from typing import Type - -import torch.nn as nn -from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, - TensorSerializer, stream_io) -from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor -from transformers import AutoConfig, PretrainedConfig - -from vllm.distributed import (init_distributed_environment, - initialize_model_parallel) -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine -from vllm.model_executor.model_loader.tensorizer import TensorizerArgs -from vllm.model_executor.models import ModelRegistry - -# yapf conflicts with isort for this docstring -# yapf: disable -""" -tensorize_vllm_model.py is a script that can be used to serialize and -deserialize vLLM models. These models can be loaded using tensorizer directly -to the GPU extremely quickly. Tensor encryption and decryption is also -supported, although libsodium must be installed to use it. Install -vllm with tensorizer support using `pip install vllm[tensorizer]`. - -To serialize a model, you can run something like this: - -python tensorize_vllm_model.py \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ - serialize \ - --serialized-directory s3://my-bucket/ \ - --suffix vllm - -Which downloads the model from HuggingFace, loads it into vLLM, serializes it, -and saves it to your S3 bucket. A local directory can also be used. - -You can also encrypt the model weights with a randomly-generated key by -providing a `--keyfile` argument. - -To deserialize a model, you can run something like this: - -python tensorize_vllm_model.py \ - --model EleutherAI/gpt-j-6B \ - --dtype float16 \ - deserialize \ - --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors - -Which downloads the model tensors from your S3 bucket and deserializes them. -To provide S3 credentials, you can provide `--s3-access-key-id` and -`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, -the OpenAI entrypoint, as arguments for LLM(), or as environment variables -in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. - - -You can also provide a `--keyfile` argument to decrypt the model weights if -they were serialized with encryption. - -For more information on the available arguments, run -`python tensorize_vllm_model.py --help`. -""" - - -def parse_args(): - parser = argparse.ArgumentParser( - description="An example script that can be used to serialize and " - "deserialize vLLM models. These models " - "can be loaded using tensorizer directly to the GPU " - "extremely quickly. Tensor encryption and decryption is " - "also supported, although libsodium must be installed to " - "use it.") - parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser)) - subparsers = parser.add_subparsers(dest='command') - - serialize_parser = subparsers.add_parser( - 'serialize', help="Serialize a model to `--serialized-directory`") - - serialize_parser.add_argument( - "--suffix", - type=str, - required=False, - help=( - "The suffix to append to the serialized model directory, which is " - "used to construct the location of the serialized model tensors, " - "e.g. if `--serialized-directory` is `s3://my-bucket/` and " - "`--suffix` is `v1`, the serialized model tensors will be " - "saved to " - "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " - "If none is provided, a random UUID will be used.")) - serialize_parser.add_argument( - "--serialized-directory", - type=str, - required=True) - - serialize_parser.add_argument( - "--keyfile", - type=str, - required=False, - help=("Encrypt the model weights with a randomly-generated binary key," - " and save the key at this path")) - - deserialize_parser = subparsers.add_parser( - 'deserialize', - help=("Deserialize a model from `--path-to-tensors`" - " to verify it can be loaded and used.")) - - deserialize_parser.add_argument( - "--path-to-tensors", - type=str, - required=True, - help="The local path or S3 URI to the model tensors to deserialize. ") - - deserialize_parser.add_argument( - "--keyfile", - type=str, - required=False, - help=("Path to a binary key to use to decrypt the model weights," - " if the model was serialized with encryption")) - - return parser.parse_args() - - -def make_model_contiguous(model): - # Ensure tensors are saved in memory contiguously - for param in model.parameters(): - param.data = param.data.contiguous() - - -def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - architectures = getattr(config, "architectures", []) - for arch in architectures: - model_cls = ModelRegistry.load_model_cls(arch) - if model_cls is not None: - return model_cls - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - -def serialize(): - eng_args_dict = {f.name: getattr(args, f.name) for f in - dataclasses.fields(EngineArgs)} - engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) - engine = LLMEngine.from_engine_args(engine_args) - - model = (engine.model_executor.driver_worker. - model_runner.model) - - encryption_params = EncryptionParams.random() if keyfile else None - if keyfile: - with _write_stream(keyfile) as stream: - stream.write(encryption_params.key) - - with _write_stream(model_path) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - serializer.close() - - print("Serialization complete. Model tensors saved to", model_path) - if keyfile: - print("Key saved to", keyfile) - - -def deserialize(): - config = AutoConfig.from_pretrained(model_ref) - - with no_init_or_tensor(): - model_class = _get_vllm_model_architecture(config) - model = model_class(config) - - before_mem = get_mem_usage() - start = time.time() - - if keyfile: - with _read_stream(keyfile) as stream: - key = stream.read() - decryption_params = DecryptionParams.from_key(key) - tensorizer_args.deserializer_params['encryption'] = \ - decryption_params - - with (_read_stream(model_path)) as stream, TensorDeserializer( - stream, **tensorizer_args.deserializer_params) as deserializer: - deserializer.load_into_module(model) - end = time.time() - - # Brag about how fast we are. - total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) - duration = end - start - per_second = convert_bytes(deserializer.total_tensor_bytes / duration) - after_mem = get_mem_usage() - print( - f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" - ) - print(f"Memory usage before: {before_mem}") - print(f"Memory usage after: {after_mem}") - - return model - - -args = parse_args() - -s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") - or None) -s3_secret_access_key = (args.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY") or None) - -s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) - -_read_stream, _write_stream = (partial( - stream_io.open_stream, - mode=mode, - s3_access_key_id=s3_access_key_id, - s3_secret_access_key=s3_secret_access_key, - s3_endpoint=s3_endpoint, -) for mode in ("rb", "wb+")) - -model_ref = args.model - -model_name = model_ref.split("/")[1] - -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = "8080" - -init_distributed_environment(world_size=1, rank=0, local_rank=0) -initialize_model_parallel() - -keyfile = args.keyfile if args.keyfile else None - -if args.command == "serialize": - input_dir = args.serialized_directory.rstrip('/') - suffix = args.suffix if args.suffix else uuid.uuid4().hex - base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" - model_path = f"{base_path}/model.tensors" - serialize() -elif args.command == "deserialize": - tensorizer_args = TensorizerArgs.from_cli_args(args) - model_path = args.path_to_tensors - deserialize() -else: - raise ValueError("Either serialize or deserialize must be specified.") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index ad4748c5ebe96..1579d53a7fe29 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -10,12 +10,19 @@ import torch from vllm import SamplingParams -from vllm.model_executor.model_loader.tensorizer import ( - EncryptionParams, TensorizerConfig, TensorSerializer, - is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream) +# yapf: disable +from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, + TensorSerializer, + is_vllm_tensorized, + load_with_tensorizer, + open_stream, + serialize_vllm_model) from ..utils import ServerRunner +# yapf conflicts with isort for this docstring + + prompts = [ "Hello, my name is", "The president of the United States is", @@ -40,7 +47,7 @@ def is_curl_installed(): @pytest.fixture(autouse=True) def tensorizer_config(): - config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True) + config = TensorizerConfig(tensorizer_uri="vllm") return config @@ -59,47 +66,6 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): assert result == mock_agent_instance.deserialize.return_value -def test_is_vllm_model_with_vllm_in_uri(tensorizer_config): - tensorizer_config.vllm_tensorized = True - - result = is_vllm_serialized_tensorizer(tensorizer_config) - - assert result is True - - -def test_is_vllm_model_without_vllm_in_uri(tensorizer_config): - tensorizer_config.vllm_tensorized = False - - result = is_vllm_serialized_tensorizer(tensorizer_config) - - assert result is False - - -def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path): - vllm_model = vllm_runner(model_ref) - model_path = tmp_path / (model_ref + ".tensors") - outputs = vllm_model.generate(prompts, sampling_params) - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(model) - del vllm_model, model - gc.collect() - torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner( - model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path, - num_readers=1, - vllm_tensorized=True), - ) - deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) - - # Assumes SamplingParams being seeded ensures the outputs are deterministic - assert outputs == deserialized_outputs - - @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" @@ -110,7 +76,6 @@ def test_can_deserialize_s3(vllm_runner): model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, - vllm_tensorized=False, s3_endpoint="object.ord1.coreweave.com", )) @@ -126,29 +91,26 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( model_path = tmp_path / (model_ref + ".tensors") key_path = tmp_path / (model_ref + ".key") outputs = vllm_model.generate(prompts, sampling_params) - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - encryption_params = EncryptionParams.random() - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) - serializer.write_module(model) - with open_stream(key_path, "wb+") as stream: - stream.write(encryption_params.key) - del vllm_model, model + config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) + serialize_vllm_model(vllm_model.model.llm_engine, + config_for_serializing, + encryption_key_path=key_path) + + del vllm_model gc.collect() torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, - encryption_keyfile=key_path, - num_readers=1, - vllm_tensorized=True)) + + config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, + encryption_keyfile=key_path) + + loaded_vllm_model = vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=config_for_deserializing) deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) - # Assumes SamplingParams being seeded ensures the outputs are deterministic assert outputs == deserialized_outputs @@ -169,7 +131,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - vllm_tensorized=False)) + )) deserialized_outputs = loaded_hf_model.generate_greedy( prompts, max_tokens=max_tokens) @@ -190,12 +152,11 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): # Serialize model before deserializing and binding LoRA adapters vllm_model = vllm_runner(model_ref, ) model_path = tmp_path / (model_ref + ".tensors") - model = (vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(model) - del vllm_model, model + + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) + + del vllm_model gc.collect() torch.cuda.empty_cache() loaded_vllm_model = vllm_runner( @@ -204,7 +165,6 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - vllm_tensorized=True, ), enable_lora=True, max_loras=1, @@ -220,58 +180,28 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): def test_load_without_tensorizer_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, - model_loader_extra_config=TensorizerConfig( - tensorizer_uri="test", vllm_tensorized=False)) - - -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_tensorize_vllm_model(tmp_path): - # Test serialize command - serialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, - "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - print(result.stdout) # Print the output of the serialize command - - assert result.returncode == 0, (f"Serialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") - - path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" - - # Test deserialize command - deserialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "deserialize", "--path-to-tensors", - path_to_tensors - ] - result = subprocess.run(deserialize_args, capture_output=True, text=True) - assert result.returncode == 0, (f"Deserialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") + vllm_runner( + model_ref, + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_openai_apiserver_with_tensorizer(tmp_path): +def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ## Serialize model - serialize_args = [ - "python3", tensorize_model_for_testing_script, "--model", model_ref, - "--dtype", "float16", "serialize", "--serialized-directory", tmp_path, - "--suffix", "tests" - ] - result = subprocess.run(serialize_args, capture_output=True, text=True) - print(result.stdout) # Print the output of the serialize command + vllm_model = vllm_runner(model_ref, ) + model_path = tmp_path / (model_ref + ".tensors") - assert result.returncode == 0, (f"Serialize command failed with output:" - f"\n{result.stdout}\n{result.stderr}") + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) - path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors" model_loader_extra_config = { - "tensorizer_uri": path_to_tensors, - "vllm_tensorized": True + "tensorizer_uri": str(model_path), } + del vllm_model + gc.collect() + torch.cuda.empty_cache() + ## Start OpenAI API server openai_args = [ "--model", model_ref, "--dtype", "float16", "--load-format", @@ -304,10 +234,10 @@ def test_openai_apiserver_with_tensorizer(tmp_path): def test_raise_value_error_on_invalid_load_format(vllm_runner): with pytest.raises(ValueError): - vllm_runner(model_ref, - load_format="safetensors", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri="test", vllm_tensorized=False)) + vllm_runner( + model_ref, + load_format="safetensors", + model_loader_extra_config=TensorizerConfig(tensorizer_uri="test")) def test_tensorizer_with_tp(vllm_runner): @@ -321,8 +251,29 @@ def test_tensorizer_with_tp(vllm_runner): model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, - vllm_tensorized=False, s3_endpoint="object.ord1.coreweave.com", ), tensor_parallel_size=2, ) + + +def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path)) + + vllm_model = vllm_runner(model_ref) + outputs = vllm_model.generate(prompts, sampling_params) + serialize_vllm_model(vllm_model.model.llm_engine, config) + + assert is_vllm_tensorized(config) + del vllm_model + gc.collect() + torch.cuda.empty_cache() + + loaded_vllm_model = vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=config) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + + assert outputs == deserialized_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 163723b4be364..fd5338c46c340 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -167,8 +167,8 @@ def add_cli_args( '* "dummy" will initialize the weights with random values, ' 'which is mainly for profiling.\n' '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave which assumes tensorizer_uri is set to the location of ' - 'the serialized weights.') + 'CoreWeave. See the Tensorize vLLM Model script in the Examples' + 'section for more information.\n') parser.add_argument( '--dtype', type=str, diff --git a/vllm/envs.py b/vllm/envs.py index 91cc8f3be775c..68d8a074d0914 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -145,7 +145,7 @@ # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": - lambda: os.environ.get("S3_ACCESS_KEY", None), + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), "S3_ENDPOINT_URL": diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index fc9c8aa0af44b..b14824a359b6d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) @@ -291,7 +291,7 @@ def _get_weights_iterator( tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) - def _load_model_unserialized( + def _load_model_serialized_cpu( self, model_config: ModelConfig, device_config: DeviceConfig, @@ -299,11 +299,12 @@ def _load_model_unserialized( vision_language_config: Optional[VisionLanguageConfig], cache_config: CacheConfig, ) -> nn.Module: - """Load an unserialized model with tensorizer. + """Load a serialized model with tensorizer to the CPU. - Unserialized here means "not serialized with tensorizer". This - should still be faster than default HuggingFace loading, but will - be slower than loading a tensorizer-serialized model. + This is only necessary when the model isn't vLLM-tensorized (see + examples/tensorize_vllm_model.py) This should still be faster than + default HuggingFace loading, but will be slower than loading a + vLLM-tensorized model. """ with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): @@ -324,8 +325,9 @@ def _load_model_serialized( ) -> nn.Module: """Load a serialized model with tensorizer. - See the examples/tensorize_vllm_model.py example " - script for serializing vLLM models.""" + Expects a vLLM-tensorized model. See the + examples/tensorize_vllm_model.py example script + for serializing vLLM models.""" with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] @@ -353,15 +355,15 @@ def load_model(self, *, model_config: ModelConfig, cache_config: CacheConfig) -> nn.Module: self._verify_config(model_config, parallel_config) - if is_vllm_serialized_tensorizer(self.tensorizer_config): + if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(model_config, device_config, lora_config, vision_language_config, cache_config) - return self._load_model_unserialized(model_config, device_config, - lora_config, - vision_language_config, - cache_config) + return self._load_model_serialized_cpu(model_config, device_config, + lora_config, + vision_language_config, + cache_config) def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 219a2a392e129..2cf4ce5f88521 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -5,6 +5,7 @@ import time import typing from dataclasses import dataclass +from functools import partial from typing import Generator, Optional, Tuple, Type, Union import torch @@ -13,6 +14,7 @@ import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig +from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -27,6 +29,11 @@ from tensorizer.stream_io import open_stream from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) except ImportError as e: tensorizer_error_msg = str(e) @@ -43,7 +50,7 @@ class TensorizerConfig: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, str, bytes, os.PathLike, int] - vllm_tensorized: bool + vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None @@ -93,17 +100,11 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig, return tensorizer.deserialize() -def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool: - if tensorizer_config is None: - return False - return tensorizer_config.vllm_tensorized - - @dataclass class TensorizerArgs: tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, str, bytes, os.PathLike, int] - vllm_tensorized: bool + vllm_tensorized: Optional[bool] = False verify_hash: Optional[bool] = False num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None @@ -121,7 +122,9 @@ class TensorizerArgs: vLLM model. This is used to determine the behavior of the TensorDeserializer when loading tensors from a serialized model. It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. verify_hash: If True, the hashes of each tensor will be verified against the hashes stored in the metadata. A `HashMismatchError` will be raised if any of the hashes do not match. @@ -158,6 +161,7 @@ def __post_init__(self): "encryption": self.encryption_keyfile, "num_readers": self.num_readers } + if self.encryption_keyfile: with open_stream( self.encryption_keyfile, @@ -177,7 +181,14 @@ def add_cli_args( 'tensorizer options', description=('Options for configuring the behavior of the' ' tensorizer deserializer when ' - '--load-format=tensorizer')) + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) group.add_argument( "--tensorizer-uri", @@ -222,13 +233,6 @@ def add_cli_args( help="The endpoint for the S3 bucket. Can also be set via the " "S3_ENDPOINT_URL environment variable.", ) - group.add_argument( - "--vllm-tensorized", - action="store_true", - help="If enabled, indicates that the serialized model is a vLLM " - "model. This is used to determine the behavior of the " - "TensorDeserializer when loading tensors from a " - "serialized model.") return parser @@ -322,10 +326,9 @@ def deserialize(self): """ before_mem = get_mem_usage() start = time.perf_counter() - with open_stream( - self.tensorizer_args.tensorizer_uri, - mode="rb", - **self.tensorizer_args.stream_params, + with _read_stream( + self.tensorizer_config.tensorizer_uri, + **self.tensorizer_args.stream_params ) as stream, TensorDeserializer( stream, dtype=self.tensorizer_config.dtype, @@ -345,6 +348,7 @@ def deserialize(self): self._check_tensors_on_meta_device() self._resize_lora_embeddings() + del self.model.vllm_tensorized_marker return self.model.eval() @@ -366,3 +370,63 @@ def tensorizer_weights_iterator( for name, param in state.items(): yield name, param del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + if (".vllm_tensorized_marker" in deserializer): + return True + return False + + +def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module: + model = (engine.model_executor.driver_worker.model_runner.model) + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + return model + + +def serialize_vllm_model(engine: "LLMEngine", + tensorizer_config : TensorizerConfig, + encryption_key_path: Optional[str] = None) \ + -> nn.Module: + + model = get_pretensorized_vllm_model(engine) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + encryption_params = None + if encryption_key_path is not None: + encryption_params = EncryptionParams.random() + with _write_stream(encryption_key_path, + **tensorizer_args.stream_params) as stream: + stream.write(encryption_params.key) + + with _write_stream(tensorizer_args.tensorizer_uri, + **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", + str(tensorizer_args.tensorizer_uri)) + return model From ce532ff45c8008c7157eb448860c13bcdd44823f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 13 May 2024 15:00:13 -0700 Subject: [PATCH 267/413] [Speculative decoding] Improve n-gram efficiency (#4724) --- tests/spec_decode/test_ngram_worker.py | 17 ++--- vllm/config.py | 13 ++-- vllm/spec_decode/ngram_worker.py | 97 ++++++++++++++------------ vllm/spec_decode/top1_proposer.py | 63 +++++++++++++++++ 4 files changed, 132 insertions(+), 58 deletions(-) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index de305c4030aa9..88b40d1eb4674 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -34,8 +34,8 @@ def test_ngram_algo_correctness_for_single_no_match(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find no candidate @@ -90,8 +90,8 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [1, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find no candidate @@ -128,11 +128,12 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) assert proposals.proposal_lens.shape == torch.Size([5]) + # the first sequence has no match so proposal_len should be overwritten to 0 assert proposals.proposal_lens.tolist( - ) == [proposal_len for _ in range(4)] + [0] + ) == [0] + [proposal_len for _ in range(3)] + [0] for i in range(proposal_len): - assert proposals.proposal_token_ids[0][i] == 0 + assert proposals.proposal_token_ids[0][i] == -1 assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] @@ -167,8 +168,8 @@ def test_ngram_algo_correctness_for_batches_match_all(): max_proposal_len=20, ) - # set ngram window (0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(0, 3) + # set ngram window [0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(1, 3) prompts = [ # shall find candidate 12,13,14,15,16 diff --git a/vllm/config.py b/vllm/config.py index fab9cfbf41a2d..435f47dc9459a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -784,12 +784,15 @@ def maybe_create_spec_config( draft_quantization = None if speculative_model == "[ngram]": - assert (ngram_prompt_lookup_max is not None - and ngram_prompt_lookup_max > 0) if ngram_prompt_lookup_min is None: - ngram_prompt_lookup_min = 0 - else: - assert ngram_prompt_lookup_max > ngram_prompt_lookup_min + ngram_prompt_lookup_min = 1 + if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: + raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") + if ngram_prompt_lookup_min < 1: + raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") + if ngram_prompt_lookup_min > ngram_prompt_lookup_max: + raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " + f"larger than {ngram_prompt_lookup_max=}") # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 6cd50fcc1a041..9628f7af5315a 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -77,9 +77,11 @@ def sampler_output( """ self._raise_if_unsupported(execute_model_req) - arr = [] has_spec_out = False - for seq_group_metadata in execute_model_req.seq_group_metadata_list: + token_id_list = [] + token_prob_list = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): seq_data = next(iter(seq_group_metadata.seq_data.values())) input_ids = torch.as_tensor(seq_data.get_token_ids(), @@ -89,59 +91,64 @@ def sampler_output( for ngram_size in range( min(self.ngram_prompt_lookup_max, input_length - 1), - self.ngram_prompt_lookup_min, + self.ngram_prompt_lookup_min - 1, -1, ): - ngram_tensor = input_ids[-1 * ngram_size:] - windows = input_ids.unfold(dimension=0, - size=ngram_size, - step=1) - matches = (windows == ngram_tensor).all(dim=1) - match_indices = matches.nonzero(as_tuple=True)[0] - if match_indices.size()[0] > 1: + ngram_tensor = input_ids[-ngram_size:] + proposal_start_idx = None + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + # Note that "first_match.values.item()" triggers GPU-CPU + # sync so it is a bit inefficient, but we have not found + # a better way to do this. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=self.device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, index=spec_indices) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) has_spec_out = True - res = seq_data.get_token_ids() - res = res[match_indices[0] + ngram_size:match_indices[0] + - ngram_size + sample_len] - res_len = len(res) - # pad 0 towards output as sample_len tokens required - res += [0] * (sample_len - res_len) - break else: - # if no candidate found, fill with 0 - res = [0] * sample_len - - arr.append(res) + token_id_list.append(None) + token_prob_list.append(None) if not has_spec_out: return None, False - outputs = [] - token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device) - indices = token_ids.unsqueeze(2) + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) - token_probs = torch.zeros( - (len(execute_model_req.seq_group_metadata_list), sample_len, - self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - token_probs.scatter_(2, indices, 1) - token_logprobs = torch.zeros( - (len(execute_model_req.seq_group_metadata_list), sample_len, - self.vocab_size), - dtype=torch.float32, - device=self.device, - ) - for i in range(len(execute_model_req.seq_group_metadata_list)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_probs[i], - logprobs=token_logprobs[i], - sampled_token_ids=token_ids[i], - )) return outputs, False def get_spec_proposals( diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index ee9462b68dae8..6c7e22207f6b2 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -73,6 +73,14 @@ def get_proposals( execute_model_req=nonzero_execute_model_req, sample_len=proposal_len, ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) else: # If no sequences can be speculated, set sampler output to None. maybe_sampler_output = None @@ -140,6 +148,61 @@ def _split_by_proposal_len( nonzero_proposal_len_indices, ) + def _remove_no_proposal_seqs(self, proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + def _merge_outputs( self, batch_size: int, From 1356df53bd5d6877358aff3d2bbd95f28f8009a4 Mon Sep 17 00:00:00 2001 From: Stephen Krider <72541272+skrider@users.noreply.github.com> Date: Mon, 13 May 2024 15:50:33 -0700 Subject: [PATCH 268/413] [Kernel] Use flash-attn for decoding (#3648) Co-authored-by: Woosuk Kwon Co-authored-by: LiuXiaoxuanPKU --- tests/kernels/test_flash_attn.py | 209 ++++++++++++++++++++++++++ tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 128 +++++++++------- vllm/attention/selector.py | 14 ++ vllm/worker/model_runner.py | 15 +- 6 files changed, 313 insertions(+), 65 deletions(-) create mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py new file mode 100644 index 0000000000000..89bdacc67fbc4 --- /dev/null +++ b/tests/kernels/test_flash_attn.py @@ -0,0 +1,209 @@ +from typing import List, Optional, Tuple + +import pytest +import torch +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_flash_attn_with_paged_kv( + kv_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_blocks = 128 + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + ).squeeze(1) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_blocks = 128 + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window, + sliding_window) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index c02204f16ac68..10e7c64e34e75 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - "mosaicml/mpt-7b", + # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index e87a1783a83f1..664e951a89f2a 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f59715bd76ede..11ecb2792ea9d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,20 +1,16 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) + +_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] class FlashAttentionBackend(AttentionBackend): @@ -38,8 +34,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -47,19 +44,26 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadataPerStage): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -105,6 +109,14 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + class FlashAttentionImpl(AttentionImpl): """ @@ -156,11 +168,15 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") def forward( self, @@ -171,17 +187,20 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -189,16 +208,20 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache = kv_cache[0] + value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -218,7 +241,8 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -239,38 +263,32 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - output[:num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.subquery_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.subquery_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, ) + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, + output[num_prefill_tokens:] = flash_attn_with_kvcache( + decode_query.unsqueeze(1), key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_scale, - ) + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 06f99718a4dee..5140c3cc86a31 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,6 +93,20 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") + return _Backend.XFORMERS + + if block_size % 16 != 0: + logger.info("Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + return _Backend.XFORMERS + + if sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e1991717b13..3f7e87c1de48c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -266,20 +266,27 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) else: # The first prefill. - prefix_block_tables.append([]) + block_table = [] else: - prefix_block_tables.append([]) + block_table = [] # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 + prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) From 33d3914b1e6d85a855da1a69193030c1915cb6f9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 13 May 2024 16:00:27 -0700 Subject: [PATCH 269/413] [Bugfix] Fix dynamic FP8 quantization for Mixtral (#4793) --- vllm/model_executor/models/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 113abbaa6036d..e3ac33e0452fe 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -95,7 +95,7 @@ def __init__( params_dtype=self.params_dtype, quant_config=None) - if self.use_fp8: + if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn self.w13_weight = nn.Parameter( From ac1fbf7fd2d1fdddc7b4953eeb3acae35c62766f Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 13 May 2024 16:23:54 -0700 Subject: [PATCH 270/413] [Doc] Shorten README by removing supported model list (#4796) --- README.md | 47 +++++-------------------- docs/source/models/supported_models.rst | 25 +++++++++---- 2 files changed, 28 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 524d027137aba..b704441eada9f 100644 --- a/README.md +++ b/README.md @@ -51,41 +51,14 @@ vLLM is flexible and easy to use with: - (Experimental) Prefix caching support - (Experimental) Multi-lora support -vLLM seamlessly supports many Hugging Face models, including the following architectures: - -- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) -- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) -- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) -- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) -- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.) -- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) -- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) -- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) -- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) -- GPT-2 (`gpt2`, `gpt2-xl`, etc.) -- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) -- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) -- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) -- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) -- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.) -- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.) -- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) -- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) -- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) -- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.) -- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) -- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.) -- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.) -- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) -- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.) -- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) -- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.) -- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.) -- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) -- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.) -- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.) -- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) +vLLM seamlessly supports most popular open-source models on HuggingFace, including: +- Transformer-like LLMs (e.g., Llama) +- Mixture-of-Expert LLMs (e.g., Mixtral) +- Multi-modal LLMs (e.g., LLaVA) + +Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). + +## Getting Started Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): @@ -93,9 +66,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get pip install vllm ``` -## Getting Started - -Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started. +Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ceb658bbd5c66..142c8f8573e2f 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -16,13 +16,21 @@ Alongside each architecture, we include some popular models that use it. - Example HuggingFace Models - :ref:`LoRA ` * - :code:`AquilaForCausalLM` - - Aquila + - Aquila & Aquila2 - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc. - ✅︎ + * - :code:`ArcticForCausalLM` + - Arctic + - :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc. + - * - :code:`BaiChuanForCausalLM` - - Baichuan + - Baichuan & Baichuan2 - :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc. - ✅︎ + * - :code:`BloomForCausalLM` + - BLOOM, BLOOMZ, BLOOMChat + - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. + - * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. @@ -39,10 +47,6 @@ Alongside each architecture, we include some popular models that use it. - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - - * - :code:`BloomForCausalLM` - - BLOOM, BLOOMZ, BLOOMChat - - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - - * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. @@ -135,6 +139,15 @@ Alongside each architecture, we include some popular models that use it. - StableLM - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - + * - :code:`Starcoder2ForCausalLM` + - Starcoder2 + - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. + - + * - :code:`XverseForCausalLM` + - Xverse + - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. + - + If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. From 4bfa7e7f75eb5b1a397c93aeea1dea1afa867b2a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 08:47:42 +0800 Subject: [PATCH 271/413] [Doc] Add API reference for offline inference (#4710) --- docs/source/index.rst | 8 +++++++- docs/source/offline_inference/llm.rst | 6 ++++++ .../source/{dev => offline_inference}/sampling_params.rst | 4 ++-- docs/source/serving/openai_compatible_server.md | 4 ++-- 4 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 docs/source/offline_inference/llm.rst rename docs/source/{dev => offline_inference}/sampling_params.rst (55%) diff --git a/docs/source/index.rst b/docs/source/index.rst index 4022c590843e6..e1e81778dbdb7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -67,6 +67,13 @@ Documentation getting_started/quickstart getting_started/examples/examples_index +.. toctree:: + :maxdepth: 1 + :caption: Offline Inference + + offline_inference/llm + offline_inference/sampling_params + .. toctree:: :maxdepth: 1 :caption: Serving @@ -101,7 +108,6 @@ Documentation :maxdepth: 2 :caption: Developer Documentation - dev/sampling_params dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile diff --git a/docs/source/offline_inference/llm.rst b/docs/source/offline_inference/llm.rst new file mode 100644 index 0000000000000..1a443ea406994 --- /dev/null +++ b/docs/source/offline_inference/llm.rst @@ -0,0 +1,6 @@ +LLM Class +========== + +.. autoclass:: vllm.LLM + :members: + :show-inheritance: diff --git a/docs/source/dev/sampling_params.rst b/docs/source/offline_inference/sampling_params.rst similarity index 55% rename from docs/source/dev/sampling_params.rst rename to docs/source/offline_inference/sampling_params.rst index ef3d1509bda6d..f645941a6c022 100644 --- a/docs/source/dev/sampling_params.rst +++ b/docs/source/offline_inference/sampling_params.rst @@ -1,5 +1,5 @@ -Sampling Params -=============== +Sampling Parameters +=================== .. autoclass:: vllm.SamplingParams :members: diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 15a8761eb5738..a775c6addf1d9 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python From c579b750a083931ad03ecac898aca5ad67c6c59c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 13 May 2024 18:48:00 -0700 Subject: [PATCH 272/413] [Doc] Add meetups to the doc (#4798) --- docs/source/community/meetups.rst | 12 ++++++++++++ docs/source/index.rst | 7 +++++++ 2 files changed, 19 insertions(+) create mode 100644 docs/source/community/meetups.rst diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst new file mode 100644 index 0000000000000..fa1a26521814d --- /dev/null +++ b/docs/source/community/meetups.rst @@ -0,0 +1,12 @@ +.. _meetups: + +vLLM Meetups +============ + +We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: + +- `The third vLLM meetup `_, with Roblox, April 2nd 2024, `slides `_. +- `The second vLLM meetup `_, with IBM Research, January 31st 2024, `slides `_, `video (vLLM Update) `_, `video (IBM Research & torch.compile) `_. +- `The first vLLM meetup `_, with a16z, October 5th, 2023, `slides `_. + +We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `_. diff --git a/docs/source/index.rst b/docs/source/index.rst index e1e81778dbdb7..bab00e28e4018 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,6 +50,7 @@ For more information, check out the following: * `vLLM announcing blog post `_ (intro to PagedAttention) * `vLLM paper `_ (SOSP 2023) * `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency `_ by Cade Daniel et al. +* :ref:`vLLM Meetups `. @@ -112,6 +113,12 @@ Documentation dev/kernel/paged_attention dev/dockerfile/dockerfile +.. toctree:: + :maxdepth: 2 + :caption: Community + + community/meetups + Indices and tables ================== From ccb63a8245bceb9e6ba260eeef41b54ca8bdb370 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 14 May 2024 05:34:33 -0700 Subject: [PATCH 273/413] [Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies (#4696) --- benchmarks/overheads/benchmark_hashing.py | 63 +++++++++++++++++++++++ vllm/sequence.py | 16 +++++- 2 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 benchmarks/overheads/benchmark_hashing.py diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py new file mode 100644 index 0000000000000..c846e47de1fcf --- /dev/null +++ b/benchmarks/overheads/benchmark_hashing.py @@ -0,0 +1,63 @@ +import argparse +import cProfile +import pstats + +from vllm import LLM, SamplingParams + +# A very long prompt, total number of tokens is about 15k. +LONG_PROMPT = ["You are an expert in large language models, aren't you?" + ] * 1000 +LONG_PROMPT = ' '.join(LONG_PROMPT) + + +def main(args): + llm = LLM( + model=args.model, + enforce_eager=True, + enable_prefix_caching=True, + tensor_parallel_size=args.tensor_parallel_size, + use_v2_block_manager=args.use_v2_block_manager, + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + profiler = cProfile.Profile() + + print("------warm up------") + for i in range(3): + output = llm.generate(LONG_PROMPT, sampling_params) + print(output[0].outputs[0].text) + + print("------start generating------") + for i in range(3): + profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', + globals(), locals()) + + # analyze the runtime of hashing function + stats = pstats.Stats(profiler) + stats.sort_stats('cumulative') + total_time = 0 + total_calls = 0 + for func in stats.stats: + if 'hash_of_block' in func[2]: + total_time = stats.stats[func][3] + total_calls = stats.stats[func][0] + percentage = (total_time / stats.total_tt) * 100 + print(f"Hashing took {total_time:.2f} seconds," + f"{percentage:.2f}% of the total runtime.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Benchmark the performance of hashing function in' + 'automatic prefix caching.') + parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') + args = parser.parse_args() + main(args) diff --git a/vllm/sequence.py b/vllm/sequence.py index 46ac33b7ecabd..12e930c27173e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -121,6 +121,7 @@ def __init__( output_token_ids = [] self.prompt_token_ids = prompt_token_ids + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). @@ -143,6 +144,17 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_prefix_token_ids( + self, num_tokens: int + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = len(self.prompt_token_ids) + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self.output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens @@ -253,8 +265,8 @@ def hash_of_block(self, logical_idx: int) -> int: # TODO: The current hashing function is O(L^2). We should optimize # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) - return hash( - (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) + hashed_tokens = self.data.get_prefix_token_ids(num_tokens) + return hash((hashed_tokens, self.lora_int_id)) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size From dc72402b5785a6ffadff59d4e018661278d4b028 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 00:57:08 +0800 Subject: [PATCH 274/413] [Bugfix][Doc] Fix CI failure in docs (#4804) This PR fixes the CI failure introduced by #4798. The failure originates from having duplicate target names in reST, and is fixed by changing the ref targets to anonymous ones. For more information, see this discussion. I have also changed the format of the links to be more distinct from each other. --- docs/source/community/meetups.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst index fa1a26521814d..f371194781de3 100644 --- a/docs/source/community/meetups.rst +++ b/docs/source/community/meetups.rst @@ -5,8 +5,8 @@ vLLM Meetups We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: -- `The third vLLM meetup `_, with Roblox, April 2nd 2024, `slides `_. -- `The second vLLM meetup `_, with IBM Research, January 31st 2024, `slides `_, `video (vLLM Update) `_, `video (IBM Research & torch.compile) `_. -- `The first vLLM meetup `_, with a16z, October 5th, 2023, `slides `_. +- `The third vLLM meetup `__, with Roblox, April 2nd 2024. `[Slides] `__ +- `The second vLLM meetup `__, with IBM Research, January 31st 2024. `[Slides] `__ `[Video (vLLM Update)] `__ `[Video (IBM Research & torch.compile)] `__ +- `The first vLLM meetup `__, with a16z, October 5th 2023. `[Slides] `__ -We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `_. +We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu `__. From 676a99982fe9aabe72fd52a91e08988a653a7359 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 14 May 2024 10:38:59 -0700 Subject: [PATCH 275/413] [Core] Add MultiprocessingGPUExecutor (#4539) Co-authored-by: SAHIL SUNEJA --- .buildkite/test-pipeline.yaml | 12 +- .../test_basic_distributed_correctness.py | 17 ++- .../test_chunked_prefill_distributed.py | 4 + tests/lora/test_mixtral.py | 3 +- vllm/config.py | 38 +++-- vllm/engine/arg_utils.py | 16 +- vllm/engine/async_llm_engine.py | 16 +- vllm/engine/llm_engine.py | 10 +- vllm/executor/multiproc_gpu_executor.py | 140 ++++++++++++++++++ vllm/executor/ray_gpu_executor.py | 4 +- vllm/executor/ray_utils.py | 4 +- 11 files changed, 225 insertions(+), 39 deletions(-) create mode 100644 vllm/executor/multiproc_gpu_executor.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3c3da41c3abf3..eeb6eaa2165bc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -34,10 +34,14 @@ steps: mirror_hardwares: [amd] commands: - pytest -v -s distributed/test_pynccl_library.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_basic_distributed_correctness.py - - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s distributed/test_chunked_prefill_distributed.py - - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py + - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - label: Distributed Tests (Multiple Groups) working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index d63f015511ada..3ba5cea389c2f 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -20,6 +20,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @@ -36,19 +37,21 @@ def test_models( dtype: str, max_tokens: int, ) -> None: - enforce_eager = False + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) - if backend_by_env_var == "FLASHINFER": - enforce_eager = True + enforce_eager = backend_by_env_var == "FLASHINFER" hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, - dtype=dtype, - tensor_parallel_size=2, - enforce_eager=enforce_eager) + vllm_model = vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager, + distributed_executor_backend=distributed_executor_backend) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index 737b1f3169519..db938cc613c6b 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -19,6 +19,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -36,6 +37,8 @@ def test_models( max_tokens: int, chunked_prefill_token_size: int, ) -> None: + distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) + # Add a chunked prefill config. max_num_seqs = min(chunked_prefill_token_size, 256) assert chunked_prefill_token_size != -1 @@ -53,6 +56,7 @@ def test_models( max_num_seqs=max_num_seqs, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index 4d74722aaa926..f6a8a50fa9e50 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -38,8 +38,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size): enable_lora=True, max_num_seqs=16, max_loras=4, - tensor_parallel_size=tp_size, - worker_use_ray=True) + tensor_parallel_size=tp_size) expected_lora_output = [ "give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501 diff --git a/vllm/config.py b/vllm/config.py index 435f47dc9459a..26edd4567b9ac 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -521,9 +521,7 @@ class ParallelConfig: Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Whether to use Ray for model workers. Will be set to - True if either pipeline_parallel_size or tensor_parallel_size is - greater than 1. + worker_use_ray: Deprecated, use distributed_executor_backend instead. max_parallel_loading_workers: Maximum number of multiple batches when load model sequentially. To avoid RAM OOM when using tensor parallel and large models. @@ -533,22 +531,27 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + distributed_executor_backend: Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If either + pipeline_parallel_size or tensor_parallel_size is greater than 1, + will default to "ray" if Ray is installed or "mp" otherwise. """ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, - worker_use_ray: bool, + worker_use_ray: Optional[bool] = None, max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, + distributed_executor_backend: Optional[str] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size - self.worker_use_ray = worker_use_ray + self.distributed_executor_backend = distributed_executor_backend self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce self.tokenizer_pool_config = tokenizer_pool_config @@ -556,14 +559,29 @@ def __init__( self.placement_group = placement_group self.world_size = pipeline_parallel_size * self.tensor_parallel_size - if self.world_size > 1: - self.worker_use_ray = True + if worker_use_ray: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + elif self.distributed_executor_backend != "ray": + raise ValueError(f"worker-use-ray can't be used with " + f"distributed executor backend " + f"'{self.distributed_executor_backend}'.") + + if self.distributed_executor_backend is None and self.world_size > 1: + from vllm.executor import ray_utils + ray_found = ray_utils.ray is not None + self.distributed_executor_backend = "ray" if ray_found else "mp" + self._verify_args() def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if self.distributed_executor_backend not in ("ray", "mp", None): + raise ValueError( + "Unrecognized distributed executor backend. Supported values " + "are 'ray' or 'mp'.") if not self.disable_custom_all_reduce and self.world_size > 1: if is_hip(): self.disable_custom_all_reduce = True @@ -575,7 +593,8 @@ def _verify_args(self) -> None: logger.info( "Disabled the custom all-reduce kernel because it is not " "supported with pipeline parallelism.") - if self.ray_workers_use_nsight and not self.worker_use_ray: + if self.ray_workers_use_nsight and ( + not self.distributed_executor_backend == "ray"): raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") @@ -887,7 +906,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=target_parallel_config.tensor_parallel_size, - worker_use_ray=target_parallel_config.worker_use_ray, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. max_parallel_loading_workers, disable_custom_all_reduce=target_parallel_config. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fd5338c46c340..195d9e1b33e3c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ class EngineArgs: seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False + distributed_executor_backend: Optional[str] = None pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -221,10 +222,17 @@ def add_cli_args( ' Can be overridden per request via guided_decoding_backend' ' parameter.') # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='Use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--worker-use-ray', + action='store_true', + help='Deprecated, use --distributed-executor-backend=ray.') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a31f10b7748d3..8a37bac02823a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -348,27 +348,31 @@ def from_engine_args( """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "cpu": - assert not engine_config.parallel_config.worker_use_ray, ( - "Ray is not supported with the CPU backend.") + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. engine = cls( - engine_config.parallel_config.worker_use_ray, + distributed_executor_backend == "ray", engine_args.engine_use_ray, **engine_config.to_dict(), executor_class=executor_class, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e258a3f4afd54..f6a5284093c1c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -274,6 +274,8 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config() + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": @@ -282,13 +284,15 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor - elif engine_config.parallel_config.worker_use_ray: + elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + executor_class = MultiprocessingGPUExecutor else: - assert engine_config.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 0000000000000..2a7b99c9dcbe1 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,140 @@ +import asyncio +import os +from functools import partial +from typing import Any, Dict, Optional, Tuple + +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +logger = init_logger(__name__) + + +class MultiprocessingGPUExecutor(DistributedGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + def _init_executor(self) -> None: + assert ( + not self.speculative_config + ), "Speculative decoding not yet supported for MultiProcGPU backend." + + # Create the parallel GPU workers. + world_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = (",".join( + map(str, range(world_size)))) + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + from torch.cuda import device_count + assert world_size <= device_count(), ( + "please set tensor_parallel_size to less than max local gpu count") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + + if world_size == 1: + self.workers = [] + else: + result_handler = ResultHandler() + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*driver_args, + **driver_kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if not self.worker_monitor.is_alive(): + raise RuntimeError("Worker processes are not running") + + +class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, + DistributedGPUExecutorAsync): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + driver_executor = make_async(getattr(self.driver_worker, method)) + + # Run all the workers asynchronously. + coros = [driver_executor(*driver_args, **driver_kwargs)] + [ + worker.execute_method_async(method, *args, **kwargs) + for worker in self.workers + ] + + return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index afc1c886722e6..9cb03ec8c3f5a 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -31,7 +31,7 @@ def _init_executor(self) -> None: assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection. @@ -264,7 +264,7 @@ def _compiled_ray_dag(self): f"required, but found {current_version}") from ray.dag import InputNode, MultiOutputNode - assert self.parallel_config.worker_use_ray + assert self.parallel_config.distributed_executor_backend == "ray" # Right now, compiled DAG requires at least 1 arg. We send # a dummy value for now. It will be fixed soon. diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 9db3ae2ff8298..4704f5f1b1a10 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -44,7 +44,7 @@ def execute_model_compiled_dag_remote(self, ignored): except ImportError as e: logger.warning( - "Failed to import Ray with %r. For distributed inference, " + "Failed to import Ray with %r. For multi-node inference, " "please install Ray with `pip install ray`.", e) ray = None # type: ignore RayWorkerWrapper = None # type: ignore @@ -67,7 +67,7 @@ def initialize_ray_cluster( """ if ray is None: raise ImportError( - "Ray is not installed. Please install Ray to use distributed " + "Ray is not installed. Please install Ray to use multi-node " "serving.") # Connect to a ray cluster. From 29bc01bf3bc26642e4cee15ebd36a6ce5799326d Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 14 May 2024 15:33:06 -0700 Subject: [PATCH 276/413] Add 4th meetup announcement to readme (#4817) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b704441eada9f..4fe5d9630f858 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Easy, fast, and cheap LLM serving for everyone

*Latest News* 🔥 +- [2024/05] We are hosting [the fourth vLLM meetup](https://lu.ma/event/manage/evt-A064fGpj52fviSn) with BentoML and Cloudflare on June 11! Please register [here](https://lu.ma/agivllm). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] Added ROCm 6.0 support to vLLM. From 8a7cc254a064b8d42bf4de7a9c3f29552240dfd9 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 15 May 2024 11:52:45 +0900 Subject: [PATCH 277/413] Revert "[Kernel] Use flash-attn for decoding (#3648)" (#4820) Lora 3 & 4 test seems to have illegal memory access failure after this commit; [2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241 This reverts commit 1356df5. FILL IN THE PR DESCRIPTION HERE FIX #xxxx (link existing issues this PR will resolve) --- tests/kernels/test_flash_attn.py | 209 -------------------------- tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 128 +++++++--------- vllm/attention/selector.py | 14 -- vllm/worker/model_runner.py | 15 +- 6 files changed, 65 insertions(+), 313 deletions(-) delete mode 100644 tests/kernels/test_flash_attn.py diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py deleted file mode 100644 index 89bdacc67fbc4..0000000000000 --- a/tests/kernels/test_flash_attn.py +++ /dev/null @@ -1,209 +0,0 @@ -from typing import List, Optional, Tuple - -import pytest -import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache - -NUM_HEADS = [(16, 16), (32, 8), (64, 8)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 32] -DTYPES = [torch.float16, torch.bfloat16] - - -def ref_paged_attn( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - query_lens: List[int], - kv_lens: List[int], - block_tables: torch.Tensor, - scale: float, - sliding_window: Optional[int] = None, -) -> torch.Tensor: - num_seqs = len(query_lens) - block_tables = block_tables.cpu().numpy() - _, block_size, num_kv_heads, head_size = key_cache.shape - - outputs = [] - start_idx = 0 - for i in range(num_seqs): - query_len = query_lens[i] - kv_len = kv_lens[i] - q = query[start_idx:start_idx + query_len] - q *= scale - - num_kv_blocks = (kv_len + block_size - 1) // block_size - block_indices = block_tables[i, :num_kv_blocks] - - k = key_cache[block_indices].view(-1, num_kv_heads, head_size) - k = k[:kv_len] - v = value_cache[block_indices].view(-1, num_kv_heads, head_size) - v = v[:kv_len] - - if q.shape[1] != k.shape[1]: - k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) - v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) - attn = torch.einsum("qhd,khd->hqk", q, k).float() - empty_mask = torch.ones(query_len, kv_len) - mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() - if sliding_window is not None: - sliding_window_mask = torch.triu(empty_mask, - diagonal=kv_len - - (query_len + sliding_window) + - 1).bool().logical_not() - mask |= sliding_window_mask - attn.masked_fill_(mask, float("-inf")) - attn = torch.softmax(attn, dim=-1).to(v.dtype) - out = torch.einsum("hqk,khd->qhd", attn, v) - - outputs.append(out) - start_idx += query_len - - return torch.cat(outputs, dim=0) - - -@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_flash_attn_with_paged_kv( - kv_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(kv_lens) - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_kv_len = max(kv_lens) - scale = head_size**-0.5 - - query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - ).squeeze(1) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" - - -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) -@pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode -def test_varlen_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - sliding_window: Optional[int], - dtype: torch.dtype, - block_size: int, -) -> None: - torch.set_default_device("cuda") - torch.cuda.manual_seed_all(0) - num_blocks = 128 - num_seqs = len(seq_lens) - query_lens = [x[0] for x in seq_lens] - kv_lens = [x[1] for x in seq_lens] - num_query_heads = num_heads[0] - num_kv_heads = num_heads[1] - assert num_query_heads % num_kv_heads == 0 - max_query_len = max(query_lens) - max_kv_len = max(kv_lens) - window_size = ((sliding_window, - sliding_window) if sliding_window is not None else - (-1, -1)) - scale = head_size**-0.5 - - query = torch.randn(sum(query_lens), - num_query_heads, - head_size, - dtype=dtype) - key_cache = torch.randn(num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype) - value_cache = torch.randn_like(key_cache) - # Normalize the scale of the key and value caches to mitigate - # numerical instability. - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 - cu_query_lens = torch.tensor([0] + query_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - cu_kv_lens = torch.tensor([0] + kv_lens, - dtype=torch.int32).cumsum(dim=0, - dtype=torch.int32) - - max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size - block_tables = torch.randint(0, - num_blocks, - (num_seqs, max_num_blocks_per_seq), - dtype=torch.int32) - - output = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - ) - - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - sliding_window=sliding_window, - ) - assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ - f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 10e7c64e34e75..c02204f16ac68 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - # "mosaicml/mpt-7b", # Broken + "mosaicml/mpt-7b", # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..e87a1783a83f1 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 11ecb2792ea9d..f59715bd76ede 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,16 +1,20 @@ -"""Attention layer with FlashAttention.""" +"""Attention layer with Flash and PagedAttention. + +NOTE(woosuk): At the moment, this file includes a lot of duplicated code from +XFormers backend. The duplicated code will be removed once we use flash-attn or +flashinfer for all the attention operations. +""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm_flash_attn import flash_attn_varlen_func -from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) - -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) class FlashAttentionBackend(AttentionBackend): @@ -34,9 +38,8 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -44,26 +47,19 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + PagedAttention.copy_blocks(kv_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage): +class FlashAttentionMetadata(AttentionMetadataPerStage, + PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -109,14 +105,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - class FlashAttentionImpl(AttentionImpl): """ @@ -168,15 +156,11 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") def forward( self, @@ -187,20 +171,17 @@ def forward( attn_metadata: AttentionMetadata[FlashAttentionMetadata], kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention. + """Forward pass with FlashAttention and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -208,20 +189,16 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache = kv_cache[0] - value_cache = kv_cache[1] + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - ) + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -241,8 +218,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): + if kv_cache is None or prefill_meta.block_tables.numel() == 0: # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -263,32 +239,38 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - output[:num_prefill_tokens] = flash_attn_varlen_func( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.subquery_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=prefill_meta.max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], ) - if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( - decode_query.unsqueeze(1), + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, key_cache, value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - ).squeeze(1) + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a31..06f99718a4dee 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,20 +93,6 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - try: import vllm_flash_attn # noqa: F401 except ImportError: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3f7e87c1de48c..b5e1991717b13 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -266,27 +266,20 @@ def _prepare_prompt( # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size prompt_tokens = prompt_tokens[context_len:] - if self.attn_backend.get_name() == "flash-attn": - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - # TODO(woosuk): This is a temporary fix. We should - # provide a unified interface for different backends. - block_table = seq_group_metadata.block_tables[seq_id] - else: - block_table = computed_block_nums + prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: # Prefill has chunked before. block_table = seq_group_metadata.block_tables[seq_id] + prefix_block_tables.append(block_table) else: # The first prefill. - block_table = [] + prefix_block_tables.append([]) else: - block_table = [] + prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. assert context_len == 0 - prefix_block_tables.append(block_table) # actual prompt lens context_lens.append(context_len) From 65bf2ac165734fb6339210c4b2b8ce68d2391b77 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 15 May 2024 14:00:10 +0900 Subject: [PATCH 278/413] [Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681) This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR --- tests/worker/test_model_runner.py | 123 ++- vllm/attention/__init__.py | 5 +- vllm/attention/backends/abstract.py | 68 +- vllm/attention/backends/flash_attn.py | 95 ++- vllm/attention/backends/flashinfer.py | 38 +- vllm/attention/backends/rocm_flash_attn.py | 98 ++- vllm/attention/backends/torch_sdpa.py | 28 +- vllm/attention/backends/xformers.py | 92 ++- vllm/attention/layer.py | 5 +- vllm/attention/ops/paged_attn.py | 10 +- vllm/engine/arg_utils.py | 5 + .../layers/rejection_sampler.py | 1 + vllm/sequence.py | 3 +- vllm/spec_decode/batch_expansion.py | 23 +- vllm/spec_decode/multi_step_worker.py | 1 + vllm/worker/cpu_model_runner.py | 10 +- vllm/worker/embedding_model_runner.py | 130 +-- vllm/worker/model_runner.py | 772 ++++++++---------- 18 files changed, 777 insertions(+), 730 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index c2d1c5769619b..92de545acd53d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices.append(selected_token_start_idx + seq_len - 1) selected_token_start_idx += seq_len - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + slot_mapping = model_input.slot_mapping assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is True + assert attn_metadata.num_prefills > 0 + assert attn_metadata.num_decode_tokens == 0 assert torch.allclose( attn_metadata.seq_lens_tensor, torch.tensor(seq_lens, device=device, dtype=torch.int)) assert attn_metadata.seq_lens == seq_lens - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.max_prefill_seq_len == max(seq_lens) + assert attn_metadata.max_decode_seq_len == 0 # Test subquery start locs. start_idx = 0 @@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size): start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( - attn_metadata.subquery_start_loc, + attn_metadata.query_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) # Test seq start locs. Note that for normal prefill it is - # equivalent to subquery_start_loc. + # equivalent to query_start_loc. start_idx = 0 seq_start_loc = [start_idx] for seq_len in seq_lens: @@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size): device=actual.device, dtype=actual.dtype) torch.testing.assert_close(actual, expected) - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, @@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size): enable_chunked_prefill=False, ) - seq_lens = [] + context_lens = [] seq_group_metadata_list = [] + # Assume each seq group finishes prefill. for i in range(batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - seq_lens.append(seq_len) - seq_data = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + context_lens.append(context_len) + seq_data = list(range(context_len)) seq_data = SequenceData(seq_data) + seq_data.update_num_computed_tokens(context_len) + # Append one token ID since prefill is finished. + seq_data.append_token_id(1, 0) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size): assert seq_group_metadata.token_chunk_size == 1 seq_group_metadata_list.append(seq_group_metadata) - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, model_input.input_positions, + model_input.attn_metadata, model_input.slot_mapping) assert len(slot_mapping) == len(input_tokens) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device - assert attn_metadata.is_prompt is False - assert attn_metadata.seq_lens is None - assert attn_metadata.subquery_start_loc is None - assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_seq_len == max(seq_lens) + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + seq_lens = [context_len + 1 for context_len in context_lens] + # seq_lens are padded to expected_bs + for _ in range(expected_bs - len(seq_lens)): + seq_lens.append(1) + assert attn_metadata.seq_lens == seq_lens + start_idx = 0 + start_loc = [start_idx] + for _ in context_lens: + # decode has only 1 token for query. + start_idx += 1 + start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.query_start_loc, + torch.tensor(start_loc, dtype=torch.int32, device=device)) + + start_idx = 0 + seq_start_loc = [start_idx] + for seq_len in seq_lens: + start_idx += seq_len + seq_start_loc.append(start_idx) + assert torch.allclose( + attn_metadata.seq_start_loc, + torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) + + assert torch.allclose( + attn_metadata.context_lens_tensor, + torch.tensor(context_lens, dtype=torch.int, device=device)) + assert attn_metadata.max_decode_seq_len == max(seq_lens) assert torch.allclose( attn_metadata.seq_lens_tensor[:len(seq_lens)], torch.tensor(seq_lens, dtype=torch.int, device=device)) @@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size): # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) - # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is True assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - assert input_tokens == input_positions + torch.allclose(input_tokens, input_positions) # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for seq_len in seq_lens: + for _ in context_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, - query_lens=seq_lens, + # query lens is all 1 for decode. + query_lens=[1 for _ in range(len(context_lens))], device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -220,15 +257,27 @@ def test_empty_seq_group(): enforce_eager=False, ) seq_group_metadata_list = [] - input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = ( - model_runner._prepare_decode(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, slot_mapping = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, - _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + model_input = model_runner._prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata, slot_mapping, + return_seq_lens) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + model_input.slot_mapping, + model_input.seq_lens, + ) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None @@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - seq_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(seq_len)) + context_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(context_len)) seq_data = SequenceData(prompt_toks) + seq_data.append_token_id(1, 0) + seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert len(attn_metadata.slot_mapping) == len(input_tokens) assert len(input_positions) == len(input_tokens) assert attn_metadata.num_prefills == prefill_batch_size - if enforce_eager: - assert attn_metadata.num_decode_tokens == decode_batch_size - else: - assert attn_metadata.num_decode_tokens == _get_graph_batch_size( - decode_batch_size) + assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. - prefill_meta = model_runner._prepare_prompt( - prefill_metadata_list).attn_metadata - decode_meta = model_runner._prepare_decode( - decode_metadata_list).attn_metadata + attn_metadata = model_runner._prepare_model_input( + seq_group_metadata_list).attn_metadata - for attr_expected, attr_actual in zip(vars(prefill_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata), vars(prefill_meta_actual)): assert attr_expected[1] == attr_actual[1] - for attr_expected, attr_actual in zip(vars(decode_meta), + for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): assert attr_expected[1] == attr_actual[1] diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py index 088f48def7668..f6bce9a187c64 100644 --- a/vllm/attention/__init__.py +++ b/vllm/attention/__init__.py @@ -1,6 +1,5 @@ from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend @@ -8,6 +7,6 @@ "Attention", "AttentionBackend", "AttentionMetadata", - "AttentionMetadataPerStage", + "Attention", "get_attn_backend", ] diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 98d70fcab1a18..94ab64de30a94 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -21,7 +21,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": + def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError @staticmethod @@ -53,8 +53,34 @@ def copy_blocks( @dataclass -class AttentionMetadataPerStage: - """Attention metadata for a specific stage. I.e., prefill or decode.""" +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -70,40 +96,10 @@ def asdict_zerocopy(self, } -T = TypeVar("T", bound=AttentionMetadataPerStage) - - -@dataclass -class AttentionMetadata(Generic[T]): - """Attention metadata for prefill and decode batched together.""" - # Total number of prefill requests. - num_prefills: int - # Number of prefill tokens. - num_prefill_tokens: int - # Number of decode tokens. Note that it is equivalent to the number of - # decode requests. - num_decode_tokens: int - # The attention metadata for prefill requests in a batch. - # None if there's no prefill requests in a batch. - prefill_metadata: Optional[T] - # The attention metadata for decode requests in a batch. - # None if there's no decode requests in a batch. - decode_metadata: Optional[T] - # (num_tokens,). The indices of the token slots that input tokens will be - # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size - # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot - # in block 0, and 1st slot in block 1, respectively. - slot_mapping: torch.Tensor - - def __post_init__(self): - if self.num_prefill_tokens > 0: - assert self.num_prefills > 0 - assert self.prefill_metadata is not None - if self.num_decode_tokens > 0: - assert self.decode_metadata is not None +T = TypeVar("T", bound=AttentionMetadata) -class AttentionImpl(ABC): +class AttentionImpl(ABC, Generic[T]): @abstractmethod def __init__( @@ -125,7 +121,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: T, kv_scale: float = 1.0, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f59715bd76ede..5d1f65819ed4e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -11,8 +11,7 @@ from vllm_flash_attn import flash_attn_varlen_func from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -58,8 +57,7 @@ def copy_blocks( @dataclass -class FlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + class FlashAttentionImpl(AttentionImpl): """ @@ -168,7 +231,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[FlashAttentionMetadata], + attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -228,8 +291,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -249,7 +312,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -264,7 +327,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 92d0fe0487516..5f9fd586fb70e 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -8,8 +8,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) class FlashInferBackend(AttentionBackend): @@ -56,9 +55,10 @@ def get_supported_head_sizes() -> List[int]: @dataclass -class FlashInferMetadata(AttentionMetadataPerStage): - - is_prompt: bool +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int use_cuda_graph: bool = False @@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage): # Metadata for the prefill stage since we still # use flash attention for prefill. seq_start_loc: Optional[torch.Tensor] = None - max_seq_len: Optional[int] = None block_tables: Optional[torch.Tensor] = None # Metadata for the decode stage @@ -113,7 +112,8 @@ def __post_init__(self): # When using flashinfer, we are also creating the FlashInferMetadata, # which will also call post_init by default, here we want to skip the # post_init if it's the prefill phase. - if not self.is_prompt: + if self.num_prefills == 0: + assert self.num_decode_tokens > 0 self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, "NHD") self.decode_wrapper.begin_forward( @@ -138,6 +138,24 @@ def asdict_zerocopy(self, skip_fields.add('decode_wrapper') return super().asdict_zerocopy(skip_fields) + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + class FlashInferImpl(AttentionImpl): @@ -172,7 +190,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[FlashInferMetadata], + attn_metadata: FlashInferMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: assert kv_scale == 1.0 @@ -208,8 +226,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 539585b46c7aa..1a94dc3596358 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -6,8 +6,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -56,8 +55,7 @@ def copy_blocks( @dataclass -class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, - PagedAttentionMetadata): +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. @@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata class ROCmFlashAttentionImpl(AttentionImpl): @@ -198,7 +260,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata], + attn_metadata: ROCmFlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -266,8 +328,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_seq_len, - prefill_meta.max_seq_len, + prefill_meta.max_prefill_seq_len, + prefill_meta.max_prefill_seq_len, True, self.scale, ) @@ -290,8 +352,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_seq_len, - max_seqlen_k=prefill_meta.max_seq_len, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, ) @@ -308,7 +370,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -324,7 +386,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 2dd72a00c6e30..a3f72b9c94566 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,8 +7,7 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) @@ -54,8 +53,7 @@ def copy_blocks( @dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, - AttentionMetadataPerStage): +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts @@ -72,8 +70,26 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self -class TorchSDPABackendImpl(AttentionImpl): + return None + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( self, @@ -200,7 +216,7 @@ def forward( value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, - attn_metadata.max_seq_len, + attn_metadata.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index cb2028553461f..fc46af054de4f 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -9,8 +9,7 @@ LowerTriangularMaskWithTensorBias) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataPerStage) + AttentionMetadata) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm.logger import init_logger @@ -59,7 +58,7 @@ def copy_blocks( @dataclass -class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): """Metadata for XFormersbackend. NOTE: Any python object stored here is not updated when it is @@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): dynamically, it should be stored in tensor. The tensor has to be updated from `CUDAGraphRunner.forward` API. """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] @@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # Maximum query length in the batch. + # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - subquery_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] # FIXME: It is for flash attn. # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is @@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None def __post_init__(self): # Set during the execution of the first attention op. @@ -114,8 +116,68 @@ def __post_init__(self): # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[AttentionBias]] = None - -class XFormersImpl(AttentionImpl): + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None, + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -176,7 +238,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[XFormersMetadata], + attn_metadata: "XFormersMetadata", kv_scale: float = 1.0, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -244,7 +306,7 @@ def forward( key_cache, value_cache, prefill_meta.block_tables, - prefill_meta.subquery_start_loc, + prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens_tensor, prefill_meta.max_query_len, @@ -261,7 +323,7 @@ def forward( value_cache, decode_meta.block_tables, decode_meta.seq_lens_tensor, - decode_meta.max_seq_len, + decode_meta.max_decode_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8a872dba8c877..126692d8c9b40 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,8 +4,7 @@ import torch import torch.nn as nn -from vllm.attention.backends.abstract import (AttentionMetadata, - AttentionMetadataPerStage) +from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig @@ -57,7 +56,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], - attn_metadata: AttentionMetadata[AttentionMetadataPerStage], + attn_metadata: AttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 3c010b67b3120..30feaa4da254d 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -16,8 +16,8 @@ class PagedAttentionMetadata: # (batch_size,). The length of sequences (entire tokens seen so far) per # sequence. seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length in the batch. - max_seq_len: Optional[int] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -166,7 +166,7 @@ def forward_prefix( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, + query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, max_query_len: int, @@ -182,8 +182,8 @@ def forward_prefix( key_cache, value_cache, block_tables, - # subquery_start_loc is (batch_size + 1,) - subquery_start_loc[:-1], + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], seq_lens_tensor, context_lens, max_query_len, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 195d9e1b33e3c..bd44c2470182b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -618,6 +618,11 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + if (model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled): + raise ValueError( + "Chunked prefill is not supported with sliding window.") + return EngineConfig(model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index b5f1e55d0e839..1f2ab7e2870ca 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -122,6 +122,7 @@ def forward( draft_token_ids, bonus_token_ids, ) + return output_token_ids def _batch_modified_rejection_sampling( diff --git a/vllm/sequence.py b/vllm/sequence.py index 12e930c27173e..aa759448d82b1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -654,8 +654,9 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> Optional[int]: + def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" + assert self._token_chunk_size is not None return self._token_chunk_size diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index d5fd96907ddd7..7792f3a3425cc 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -293,21 +293,30 @@ def _create_single_target_seq_group_metadata( prompt_token_ids = seq_data.get_prompt_token_ids() new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + return SequenceGroupMetadata( request_id=seq_group_metadata.request_id, is_prompt=seq_group_metadata.is_prompt, - seq_data={ - target_seq_id: - SequenceData( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ), - }, + seq_data=new_seq_data_dict, sampling_params=seq_group_metadata.sampling_params, block_tables={ target_seq_id: seq_group_metadata.block_tables[seq_id], }, lora_request=None, + token_chunk_size=1, ) def _split_scoring_output( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 20098ebaeea32..b5a805278d273 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -114,6 +114,7 @@ def _append_new_tokens( token_logprob = seq_output.logprobs[token_id] seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) def _shallow_copy_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 0a0b0d70cfe21..bc88f2c5bed6c 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -159,12 +159,10 @@ def _prepare_prompt( is_prompt=True, seq_lens=seq_lens, seq_lens_tensor=None, - max_seq_len=None, + max_decode_seq_len=None, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, - prefill_metadata=None, - decode_metadata=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, ) @@ -213,7 +211,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_seq_len = max(seq_lens) + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -243,12 +241,10 @@ def _prepare_decode( slot_mapping=slot_mapping, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_seq_len=max_seq_len, + max_decode_seq_len=max_decode_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), num_prefills=0, - prefill_metadata=None, - decode_metadata=None, block_tables=block_tables, ) return ( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index d04bebbdc31b6..91f30978ead87 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -13,7 +13,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import BatchType, ModelRunner +from vllm.worker.model_runner import ModelRunner logger = init_logger(__name__) @@ -88,85 +88,24 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, - prompt_lens, - subquery_lens, - lora_index_mapping, - lora_prompt_mapping, + attn_metadata, + seq_lens, + _, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) - + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - prompt_lens) - - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(prompt_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE + seq_lens) metadata_dict = { "input_tokens": input_tokens, @@ -178,65 +117,26 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - + attn_metadata = None pooling_metadata = PoolingMetadata(seq_groups=None, seq_data=None, prompt_lens=None) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b5e1991717b13..dcdd4b962454e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,13 +1,11 @@ import time -from enum import IntEnum from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn -from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, - get_attn_backend) +from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) @@ -37,66 +35,38 @@ ] -class PreparePromptMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadataPerStage] +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[AttentionMetadata] seq_lens: List[int] query_lens: List[int] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] + lora_mapping: Optional[LoRAMapping] lora_requests: Set[LoRARequest] multi_modal_input: Optional[torch.Tensor] - slot_mapping: List[int] + slot_mapping: torch.Tensor + num_prefill_tokens: int + num_decode_tokens: int + num_prefills: int @classmethod - def empty(cls): - return PreparePromptMetadata( - input_tokens=[], - input_positions=[], + def empty(cls, device): + return ModelInput( + input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), attn_metadata=None, seq_lens=[], query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], + lora_mapping=None, lora_requests=set(), multi_modal_input=None, - slot_mapping=[], - ) - - -class PrepareDecodeMetadata(NamedTuple): - input_tokens: List[int] - input_positions: List[int] - attn_metadata: Optional[AttentionMetadata] - lora_index_mapping: List[int] - lora_prompt_mapping: List[int] - lora_requests: Set[LoRARequest] - slot_mapping: List[int] - - @classmethod - def empty(cls): - return PrepareDecodeMetadata( - input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], + slot_mapping=torch.empty(0, device=device), + num_prefill_tokens=0, + num_decode_tokens=0, + num_prefills=0, ) -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - class ModelRunner: def __init__( @@ -216,10 +186,22 @@ def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size - def _prepare_prompt( + def _prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PreparePromptMetadata: + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] @@ -228,212 +210,16 @@ def _prepare_prompt( lora_requests: Set[LoRARequest] = set() seq_lens: List[int] = [] + prefill_seq_lens: List[int] = [] + decode_seq_lens: List[int] = [] context_lens: List[int] = [] query_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_input_list: List[torch.Tensor] = [] - - if len(seq_group_metadata_list) == 0: - return PreparePromptMetadata.empty() - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - context_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - seq_len = min(seq_data.get_len(), context_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] - seq_lens.append(seq_len) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - context_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) - elif self.scheduler_config.chunked_prefill_enabled: - if seq_group_metadata.block_tables is not None: - # Prefill has chunked before. - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - else: - # The first prefill. - prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert context_len == 0 - - # actual prompt lens - context_lens.append(context_len) - query_lens.append(seq_len - context_len) - - input_tokens.extend(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.extend(list(range(context_len, seq_len))) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * (seq_len - context_len) - lora_prompt_mapping.extend([lora_id] * ( - seq_len - context_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if seq_group_metadata.multi_modal_data: - multi_modal_input_list.append( - seq_group_metadata.multi_modal_data.data) - - if _is_block_tables_empty(seq_group_metadata.block_tables): - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - assert context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - max_query_len = max(query_lens) - max_seq_len = max(seq_lens) - assert max_query_len > 0 - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) - - if multi_modal_input_list: - assert self.vision_language_config, ( - "Multi-modal inputs are only supported by " - "vision language models.") - multi_modal_input = torch.cat(multi_modal_input_list, - dim=0).to(self.device) - else: - multi_modal_input = None - - # Prepare prefix block tables - max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) - block_tables = make_tensor_with_pad( - prefix_block_tables, - max_len=max_prompt_block_table_len, - pad=0, - dtype=torch.int, - device=self.device, - ) - - # Query length can be shorter than key (i.e., prompt) when prefill - # is chunked or prefix cached. - query_lens_tensor = torch.tensor(query_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, - dtype=torch.int32, - device=self.device) - - torch.cumsum(query_lens_tensor, - dim=0, - dtype=subquery_start_loc.dtype, - out=subquery_start_loc[1:]) - - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - - if self.attn_backend.get_name() == "flashinfer": - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - use_cuda_graph=False, - seq_start_loc=seq_start_loc, - max_seq_len=max_seq_len, - block_tables=block_tables) - else: - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_seq_len=max_seq_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) - - return PreparePromptMetadata( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - multi_modal_input=multi_modal_input, - slot_mapping=slot_mapping, - ) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PrepareDecodeMetadata: - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] block_tables: List[List[int]] = [] - lora_index_mapping: List[int] = [] - lora_prompt_mapping: List[int] = [] - lora_requests: Set[LoRARequest] = set() + multi_modal_input_list: List[torch.Tensor] = [] + decode_only = True + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 # The following fields are only for flashinfer # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout @@ -454,60 +240,186 @@ def _prepare_decode( paged_kv_last_page_len: List[int] = [] if len(seq_group_metadata_list) == 0: - return PrepareDecodeMetadata.empty() + return ModelInput.empty(self.device) for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) + is_prompt = seq_group_metadata.is_prompt for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + context_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + context_len + seq_group_metadata.token_chunk_size) + if is_prompt: + tokens = seq_data.get_token_ids()[context_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + context_len = len(computed_block_nums) * self.block_size + tokens = tokens[context_len:] + if self.attn_backend.get_name() == "flash-attn": + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + # TODO(woosuk): This is a temporary fix. We should + # provide a unified interface for different backends. + block_table = seq_group_metadata.block_tables[seq_id] + else: + block_table = computed_block_nums + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert (not self.scheduler_config. + chunked_prefill_enabled) + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + + if self.attn_backend.get_name() == "flashinfer": + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + + len(block_table)) + last_page_len = seq_data.get_len( + ) % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # Prefill without chunked prefill or memory profiling. + block_table = [] + block_tables.append(block_table) - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + seq_len = min(seq_len, self.sliding_window) + context_len = seq_len - 1 - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) seq_lens.append(seq_len) + context_lens.append(context_len) + query_len = seq_len - context_len + query_lens.append(query_len) + input_tokens.extend(tokens) + input_positions.extend(list(range(context_len, seq_len))) + lora_id = seq_group_metadata.lora_int_id + + if is_prompt: + assert len(seq_ids) == 1 + num_prefills += 1 + num_prefill_tokens += len(tokens) + decode_only = False + prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + num_decode_tokens += query_len + decode_seq_lens.append(seq_len) + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_prompt_mapping.extend( + [lora_id] * + (seq_len - + context_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if _is_block_tables_empty(seq_group_metadata.block_tables): + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) + # Mask the [0, start_idx) tokens of the prompt with + # _PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) + if is_prompt: + assert context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + # It is an optimization. When it is decoding, it is always + # 0. When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - self.sliding_window) + + for i in range(context_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) - paged_kv_indices.extend(block_table) - paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) - last_page_len = seq_data.get_len() % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - paged_kv_last_page_len.append(last_page_len) + batch_size = len(input_tokens) + max_query_len = max(query_lens) + max_prefill_seq_len = max(prefill_seq_lens, default=0) + max_decode_seq_len = max(decode_seq_lens, default=0) - # vLLM uses cuda graph only for decoding requests. + # If cuda graph can be used, pad tensors accordingly. # See `capture_model` API for more details. - # For decoding requests, batch_size == input_tokens. - batch_size = len(input_tokens) - max_seq_len = max(seq_lens) - use_captured_graph = (not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_seq_len <= self.max_seq_len_to_capture) + # vLLM uses cuda graph only for decoding requests. + use_captured_graph = ( + decode_only and not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -519,18 +431,9 @@ def _prepare_decode( block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) + num_decode_tokens = batch_size if use_captured_graph: - # When using cuda-graph all these tensors should be - # padded. - assert seq_lens_tensor.shape[0] == len(input_tokens) - assert seq_lens_tensor.shape[0] == len(input_positions) - assert seq_lens_tensor.shape[0] == len(slot_mapping) - # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.graph_block_tables[:batch_size] @@ -548,6 +451,57 @@ def _prepare_decode( dtype=torch.int, device=self.device, ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=self.device) + + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): @@ -555,53 +509,75 @@ def _prepare_decode( # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html self.flashinfer_workspace_buffer = torch.empty( 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) - paged_kv_indptr = torch.tensor(paged_kv_indptr, - dtype=torch.int, - device=self.device) - paged_kv_indices = torch.tensor(paged_kv_indices, - dtype=torch.int, - device=self.device) - paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr, dtype=torch.int, device=self.device) + paged_kv_indices_tensor = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len_tensor = torch.tensor( + paged_kv_last_page_len, dtype=torch.int, device=self.device) kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, use_cuda_graph=False, + max_prefill_seq_len=max_prefill_seq_len, + block_tables=block_tables, workspace_buffer=self.flashinfer_workspace_buffer, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, num_qo_heads=self.model_config.get_num_attention_heads( self.parallel_config), num_kv_heads=self.model_config.get_num_kv_heads( self.parallel_config), head_dim=self.model_config.get_head_size(), - page_size=self.block_size, + page_size=16, + seq_start_loc=seq_start_loc, data_type=kv_cache_dtype) else: attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - seq_lens=None, + num_prefills=num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, - max_query_len=None, - max_seq_len=max_seq_len, - subquery_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return PrepareDecodeMetadata( - input_tokens=input_tokens, - input_positions=input_positions, + + if self.lora_config: + lora_mapping = LoRAMapping( + lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + return ModelInput( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, attn_metadata=attn_metadata, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, lora_requests=lora_requests, - slot_mapping=slot_mapping, + multi_modal_input=multi_modal_input, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, ) def prepare_input_tensors( @@ -610,85 +586,25 @@ def prepare_input_tensors( ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - # Prepare input tensors. ( input_tokens, input_positions, - prefill_attn_metadata, + attn_metadata, seq_lens, query_lens, - lora_index_mapping, - lora_prompt_mapping, + lora_mapping, lora_requests, multi_modal_input, slot_mapping, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - ) = self._prepare_decode(decode_reqs) + num_prefill_tokens, + num_decode_tokens, + num_prefills, + ) = self._prepare_model_input(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, query_lens, self.device, self.pin_memory) - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(seq_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # Coalesce tensors. Note that attn_metadata is currently not - # coalesced for simplicity. - input_tokens.extend(decode_input_tokens) - input_positions.extend(decode_input_positions) - slot_mapping.extend(decode_slot_mapping) - lora_index_mapping.extend(decode_lora_index_mapping) - lora_prompt_mapping.extend(decode_lora_prompt_mapping) - lora_requests.update(decode_lora_requests) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - if self.lora_config: - lora_mapping = LoRAMapping( - lora_index_mapping, - lora_prompt_mapping, - ) - else: - lora_mapping = None - - # Broadcast the metadata. - # If batch contains both prefill and decode, it sends 2 broadcasts. - # If it only contains 1 type, it triggers a single broadcast. - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE - metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, @@ -701,46 +617,24 @@ def prepare_input_tensors( "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, - "batch_type": batch_type, } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) + if attn_metadata: + metadata_dict.update(attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) - - # Broadcast decode attn metadata for mixed batch type. - # The additional broadcast costs 300us overhead on 4 A10 GPUs. - # We can potentially reduce the overhead by coelescing tensors. - if batch_type == BatchType.MIXED: - assert decode_attn_metadata is not None - metadata_dict = decode_attn_metadata.asdict_zerocopy() - broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") - slot_mapping = metadata_dict.pop("slot_mapping") - num_prefills = metadata_dict.pop("num_prefills") selected_token_indices = metadata_dict.pop( "selected_token_indices") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") - num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") - num_decode_tokens = metadata_dict.pop("num_decode_tokens") - batch_type = metadata_dict.pop("batch_type") - - # Create an attention metadata. - prefill_attn_metadata = None - decode_attn_metadata = None - if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: - prefill_attn_metadata = self.attn_backend.make_metadata( + if metadata_dict: + attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) + attn_metadata = None sampling_metadata = SamplingMetadata( seq_groups=None, selected_token_indices=selected_token_indices, @@ -748,22 +642,6 @@ def prepare_input_tensors( num_prompts=0, ) - # if it is a mixed batch, decode attn_metadata is broadcasted - # separately. - if batch_type == BatchType.MIXED: - metadata_dict = broadcast_tensor_dict(src=0) - decode_attn_metadata = self.attn_backend.make_metadata( - **metadata_dict) - - attn_metadata = AttentionMetadata( - num_prefills=num_prefills, - slot_mapping=slot_mapping, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - prefill_metadata=prefill_attn_metadata, - decode_metadata=decode_attn_metadata, - ) - return (input_tokens, input_positions, attn_metadata, sampling_metadata, lora_requests, lora_mapping, multi_modal_input) @@ -954,26 +832,22 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): # Create dummy attn_metadata. - decode_metadata = self.attn_backend.make_metadata( - is_prompt=False, + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=seq_lens[:batch_size], max_query_len=None, - max_seq_len=self.max_seq_len_to_capture, - subquery_start_loc=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_seq_len_to_capture, + query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) - attn_metadata = AttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping[:batch_size], - prefill_metadata=None, - decode_metadata=decode_metadata, - ) if self.lora_config: lora_mapping = LoRAMapping( From e9cdd2b1e20beb1c21c55441d0e6a4ed86f4e292 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 14:38:40 +0800 Subject: [PATCH 279/413] [CI/Build] Further decouple HuggingFace implementation from ours during tests (#4166) --- tests/conftest.py | 77 +++++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 999ace2c3c699..c1a44a606e1bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,21 @@ import contextlib import gc import os -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import pytest import torch from PIL import Image -from transformers import (AutoModelForCausalLM, AutoProcessor, - LlavaForConditionalGeneration) +from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, + LlavaConfig, LlavaForConditionalGeneration) from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.logger import init_logger from vllm.sequence import MultiModalData -from vllm.transformers_utils.tokenizer import get_tokenizer + +logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] @@ -129,9 +131,7 @@ def example_long_prompts() -> List[str]: "float": torch.float, } -_VISION_LANGUAGE_MODELS = { - "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration, -} +AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration) _EMBEDDING_MODELS = [ "intfloat/e5-mistral-7b-instruct", @@ -143,23 +143,14 @@ class HfRunner: def __init__( self, model_name: str, - tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + self.model_name = model_name - if model_name in _VISION_LANGUAGE_MODELS: - self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() - self.processor = AutoProcessor.from_pretrained( - model_name, - torch_dtype=torch_dtype, - ) - elif model_name in _EMBEDDING_MODELS: + + if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = SentenceTransformer( @@ -172,10 +163,24 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, ).cuda() - self.processor = None - if tokenizer_name is None: - tokenizer_name = model_name - self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + try: + self.processor = AutoProcessor.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + except Exception: + logger.warning( + "Unable to auto-load processor from HuggingFace for " + "model %s. Using tokenizer instead.", model_name) + self.processor = self.tokenizer def generate( self, @@ -187,19 +192,19 @@ def generate( if images: assert len(prompts) == len(images) for i, prompt in enumerate(prompts): - if self.model_name not in _VISION_LANGUAGE_MODELS: - input_ids = self.tokenizer(prompt, - return_tensors="pt").input_ids - inputs = {"input_ids": input_ids.cuda()} - else: - image = images[i] if images else None - inputs = self.processor(text=prompt, - images=image, - return_tensors="pt") - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } + processor_kwargs: Dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] + + inputs = self.processor(**processor_kwargs) + inputs = { + key: value.cuda() if value is not None else None + for key, value in inputs.items() + } + output_ids = self.model.generate( **inputs, use_cache=True, From a5675d348b126e53928e139d1ed5b2c00a0044e8 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Wed, 15 May 2024 07:22:09 -0700 Subject: [PATCH 280/413] [Bugfix] Properly set distributed_executor_backend in ParallelConfig (#4816) --- vllm/config.py | 1 + vllm/engine/arg_utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 26edd4567b9ac..2eb5bdd18d812 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -531,6 +531,7 @@ class ParallelConfig: If None, will use synchronous tokenization. ray_workers_use_nsight: Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + placement_group: ray distributed model workers placement group. distributed_executor_backend: Backend to use for distributed model workers, either "ray" or "mp" (multiprocessing). If either pipeline_parallel_size or tensor_parallel_size is greater than 1, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd44c2470182b..dab86b7c9eb35 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -548,14 +548,18 @@ def create_engine_config(self, ) -> EngineConfig: model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( - self.pipeline_parallel_size, self.tensor_parallel_size, - self.worker_use_ray, self.max_parallel_loading_workers, + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, self.disable_custom_all_reduce, TokenizerPoolConfig.create_config( self.tokenizer_pool_size, self.tokenizer_pool_type, self.tokenizer_pool_extra_config, - ), self.ray_workers_use_nsight) + ), + self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend) speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, From 361c461a128a5df2faefeb70ffa98e61e4feda55 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 15 May 2024 11:38:49 -0700 Subject: [PATCH 281/413] [Doc] Highlight the fourth meetup in the README (#4842) --- README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4fe5d9630f858..fc3f71b00c3c5 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,18 @@ Easy, fast, and cheap LLM serving for everyone

+--- + +**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)** + +We are thrilled to announce our fourth vLLM Meetup! +The vLLM team will share recent updates and roadmap. +We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM. +Please register [here](https://lu.ma/agivllm) and join us! + +--- + *Latest News* 🔥 -- [2024/05] We are hosting [the fourth vLLM meetup](https://lu.ma/event/manage/evt-A064fGpj52fviSn) with BentoML and Cloudflare on June 11! Please register [here](https://lu.ma/agivllm). - [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] Added ROCm 6.0 support to vLLM. From fc0d9dfc3afcea2e23649ef8eb8bbe0446682813 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 16 May 2024 05:58:46 +0800 Subject: [PATCH 282/413] [Frontend] Re-enable custom roles in Chat Completions API (#4758) --- tests/entrypoints/test_openai_server.py | 30 +++++++++++ vllm/entrypoints/openai/protocol.py | 38 +++++++++++++- vllm/entrypoints/openai/serving_chat.py | 66 ++++++++++++++++--------- 3 files changed, 108 insertions(+), 26 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index ee2f034fd2c46..1b04e3205c4b8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +async def test_custom_role(server, client: openai.AsyncOpenAI): + # Not sure how the model handles custom roles so we just check that + # both string and complex message content are handled in the same way + + resp1 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": "what is 1+1?", + }], # type: ignore + temperature=0, + seed=0) + + resp2 = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "my-custom-role", + "content": [{ + "type": "text", + "text": "what is 1+1?" + }] + }], # type: ignore + temperature=0, + seed=0) + + content1 = resp1.choices[0].message.content + content2 = resp2.choices[0].message.content + assert content1 == content2 + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 139c5716c7cea..35dfa09ac12ba 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,16 +3,50 @@ import time from typing import Any, Dict, List, Literal, Optional, Union +import openai.types.chat import torch -from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated +# pydantic needs the TypedDict from typing_extensions +from typing_extensions import Annotated, Required, TypedDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + + +ChatCompletionContentPartParam = Union[ + openai.types.chat.ChatCompletionContentPartParam, + CustomChatCompletionContentPartParam] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + +ChatCompletionMessageParam = Union[ + openai.types.chat.ChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1b469fc59b076..65824a2206be9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,15 +1,16 @@ import codecs import time -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, - Optional, Tuple, TypedDict, Union, final) +from dataclasses import dataclass +from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional, + TypedDict, Union, cast, final) from fastapi import Request -from openai.types.chat import (ChatCompletionContentPartParam, - ChatCompletionRole) +from openai.types.chat import ChatCompletionContentPartTextParam from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( + ChatCompletionContentPartParam, ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, @@ -31,6 +32,11 @@ class ConversationMessage(TypedDict): content: str +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + + class OpenAIServingChat(OpenAIServing): def __init__(self, @@ -77,27 +83,40 @@ def _load_chat_template(self, chat_template: Optional[str]): logger.warning( "No chat template provided. Chat API will not work.") - def _parse_chat_message_content( + def _parse_chat_message_content_parts( self, - role: ChatCompletionRole, - content: Optional[Union[str, - Iterable[ChatCompletionContentPartParam]]], - ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: - if content is None: - return [], [] - if isinstance(content, str): - return [ConversationMessage(role=role, content=content)], [] - + role: str, + parts: Iterable[ChatCompletionContentPartParam], + ) -> ChatMessageParseResult: texts: List[str] = [] - for _, part in enumerate(content): - if part["type"] == "text": - text = part["text"] + + for _, part in enumerate(parts): + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) else: - raise NotImplementedError(f"Unknown part type: {part['type']}") + raise NotImplementedError(f"Unknown part type: {part_type}") + + messages = [ConversationMessage(role=role, content="\n".join(texts))] + + return ChatMessageParseResult(messages=messages) + + def _parse_chat_message_content( + self, + message: ChatCompletionMessageParam, + ) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages) - return [ConversationMessage(role=role, content="\n".join(texts))], [] + return self._parse_chat_message_content_parts(role, content) async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request @@ -119,11 +138,10 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] - for m in request.messages: - messages, _ = self._parse_chat_message_content( - m["role"], m["content"]) + for msg in request.messages: + parsed_msg = self._parse_chat_message_content(msg) - conversation.extend(messages) + conversation.extend(parsed_msg.messages) prompt = self.tokenizer.apply_chat_template( conversation=conversation, @@ -387,4 +405,4 @@ async def chat_completion_full_generator( usage=usage, ) - return response \ No newline at end of file + return response From 52f8107cf2e5b3cc1a6a4a96c22b24505f02df01 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Wed, 15 May 2024 19:13:36 -0400 Subject: [PATCH 283/413] [Frontend] Support OpenAI batch file format (#4794) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- examples/offline_inference_openai.md | 172 +++++++++++++++++++++ examples/openi_example_batch.jsonl | 2 + requirements-common.txt | 1 + tests/entrypoints/test_openai_run_batch.py | 53 +++++++ vllm/entrypoints/openai/protocol.py | 41 +++++ vllm/entrypoints/openai/run_batch.py | 141 +++++++++++++++++ vllm/entrypoints/openai/serving_chat.py | 8 +- 7 files changed, 415 insertions(+), 3 deletions(-) create mode 100644 examples/offline_inference_openai.md create mode 100644 examples/openi_example_batch.jsonl create mode 100644 tests/entrypoints/test_openai_run_batch.py create mode 100644 vllm/entrypoints/openai/run_batch.py diff --git a/examples/offline_inference_openai.md b/examples/offline_inference_openai.md new file mode 100644 index 0000000000000..40462ce1eb78c --- /dev/null +++ b/examples/offline_inference_openai.md @@ -0,0 +1,172 @@ +# Offline Inference with the OpenAI Batch file format + + **NOTE:** This is a guide to performing batch inference using the OpenAI batch file format, **NOT** the complete Batch (REST) API. + + ## File Format + + The OpenAI batch file format consists of a series of json objects on new lines. + + [See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/openai_example_batch.jsonl) + + Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. + + **NOTE:** We currently only support to `/v1/chat/completions` endpoint (embeddings and completions coming soon). + + ## Pre-requisites + +* Ensure you are using `vllm >= 0.4.3`. You can check by running `python -c "import vllm; print(vllm.__version__)"`. +* The examples in this document use `meta-llama/Meta-Llama-3-8B-Instruct`. + - Create a [user access token](https://huggingface.co/docs/hub/en/security-tokens) + - Install the token on your machine (Run `huggingface-cli login`). + - Get access to the gated model by [visiting the model card](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) and agreeing to the terms and conditions. + + + ## Example: Running with a local file + + ### Step 1: Create your batch file + + To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + + ``` + wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl + ``` + + Once you've created your batch file it should look like this + + ``` + $ cat openai_example_batch.jsonl +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} + ``` + + ### Step 2: Run the batch + +The batch running tool is designed to be used from the command line. + +You can run the batch with the following command, which will write its results to a file called `results.jsonl` + +``` +python -m vllm.entrypoints.openai.run_batch -i openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +### Step 3: Check your results + +You should now have your results at `results.jsonl`. You can check your results by running `cat results.jsonl` + +``` +$ cat ../results.jsonl +{"id":"vllm-383d1c59835645aeb2e07d004d62a826","custom_id":"request-1","response":{"id":"cmpl-61c020e54b964d5a98fa7527bfcdd378","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Hello! It's great to meet you! I'm here to help with any questions or tasks you may have. What's on your mind today?"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":25,"total_tokens":56,"completion_tokens":31}},"error":null} +{"id":"vllm-42e3d09b14b04568afa3f1797751a267","custom_id":"request-2","response":{"id":"cmpl-f44d049f6b3a42d4b2d7850bb1e31bcc","object":"chat.completion","created":1715633336,"model":"meta-llama/Meta-Llama-3-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"*silence*"},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":27,"total_tokens":32,"completion_tokens":5}},"error":null} +``` + +## Example 2: Using remote files + +The batch runner supports remote input and output urls that are accessible via http/https. + +For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl`, you can run + +``` +python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +## Example 3: Integrating with AWS S3 + +To integrate with cloud blob storage, we recommend using presigned urls. + +[Learn more about S3 presigned urls here] + +### Additional prerequisites + +* [Create an S3 bucket](https://docs.aws.amazon.com/AmazonS3/latest/userguide/creating-bucket.html). +* The `awscli` package (Run `pip install awscli`) to configure your credentials and interactively use s3. + - [Configure your credentials](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-quickstart.html). +* The `boto3` python package (Run `pip install boto3`) to generate presigned urls. + +### Step 1: Upload your input script + +To follow along with this example, you can download the example batch, or create your own batch file in your working directory. + + ``` + wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/openai_example_batch.jsonl + ``` + + Once you've created your batch file it should look like this + + ``` + $ cat openai_example_batch.jsonl +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} + ``` + +Now upload your batch file to your S3 bucket. + +``` +aws s3 cp openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl +``` + + +### Step 2: Generate your presigned urls + +Presigned put urls can only be generated via the SDK. You can run the following python script to generate your presigned urls. Be sure to replace the `MY_BUCKET`, `MY_INPUT_FILE.jsonl`, and `MY_OUTPUT_FILE.jsonl` placeholders with your bucket and file names. + +(The script is adapted from https://github.com/awsdocs/aws-doc-sdk-examples/blob/main/python/example_code/s3/s3_basics/presigned_url.py) + +``` +import boto3 +from botocore.exceptions import ClientError + +def generate_presigned_url(s3_client, client_method, method_parameters, expires_in): + """ + Generate a presigned Amazon S3 URL that can be used to perform an action. + + :param s3_client: A Boto3 Amazon S3 client. + :param client_method: The name of the client method that the URL performs. + :param method_parameters: The parameters of the specified client method. + :param expires_in: The number of seconds the presigned URL is valid for. + :return: The presigned URL. + """ + try: + url = s3_client.generate_presigned_url( + ClientMethod=client_method, Params=method_parameters, ExpiresIn=expires_in + ) + except ClientError: + raise + return url + + +s3_client = boto3.client("s3") +input_url = generate_presigned_url( + s3_client, "get_object", {"Bucket": "MY_BUCKET", "Key": "MY_INPUT_FILE.jsonl"}, 3600 +) +output_url = generate_presigned_url( + s3_client, "put_object", {"Bucket": "MY_BUCKET", "Key": "MY_OUTPUT_FILE.jsonl"}, 3600 +) +print(f"{input_url=}") +print(f"{output_url=}") +``` + +This script should output + +``` +input_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' +output_url='https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091' +``` + +### Step 3: Run the batch runner using your presigned urls + +You can now run the batch runner, using the urls generated in the previous section. + +``` +python -m vllm.entrypoints.openai.run_batch \ + -i "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_INPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ + -o "https://s3.us-west-2.amazonaws.com/MY_BUCKET/MY_OUTPUT_FILE.jsonl?AWSAccessKeyId=ABCDEFGHIJKLMNOPQRST&Signature=abcdefghijklmnopqrstuvwxyz12345&Expires=1715800091" \ + --model --model meta-llama/Meta-Llama-3-8B-Instruct +``` + +### Step 4: View your results + +Your results are now on S3. You can view them in your terminal by running + +``` +aws s3 cp s3://MY_BUCKET/MY_OUTPUT_FILE.jsonl - +``` diff --git a/examples/openi_example_batch.jsonl b/examples/openi_example_batch.jsonl new file mode 100644 index 0000000000000..5aa7e185c180a --- /dev/null +++ b/examples/openi_example_batch.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} diff --git a/requirements-common.txt b/requirements-common.txt index bd779d5acb68e..cc4b15d877d0f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,6 +8,7 @@ py-cpuinfo transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. tokenizers >= 0.19.1 # Required for Llama 3. fastapi +aiohttp openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. diff --git a/tests/entrypoints/test_openai_run_batch.py b/tests/entrypoints/test_openai_run_batch.py new file mode 100644 index 0000000000000..5de28513ca391 --- /dev/null +++ b/tests/entrypoints/test_openai_run_batch.py @@ -0,0 +1,53 @@ +import subprocess +import sys +import tempfile + +from vllm.entrypoints.openai.protocol import BatchRequestOutput + +# ruff: noqa: E501 +INPUT_BATCH = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" + +INVALID_INPUT_BATCH = """{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "NousResearch/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}""" + + +def test_e2e(): + with tempfile.NamedTemporaryFile( + "w") as input_file, tempfile.NamedTemporaryFile( + "r") as output_file: + input_file.write(INPUT_BATCH) + input_file.flush() + proc = subprocess.Popen([ + sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", + input_file.name, "-o", output_file.name, "--model", + "NousResearch/Meta-Llama-3-8B-Instruct" + ], ) + proc.communicate() + proc.wait() + assert proc.returncode == 0, f"{proc=}" + + contents = output_file.read() + for line in contents.strip().split("\n"): + # Ensure that the output format conforms to the openai api. + # Validation should throw if the schema is wrong. + BatchRequestOutput.model_validate_json(line) + + +def test_e2e_invalid_input(): + """ + Ensure that we fail when the input doesn't conform to the openai api. + """ + with tempfile.NamedTemporaryFile( + "w") as input_file, tempfile.NamedTemporaryFile( + "r") as output_file: + input_file.write(INVALID_INPUT_BATCH) + input_file.flush() + proc = subprocess.Popen([ + sys.executable, "-m", "vllm.entrypoints.openai.run_batch", "-i", + input_file.name, "-o", output_file.name, "--model", + "NousResearch/Meta-Llama-3-8B-Instruct" + ], ) + proc.communicate() + proc.wait() + assert proc.returncode != 0, f"{proc=}" diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 35dfa09ac12ba..41e2f77fe56f1 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -526,3 +526,44 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: List[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameteters of the request. + body: Union[ChatCompletionRequest, ] + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[ChatCompletionResponse] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000000000..99f1b2d6d091b --- /dev/null +++ b/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,141 @@ +import argparse +import asyncio +import sys +from io import StringIO + +import aiohttp + +import vllm +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + ChatCompletionResponse) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + return parser.parse_args() + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, "r") as f: + return f.read() + + +async def write_file(path_or_url: str, data: str) -> None: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.put(path_or_url, data=data.encode("utf-8")): + pass + else: + # We should make this async, but as long as this is always run as a + # standalone program, blocking the event loop won't effect performance + # in this particular case. + with open(path_or_url, "w") as f: + f.write(data) + + +async def run_request(chat_serving: OpenAIServingChat, + request: BatchRequestInput) -> BatchRequestOutput: + chat_request = request.body + chat_response = await chat_serving.create_chat_completion(chat_request) + if isinstance(chat_response, ChatCompletionResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=chat_response, + error=None, + ) + else: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=None, + error=chat_response, + ) + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + + # When using single vLLM without engine_use_ray + model_config = await engine.get_model_config() + + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + served_model_names, + args.response_role, + ) + + # Submit all requests in the file to the engine "concurrently". + response_futures = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + request = BatchRequestInput.model_validate_json(request_json) + response_futures.append(run_request(openai_serving_chat, request)) + + responses = await asyncio.gather(*response_futures) + + output_buffer = StringIO() + for response in responses: + print(response.model_dump_json(), file=output_buffer) + + output_buffer.seek(0) + await write_file(args.output_file, output_buffer.read().strip()) + + # Temporary workaround for https://github.com/vllm-project/vllm/issues/4789 + sys.exit(0) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM API server version %s", vllm.__version__) + logger.info("args: %s", args) + + asyncio.run(main(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 65824a2206be9..c86e41c601be0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -119,7 +119,9 @@ def _parse_chat_message_content( return self._parse_chat_message_content_parts(role, content) async def create_chat_completion( - self, request: ChatCompletionRequest, raw_request: Request + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None ) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]: """Completion API similar to OpenAI's API. @@ -337,7 +339,7 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, + self, request: ChatCompletionRequest, raw_request: Optional[Request], result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage] ) -> Union[ErrorResponse, ChatCompletionResponse]: @@ -347,7 +349,7 @@ async def chat_completion_full_generator( final_res: Optional[RequestOutput] = None async for res in result_generator: - if await raw_request.is_disconnected(): + if raw_request is not None and await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(request_id) return self.create_error_response("Client disconnected") From 30e754390c2a8a7198f472386d35ee1ec9443e4a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 16 May 2024 01:11:54 -0400 Subject: [PATCH 284/413] [Core] Implement sharded state loader (#4690) Co-authored-by: Woosuk Kwon --- examples/save_sharded_state.py | 75 +++++++++++ tests/test_sharded_state_loader.py | 90 +++++++++++++ vllm/config.py | 1 + vllm/executor/distributed_gpu_executor.py | 11 ++ vllm/model_executor/model_loader/loader.py | 148 +++++++++++++++++++++ vllm/worker/model_runner.py | 14 ++ vllm/worker/worker.py | 12 ++ 7 files changed, 351 insertions(+) create mode 100644 examples/save_sharded_state.py create mode 100644 tests/test_sharded_state_loader.py diff --git a/examples/save_sharded_state.py b/examples/save_sharded_state.py new file mode 100644 index 0000000000000..c595d98ba2750 --- /dev/null +++ b/examples/save_sharded_state.py @@ -0,0 +1,75 @@ +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + quantization="deepspeedfp", + tensor_parallel_size=8, +) +""" +import argparse +import dataclasses +import os +import shutil +from pathlib import Path + +from vllm import LLM, EngineArgs + +parser = argparse.ArgumentParser() +EngineArgs.add_cli_args(parser) +parser.add_argument("--output", + "-o", + required=True, + type=str, + help="path to output checkpoint") +parser.add_argument("--file-pattern", + type=str, + help="string pattern of saved filenames") +parser.add_argument("--max-file-size", + type=str, + default=5 * 1024**3, + help="max size (in bytes) of each safetensors file") + + +def main(args): + engine_args = EngineArgs.from_cli_args(args) + if engine_args.enable_lora: + raise ValueError("Saving with enable_lora=True is not supported!") + model_path = engine_args.model + if not Path(model_path).is_dir(): + raise ValueError("model path must be a local directory") + # Create LLM instance from arguments + llm = LLM(**dataclasses.asdict(engine_args)) + # Prepare output directory + Path(args.output).mkdir(exist_ok=True) + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=args.output, + pattern=args.file_pattern, + max_size=args.max_file_size) + # Copy metadata files to output directory + for file in os.listdir(model_path): + if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): + if os.path.isdir(os.path.join(model_path, file)): + shutil.copytree(os.path.join(model_path, file), + os.path.join(args.output, file)) + else: + shutil.copy(os.path.join(model_path, file), args.output) + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py new file mode 100644 index 0000000000000..8540e98da366a --- /dev/null +++ b/tests/test_sharded_state_loader.py @@ -0,0 +1,90 @@ +import os +import shutil +from tempfile import TemporaryDirectory + +import pytest +import torch +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.model_executor.model_loader.loader import ShardedStateLoader + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + seed=0, + max_tokens=256, + ignore_eos=True, +) + + +def test_filter_subtensors(): + state_dict = { + "a": torch.empty(2), + "b": torch.empty((2, 4)), + "c": torch.empty((2, 4, 8)), + } + state_dict.update({ + "x": state_dict["b"], + "y": state_dict["c"][1, 2, :], + "z": state_dict["c"][1, :, 4], + }) + filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) + assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") + for key, tensor in filtered_state_dict.items(): + assert tensor.equal(state_dict[key]) + + +@pytest.mark.parametrize("enable_lora", [False, True]) +def test_sharded_state_loader(enable_lora): + weights_patterns = ("*.bin", "*.pt", "*.safetensors") + + with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: + input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", + cache_dir=cache_dir) + + llm = LLM( + model=input_dir, + worker_use_ray=True, + gpu_memory_utilization=0.3, + ) + + # Dump worker states to output directory + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=output_dir) + # Copy metadata files to output directory + for file in os.listdir(input_dir): + if not any(file.endswith(ext) for ext in weights_patterns): + shutil.copy(f"{input_dir}/{file}", output_dir) + del llm.llm_engine.model_executor + + llm_before = LLM( + model=input_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + ) + gen_before = llm_before.generate(prompts, sampling_params) + out_before = [gen.outputs[0].__dict__ for gen in gen_before] + del llm_before.llm_engine.model_executor + + llm_after = LLM( + model=output_dir, + worker_use_ray=True, + enable_lora=enable_lora, + gpu_memory_utilization=0.3, + load_format="sharded_state", + ) + gen_after = llm_after.generate(prompts, sampling_params) + out_after = [gen.outputs[0].__dict__ for gen in gen_after] + del llm_after.llm_engine.model_executor + + assert out_before == out_after diff --git a/vllm/config.py b/vllm/config.py index 2eb5bdd18d812..91f590aaf79eb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum): NPCACHE = "npcache" DUMMY = "dummy" TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" @dataclass diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4c922ef63ee04..c5b1e61112afb 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -77,6 +77,17 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self._run_workers("list_loras") + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_sharded_state", + path=path, + pattern=pattern, + max_size=max_size) + @abstractmethod def _run_workers( self, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b14824a359b6d..dc568928b2859 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,4 +1,5 @@ # ruff: noqa: SIM117 +import collections import copy import glob import os @@ -366,6 +367,150 @@ def load_model(self, *, model_config: ModelConfig, cache_config) +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_states.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups = collections.defaultdict(list) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + from safetensors.torch import safe_open + + from vllm.distributed import get_tensor_model_parallel_rank + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, vision_language_config, + cache_config) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + model_config.model, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", tensor.shape, + key, param_shape) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" @@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.TENSORIZER: return TensorizerLoader(load_config) + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dcdd4b962454e..623a0bc32211c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -182,6 +182,20 @@ def load_model(self) -> None: "but the KV cache data type is not FP8. " "KV cache scaling factors will not be used.") + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 82cf58101a95b..faea50fbfbf50 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -119,6 +119,18 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + @torch.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many From 973617ae02a4e8e6190674cf1cdb0c0803b65ae6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 16 May 2024 00:53:51 -0700 Subject: [PATCH 285/413] [Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840) Co-authored-by: Cade Daniel Co-authored-by: Cade Daniel --- .buildkite/test-pipeline.yaml | 1 + benchmarks/benchmark_latency.py | 6 + tests/spec_decode/e2e/test_compatibility.py | 50 ----- tests/spec_decode/e2e/test_integration.py | 44 +++++ .../spec_decode/e2e/test_integration_dist.py | 65 +++++++ .../e2e/test_multistep_correctness.py | 37 ---- vllm/distributed/communication_op.py | 10 +- vllm/executor/gpu_executor.py | 71 ++------ vllm/executor/ray_gpu_executor.py | 19 +- vllm/spec_decode/spec_decode_worker.py | 171 +++++++++++++++--- vllm/worker/worker.py | 3 +- vllm/worker/worker_base.py | 2 +- 12 files changed, 297 insertions(+), 182 deletions(-) create mode 100644 tests/spec_decode/e2e/test_integration.py create mode 100644 tests/spec_decode/e2e/test_integration_dist.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index eeb6eaa2165bc..aa74672f4bf67 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -42,6 +42,7 @@ steps: - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py + - pytest -v -s spec_decode/e2e/test_integration_dist.py - label: Distributed Tests (Multiple Groups) working_dir: "/vllm-workspace/tests" diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 44da3bad8d840..8f3168c115ae6 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -18,6 +18,8 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM(model=args.model, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, tokenizer=args.tokenizer, quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, @@ -28,6 +30,7 @@ def main(args: argparse.Namespace): quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, + use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size) @@ -99,6 +102,8 @@ def run_to_completion(profile_dir: Optional[str] = None): description='Benchmark the latency of processing a single batch of ' 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') + parser.add_argument('--speculative-model', type=str, default=None) + parser.add_argument('--num-speculative-tokens', type=int, default=None) parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', @@ -181,6 +186,7 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 60c20ed7db7a3..81f91c5e10b0d 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -5,56 +5,6 @@ from .conftest import get_output_from_llm_generator -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Expect failure as spec decode not supported by - # Ray backend. - "worker_use_ray": True, - }, - ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_ray(test_llm_generator): - """Verify that speculative decoding with Ray fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - try: - with pytest.raises( - AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) - finally: - # we need to free up ray resource, - # so that latter test could use the gpu we allocated here - import ray - ray.shutdown() - - @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py new file mode 100644 index 0000000000000..4a2b62151f8cd --- /dev/null +++ b/tests/spec_decode/e2e/test_integration.py @@ -0,0 +1,44 @@ +"""Tests which cover integration of the speculative decoding framework with +other features, e.g. cuda graphs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + ) diff --git a/tests/spec_decode/e2e/test_integration_dist.py b/tests/spec_decode/e2e/test_integration_dist.py new file mode 100644 index 0000000000000..d444ef24cbfda --- /dev/null +++ b/tests/spec_decode/e2e/test_integration_dist.py @@ -0,0 +1,65 @@ +"""Tests which cover integration of the speculative decoding framework with +tensor parallelism. +""" + +import pytest +import torch + +from vllm.utils import is_hip + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "tensor_parallel_size": 2, + + # Use AsyncLLM engine, so that the engine runs in its own process. + # Otherwise, since vLLM does not follow true SPMD, the test runner + # process will have both the engine and the rank0 worker. NCCL is not + # cleaned up properly, and its server host thread leaks, causing the + # second run of the test to fail with internal NCCL error. + "use_async": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality when tensor parallelism is used. + """ + if is_hip(): + pytest.skip("hip is not well-supported yet") + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index d2da039e84c07..94d71fb012727 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Required for spec decode. - "use_v2_block_manager": True, - - # Verify equality when cuda graphs allowed. - "enforce_eager": False, - "model": "JackFram/llama-68m", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Identical models. - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("output_len", [32]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, - batch_size, output_len): - """Verify spec decode equality when cuda graphs are enabled. - """ - run_greedy_equality_correctness_test( - baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len=output_len, - force_output_len=True, - ) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 9cc776f8324f2..f8ee0f9796bcd 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -219,16 +219,16 @@ def broadcast_tensor_dict( to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, dtypes). """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() + or torch.distributed.get_world_size(group=group) == 1): + return tensor_dict + group = group or torch.distributed.group.WORLD metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return tensor_dict - rank = torch.distributed.get_rank() if rank == src: metadata_list: List[Tuple[Any, Any]] = [] diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 2b72b31b5f070..3ad201f4757ec 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase): def _init_executor(self) -> None: """Initialize the worker and load the model. - - If speculative decoding is enabled, we instead create the speculative - worker. """ - if self.speculative_config is None: - self._init_non_spec_worker() - else: - self._init_spec_worker() + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() def _get_worker_kwargs( self, @@ -45,6 +44,7 @@ def _get_worker_kwargs( distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, + speculative_config=self.speculative_config, is_driver_worker=rank == 0, ) @@ -52,59 +52,22 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): + + if self.speculative_config is None: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + else: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, ) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) return wrapper.worker - def _init_non_spec_worker(self): - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = self._create_worker() - self.driver_worker.init_device() - self.driver_worker.load_model() - - def _init_spec_worker(self): - """Initialize a SpecDecodeWorker, using a draft model for proposals. - """ - assert self.speculative_config is not None - - from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - - target_worker = self._create_worker() - - draft_worker_kwargs = self._get_worker_kwargs() - # Override draft-model specific worker args. - draft_worker_kwargs.update( - model_config=self.speculative_config.draft_model_config, - parallel_config=self.speculative_config.draft_parallel_config, - ngram_prompt_lookup_max=self.speculative_config. - ngram_prompt_lookup_max, - ngram_prompt_lookup_min=self.speculative_config. - ngram_prompt_lookup_min, - # TODO allow draft-model specific load config. - #load_config=self.load_config, - ) - - spec_decode_worker = SpecDecodeWorker.create_worker( - scorer_worker=target_worker, - draft_worker_kwargs=draft_worker_kwargs, - disable_by_batch_size=self.speculative_config. - speculative_disable_by_batch_size, - ) - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - self.driver_worker = spec_decode_worker - - # Load model handled in spec decode worker. - self.driver_worker.init_device() - def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 9cb03ec8c3f5a..dd3ee60682d30 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -28,9 +28,6 @@ class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: - assert (not self.speculative_config - ), "Speculative decoding not yet supported for RayGPU backend." - assert self.parallel_config.distributed_executor_backend == "ray" placement_group = self.parallel_config.placement_group @@ -90,14 +87,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_id, ) + + if self.speculative_config is not None: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + worker = ray.remote( num_cpus=0, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, )(RayWorkerWrapper).remote( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) @@ -107,8 +112,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name="vllm.worker.worker", - worker_class_name="Worker", + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) else: diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a4e759095b294..ef17b8c1e2cc0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -3,6 +3,7 @@ import torch +from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.sequence import (ExecuteModelRequest, SamplerOutput, @@ -17,11 +18,43 @@ get_all_num_logprobs, get_all_seq_ids, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) +from vllm.worker.worker import Worker from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase logger = init_logger(__name__) +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + assert "speculative_config" in kwargs + speculative_config = kwargs.get("speculative_config") + assert speculative_config is not None + + target_worker = Worker(*args, **kwargs) + + draft_worker_kwargs = kwargs.copy() + # Override draft-model specific worker args. + draft_worker_kwargs.update( + model_config=speculative_config.draft_model_config, + parallel_config=speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + # TODO allow draft-model specific load config. + #load_config=load_config, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_by_batch_size=speculative_config. + speculative_disable_by_batch_size, + ) + + return spec_decode_worker + + class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. @@ -142,6 +175,9 @@ def init_device(self) -> None: self._configure_model_sampler_for_spec_decode() + def load_model(self, *args, **kwargs): + pass + def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, @@ -195,39 +231,97 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + def _broadcast_control_flow_decision( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + disable_all_speculation: bool = False) -> Tuple[int, bool]: + """Broadcast how many lookahead slots are scheduled for this step, and + whether all speculation is disabled, to all non-driver workers. + + This is required as if the number of draft model runs changes + dynamically, the non-driver workers won't know unless we perform a + communication to inform then. + + Returns the broadcasted num_lookahead_slots and disable_all_speculation. + """ + + if self.rank == self._driver_rank: + assert execute_model_req is not None + + broadcast_dict = dict( + num_lookahead_slots=execute_model_req.num_lookahead_slots, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + else: + assert execute_model_req is None + broadcast_dict = broadcast_tensor_dict(src=self._driver_rank) + + return (broadcast_dict["num_lookahead_slots"], + broadcast_dict["disable_all_speculation"]) + @torch.inference_mode() def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding " - "requires non-None seq_group_metadata_list") + disable_all_speculation = False + if self.rank == self._driver_rank: + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + + (num_lookahead_slots, + disable_all_speculation) = self._broadcast_control_flow_decision( + execute_model_req, disable_all_speculation) + + if self.rank == self._driver_rank: + assert execute_model_req is not None + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list" + ) + + self._maybe_disable_speculative_tokens( + disable_all_speculation, + execute_model_req.seq_group_metadata_list) + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + else: + self._run_non_driver_rank(num_lookahead_slots) + return [] + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding # to stop trading off throughput for latency. - disable_all = (execute_model_req.running_queue_size >= - self.disable_by_batch_size) - if disable_all: - for seq_group_metadata in execute_model_req.seq_group_metadata_list: - # Once num_speculative_tokens is set to 0, the spec decode - # of this request will be disabled forever. - # TODO(comaniac): We currently store spec decoding specific - # state in the global data structure, but we should maintain - # this state within spec decode worker. - seq_group_metadata.num_speculative_tokens = 0 - - # If no spec tokens, call the proposer and scorer workers normally. - # This happens for prefill, or when the spec decode is disabled - # for this batch. - if execute_model_req.num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all) - - return self._run_speculative_decoding_step(execute_model_req) + disable_all_speculation = (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + + return disable_all_speculation + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, @@ -252,10 +346,28 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.logprobs = None return [sampler_output] + def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + """ + # In non-driver workers the input is None + execute_model_req = None + + # Even if num_lookahead_slots is zero, we want to run the proposer model + # as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we should + # delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model(execute_model_req) + + self.scorer_worker.execute_model(execute_model_req) + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -264,6 +376,7 @@ def _run_speculative_decoding_step( Returns a list of SamplerOutput, each containing a single token per sequence. """ + assert num_lookahead_slots == execute_model_req.num_lookahead_slots # Generate proposals using draft worker. proposals = self.proposer_worker.get_spec_proposals(execute_model_req) @@ -455,6 +568,10 @@ def rank(self): def device(self): return self.scorer_worker.device + @property + def _driver_rank(self) -> int: + return 0 + def get_cache_block_size_bytes(self): """Return the size of a cache block in bytes. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index faea50fbfbf50..97b3873b2a9f6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,7 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + SpeculativeConfig, VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment, @@ -43,6 +43,7 @@ def __init__( distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index fb32feaca0c94..1f04f821eb0f0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -121,7 +121,7 @@ def update_environment_variables(envs: Dict[str, str]) -> None: def init_worker(self, *args, **kwargs): """ Actual initialization of the worker class, and set up - function tracing if required. + function tracing if required. Arguments are passed to the worker class constructor. """ enable_trace_function_call_for_thread() From 5c342570d7e4c73bfaa4c4057174b92117d322bb Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Thu, 16 May 2024 09:36:49 -0400 Subject: [PATCH 286/413] Add marlin unit tests and marlin benchmark script (#4815) --- benchmarks/kernels/benchmark_marlin.py | 183 ++++++++++++++++++ benchmarks/kernels/benchmark_shapes.py | 75 +++++++ tests/kernels/test_marlin_gemm.py | 158 +++++++++++++++ .../layers/quantization/utils/__init__.py | 0 .../layers/quantization/utils/marlin_utils.py | 174 +++++++++++++++++ .../layers/quantization/utils/quant_utils.py | 146 ++++++++++++++ 6 files changed, 736 insertions(+) create mode 100644 benchmarks/kernels/benchmark_marlin.py create mode 100644 benchmarks/kernels/benchmark_shapes.py create mode 100644 tests/kernels/test_marlin_gemm.py create mode 100644 vllm/model_executor/layers/quantization/utils/__init__.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_utils.py create mode 100644 vllm/model_executor/layers/quantization/utils/quant_utils.py diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py new file mode 100644 index 0000000000000..5dcffc284f3d4 --- /dev/null +++ b/benchmarks/kernels/benchmark_marlin.py @@ -0,0 +1,183 @@ +import argparse + +import torch +import torch.utils.benchmark as benchmark +from benchmark_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) + +DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + + +def bench_run(results, model, act_order, is_k_full, num_bits, group_size, + size_m, size_k, size_n): + label = "Quant Matmul" + + sub_label = ("{}, act={} k_full={}, b={}, g={}, " + "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits, + group_size, size_m, size_k, size_n)) + + print(f"Testing: {sub_label}") + + a = torch.randn(size_m, size_k).to(torch.half).cuda() + b = torch.rand(size_k, size_n).to(torch.half).cuda() + + a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + + # Marlin quant + ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_g_idx, + marlin_sort_indices, + marlin_rand_perm, + ) = marlin_quantize(b, num_bits, group_size, act_order) + + # GPTQ quant + (w_ref, q_w, s, g_idx, + rand_perm) = quantize_weights(b, num_bits, group_size, act_order) + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + + # Prepare + marlin_workspace = MarlinWorkspace(size_n) + + globals = { + "marlin_w_ref": marlin_w_ref, + "marlin_q_w": marlin_q_w, + "marlin_s": marlin_s, + "marlin_g_idx": marlin_g_idx, + "marlin_sort_indices": marlin_sort_indices, + "marlin_rand_perm": marlin_rand_perm, + "q_w_gptq": q_w_gptq, + "repack_sort_indices": repack_sort_indices, + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "is_k_full": is_k_full, + "a": a, + "a_tmp": a_tmp, + "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_repack": ops.gptq_marlin_repack, + "marlin_workspace": marlin_workspace, + } + + min_run_time = 1 + + # Warmup pytorch + for i in range(5): + torch.matmul(a, marlin_w_ref) + + results.append( + benchmark.Timer( + stmt="torch.matmul(a, marlin_w_ref)", + globals=globals, + label=label, + sub_label=sub_label, + description="pytorch_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + + results.append( + benchmark.Timer( + stmt= + "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results = [] + + for model in args.models: + for layer in WEIGHT_SHAPES[model]: + size_k = layer[0] + size_n = layer[1] + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for act_order in ACT_ORDER_OPTS: + for is_k_full in K_FULL_OPTS: + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + if len( + args.limit_group_size + ) > 0 and group_size not in args.limit_group_size: + continue + + # For act_order, the group_size must be less than + # size_k + if act_order and (group_size == size_k + or group_size == -1): + continue + + for size_m in args.batch_sizes: + bench_run(results, model, act_order, is_k_full, + num_bits, group_size, size_m, size_k, + size_n) + + compare = benchmark.Compare(results) + compare.print() + + +# For quick benchmarking use: +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 +# +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark Marlin across specified models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py new file mode 100644 index 0000000000000..4eeeca35a37cc --- /dev/null +++ b/benchmarks/kernels/benchmark_shapes.py @@ -0,0 +1,75 @@ +WEIGHT_SHAPES = { + "ideal": [[4 * 256 * 32, 256 * 32]], + "mistralai/Mistral-7B-v0.1/TP1": [ + [4096, 6144], + [4096, 4096], + [4096, 28672], + [14336, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP2": [ + [4096, 3072], + [2048, 4096], + [4096, 14336], + [7168, 4096], + ], + "mistralai/Mistral-7B-v0.1/TP4": [ + [4096, 1536], + [1024, 4096], + [4096, 7168], + [3584, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP1": [ + [4096, 12288], + [4096, 4096], + [4096, 22016], + [11008, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP2": [ + [4096, 6144], + [2048, 4096], + [4096, 11008], + [5504, 4096], + ], + "meta-llama/Llama-2-7b-hf/TP4": [ + [4096, 3072], + [1024, 4096], + [4096, 5504], + [2752, 4096], + ], + "meta-llama/Llama-2-13b-hf/TP1": [ + [5120, 15360], + [5120, 5120], + [5120, 27648], + [13824, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP2": [ + [5120, 7680], + [2560, 5120], + [5120, 13824], + [6912, 5120], + ], + "meta-llama/Llama-2-13b-hf/TP4": [ + [5120, 3840], + [1280, 5120], + [5120, 6912], + [3456, 5120], + ], + "meta-llama/Llama-2-70b-hf/TP1": [ + [8192, 10240], + [8192, 8192], + [8192, 57344], + [28672, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP2": [ + [8192, 5120], + [4096, 8192], + [8192, 28672], + [14336, 8192], + ], + "meta-llama/Llama-2-70b-hf/TP4": [ + [8192, 2560], + [2048, 8192], + [8192, 14336], + [7168, 8192], + ], +} diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py new file mode 100644 index 0000000000000..b0ad85c25c572 --- /dev/null +++ b/tests/kernels/test_marlin_gemm.py @@ -0,0 +1,158 @@ +"""Tests for the marlin kernel. + +Run `pytest tests/kernels/marlin/test_marlin_gemm.py`. +""" +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, quantize_weights, sort_weights) + +ACT_ORDER_OPTS = [False, True] +K_FULL_OPTS = [False, True] + +K_CHUNKS = [128, 256] +N_CHUNKS = [64, 128, 256] + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (1, 7 * 4, 5 * 1), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] + + +def rand_data(shape): + data = torch.rand(shape).to(torch.half).cuda() + return data + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", K_CHUNKS) +@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, + mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Create input + b_weight = rand_data((size_k, size_n)) + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits, + group_size, act_order) + + # Pack to GPTQ format + q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Pack to Marlin format + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits) + + # Run Marlin repack GPU kernel + marlin_q_w_2 = ops.gptq_marlin_repack( + q_w_gptq, + sort_indices, + size_k, + size_n, + num_bits, + ) + torch.cuda.synchronize() + + assert torch.allclose(marlin_q_w_1, marlin_q_w_2) + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", K_CHUNKS) +@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) +@pytest.mark.parametrize("is_k_full", K_FULL_OPTS) +def test_marlin_gemm( + k_chunk, + n_chunk, + num_bits, + group_size, + mnk_factors, + act_order, + is_k_full, +): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, num_bits, group_size, act_order) + + workspace = MarlinWorkspace(size_n) + + output = ops.gptq_marlin_gemm( + a_input, + marlin_q_w, + marlin_s, + g_idx, + sort_indices, + workspace.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + is_k_full, + ) + output_ref = torch.matmul(a_input, w_ref) + + torch.cuda.synchronize() + + assert torch.allclose(output, output_ref, rtol=1e-2) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000..33b3169983475 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,174 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_pack_factor, quantize_weights, sort_weights) + +__cuda_arch = torch.cuda.get_device_capability() + + +def is_marlin_supported(): + return __cuda_arch[0] >= 8 + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def _get_perms(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +_perm = {} +_scale_perm = {} +_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = _get_perms(num_bits) + _perm[num_bits] = perm + _scale_perm[num_bits] = scale_perm + _scale_perm_single[num_bits] = scale_perm_single + + +def marlin_permute_weights(q_w, + size_k, + size_n, + num_bits, + tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape( + (-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(_scale_perm[num_bits])))[:, + _scale_perm[num_bits]] + else: + s = s.reshape( + (-1, + len(_scale_perm_single[num_bits])))[:, + _scale_perm_single[num_bits]] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, + act_order: bool, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, + act_order) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +class MarlinWorkspace: + + def __init__(self, out_features): + assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), ( + "out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}" + .format(out_features, GPTQ_MARLIN_MIN_THREAD_N)) + + max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) * + GPTQ_MARLIN_MAX_PARALLEL) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000..177cb23f63cf4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,146 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + +SUPPORTED_NUM_BITS = [4, 8] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def get_pack_factor(num_bits): + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, + act_order: bool): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + s = s.reshape((-1, size_n)).contiguous() + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res From 99caa4910651754f3f68de518ca42349c8c424d1 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 16 May 2024 21:55:29 +0800 Subject: [PATCH 287/413] [Kernel] add bfloat16 support for gptq marlin kernel (#4788) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 240 +++++++++++++----- .../gptq_marlin/gptq_marlin_dtypes.cuh | 62 +++++ tests/models/test_gptq_marlin.py | 9 +- .../layers/quantization/gptq_marlin.py | 8 +- 4 files changed, 246 insertions(+), 73 deletions(-) create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9c6bff000e916..fdc0ebef4672e 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -20,6 +20,11 @@ */ #include "gptq_marlin.cuh" +#include "gptq_marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); template inline std::string str(T x) { return std::to_string(x); } @@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template ; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { +template +__device__ inline void mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { +template +__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const void *smem_ptr) { uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) { // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant_4bit(int q) { +// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +template +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) { const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; - FragB frag_b; + typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), @@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) { return frag_b; } -__device__ inline FragB dequant_8bit(int q) { +template <> +__device__ inline typename ScalarType::FragB dequant_4bit(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16 +// Reference: +// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +template +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); +} + +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; @@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - FragB frag_b; + typename ScalarType::FragB frag_b; frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); frag_b[1] = __hsub2(*reinterpret_cast(&hi), @@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) { return frag_b; } +template <> +__device__ inline typename ScalarType::FragB dequant_8bit(int q) { + typename ScalarType::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t * fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +template +__device__ inline void scale(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, - FragS &frag_s_3, FragS &frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i]; +template +__device__ inline void scale4(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } // Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float *c, FragS &s) { - __half *s_ptr = reinterpret_cast<__half *>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +template +__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. @@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template ; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; constexpr int pack_factor = 32 / num_bits; @@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll @@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int b_quant = frag_b_quant[k % 2][0][j]; int b_quant_shift = b_quant >> 8; - frag_b0 = dequant_4bit(b_quant); - frag_b1 = dequant_4bit(b_quant_shift); + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); } else { int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); } // Apply scale to frag_b0 if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); + scale(frag_b0, frag_s[k % 2][j], 0); } } // Apply scale to frag_b1 if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); + scale(frag_b1, frag_s[k % 2][j], 1); } } #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); } } }; @@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns auto write = [&](int idx, float c0, float c1, FragS &s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) @@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk res = __hmul2(res, s[0]); } - ((half2 *)sh)[idx] = res; + ((scalar_t2 *)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); } } @@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ @@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) +template void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, void *g_idx, void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, void *workspace, int num_bits, @@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + if (a.scalar_type() == at::ScalarType::Half) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); + } else if (a.scalar_type() == at::ScalarType::BFloat16) { + gptq_marlin::marlin_mm_f16i4( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); + } else { + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + } return c; } -#endif +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh new file mode 100644 index 0000000000000..7881abbe4cbbf --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -0,0 +1,62 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "gptq_marlin.cuh" +#include +#include + + +namespace gptq_marlin { + +template +class ScalarType { +}; + +template <> +class ScalarType { +public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { return __half2float(x); } + + static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } + + static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } +}; + +template <> +class ScalarType { +public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } +#endif +}; + +} + +#endif diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index db55d4488a374..1fc0b3f239127 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -14,6 +14,7 @@ import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT from .utils import check_logprobs_close @@ -52,7 +53,7 @@ @pytest.mark.skipif(gptq_marlin_not_supported, reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["half", "bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models( @@ -76,11 +77,15 @@ def test_models( gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) del gptq_marlin_model + _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error # Run gptq. + # The naive gptq kernel doesn't support bf16 yet. + # Here we always compare fp16/bf16 gpt marlin kernel + # to fp16 gptq kernel. gptq_model = vllm_runner(model_name=model_name, revision=revision, - dtype=dtype, + dtype="half", quantization="gptq", max_model_len=MAX_MODEL_LEN, tensor_parallel_size=1) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index e2464008a875f..354bb55d09e24 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -99,7 +99,7 @@ def get_name(cls) -> str: @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] + return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: @@ -186,9 +186,9 @@ def create_weights( group_size = input_size # Validate dtype - if params_dtype != torch.float16: - raise ValueError( - f"The params dtype must be float16, but got {params_dtype}") + if params_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError(f"The params dtype must be float16 " + f"or bfloat16, but got {params_dtype}") # Validate output_size_per_partition output_size_per_partition = sum(output_partition_sizes) From dbc0754ddfc33e60454ec4a9cdba945f172a39ef Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Thu, 16 May 2024 11:42:17 -0400 Subject: [PATCH 288/413] [docs] Fix typo in examples filename openi -> openai (#4864) --- .../{openi_example_batch.jsonl => openai_example_batch.jsonl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{openi_example_batch.jsonl => openai_example_batch.jsonl} (100%) diff --git a/examples/openi_example_batch.jsonl b/examples/openai_example_batch.jsonl similarity index 100% rename from examples/openi_example_batch.jsonl rename to examples/openai_example_batch.jsonl From 5e0391c0406ec225b2c58bc22d5be864a432fe40 Mon Sep 17 00:00:00 2001 From: Alex Wu Date: Thu, 16 May 2024 11:42:41 -0400 Subject: [PATCH 289/413] [Frontend] Separate OpenAI Batch Runner usage from API Server (#4851) --- vllm/entrypoints/openai/run_batch.py | 2 +- vllm/usage/usage_lib.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 99f1b2d6d091b..731f4f4a4028a 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -101,7 +101,7 @@ async def main(args): engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 9029a5b16af72..40a954a29493e 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -90,6 +90,7 @@ class UsageContext(str, Enum): LLM_CLASS = "LLM_CLASS" API_SERVER = "API_SERVER" OPENAI_API_SERVER = "OPENAI_API_SERVER" + OPENAI_BATCH_RUNNER = "OPENAI_BATCH_RUNNER" ENGINE_CONTEXT = "ENGINE_CONTEXT" From 9216b9cc38e8753de442877863ad651425902e1f Mon Sep 17 00:00:00 2001 From: Pierre Dulac Date: Thu, 16 May 2024 18:42:21 +0200 Subject: [PATCH 290/413] [Bugfix] Bypass authorization API token for preflight requests (#4862) --- vllm/entrypoints/openai/api_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7cd51b959a0ea..97b35262329ee 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -154,6 +154,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): @app.middleware("http") async def authentication(request: Request, call_next): root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) if not request.url.path.startswith(f"{root_path}/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: From 6979ade3840703d055402e78168194644e3cbdd8 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 16 May 2024 12:56:15 -0400 Subject: [PATCH 291/413] Add GPTQ Marlin 2:4 sparse structured support (#4790) Co-authored-by: Robert Shaw --- CMakeLists.txt | 3 +- csrc/ops.h | 11 + csrc/pybind.cpp | 3 +- csrc/quantization/marlin/{ => dense}/LICENSE | 0 .../marlin/{ => dense}/marlin_cuda_kernel.cu | 0 csrc/quantization/marlin/sparse/LICENSE | 203 +++ csrc/quantization/marlin/sparse/common/base.h | 49 + csrc/quantization/marlin/sparse/common/mem.h | 132 ++ csrc/quantization/marlin/sparse/common/mma.h | 175 +++ .../marlin/sparse/marlin_24_cuda_kernel.cu | 1110 +++++++++++++++++ tests/models/test_gptq_marlin_24.py | 81 ++ vllm/_custom_ops.py | 10 + vllm/config.py | 48 +- .../layers/quantization/__init__.py | 11 +- .../layers/quantization/base_config.py | 11 + .../layers/quantization/gptq_marlin.py | 23 + .../layers/quantization/gptq_marlin_24.py | 280 +++++ .../layers/quantization/marlin.py | 22 + 18 files changed, 2131 insertions(+), 41 deletions(-) rename csrc/quantization/marlin/{ => dense}/LICENSE (100%) rename csrc/quantization/marlin/{ => dense}/marlin_cuda_kernel.cu (100%) create mode 100644 csrc/quantization/marlin/sparse/LICENSE create mode 100644 csrc/quantization/marlin/sparse/common/base.h create mode 100644 csrc/quantization/marlin/sparse/common/mem.h create mode 100644 csrc/quantization/marlin/sparse/common/mma.h create mode 100644 csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu create mode 100644 tests/models/test_gptq_marlin_24.py create mode 100644 vllm/model_executor/layers/quantization/gptq_marlin_24.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c7dfe0c048b0..2051d7560be25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,7 +176,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu") diff --git a/csrc/ops.h b/csrc/ops.h index 9541adcb3de88..ef37131c962f8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -125,6 +125,17 @@ torch::Tensor marlin_gemm( int64_t size_n, int64_t size_k); +torch::Tensor gptq_marlin_24_gemm( + torch::Tensor &a, + torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, + int64_t num_bits, + int64_t size_m, + int64_t size_n, + int64_t size_k); + torch::Tensor gptq_marlin_gemm( torch::Tensor &a, torch::Tensor &b_q_weight, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 173e0b1732e13..0339eba70c013 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -66,7 +66,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); + ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); diff --git a/csrc/quantization/marlin/LICENSE b/csrc/quantization/marlin/dense/LICENSE similarity index 100% rename from csrc/quantization/marlin/LICENSE rename to csrc/quantization/marlin/dense/LICENSE diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu similarity index 100% rename from csrc/quantization/marlin/marlin_cuda_kernel.cu rename to csrc/quantization/marlin/dense/marlin_cuda_kernel.cu diff --git a/csrc/quantization/marlin/sparse/LICENSE b/csrc/quantization/marlin/sparse/LICENSE new file mode 100644 index 0000000000000..ca75fb15e660a --- /dev/null +++ b/csrc/quantization/marlin/sparse/LICENSE @@ -0,0 +1,203 @@ +Contains code from https://github.com/IST-DASLab/Sparse-Marlin/ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h new file mode 100644 index 0000000000000..929b39d7642f1 --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/base.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace marlin_24 { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +template struct ShapeBase { + static constexpr int M = M_, N = N_, K = K_; +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragM = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h new file mode 100644 index 0000000000000..a49d15ca544eb --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -0,0 +1,132 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred_zfill(void *smem_ptr, + const void *glob_ptr, + bool pred = true, + const bool zfill = false) { + const int BYTES = 16; + int src_in_bytes = (zfill ? 0 : BYTES); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template __device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_m); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h new file mode 100644 index 0000000000000..9319456677d36 --- /dev/null +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -0,0 +1,175 @@ +/* + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "base.h" + +namespace marlin_24 { + +// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, + const FragA &frag_b, FragC &frag_c, FragM &frag_m, + const int psel) { + const uint32_t *a0 = reinterpret_cast(&a_frag0); + const uint32_t *a1 = reinterpret_cast(&a_frag1); + const uint32_t *b = reinterpret_cast(&frag_b); + const uint32_t *e = reinterpret_cast(&frag_m); + float *c = reinterpret_cast(&frag_c); + if (psel == 0) { + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } else { + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), + "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), + "f"(c[2]), "f"(c[3]), "r"(e[0])); + asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), + "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), + "f"(c[6]), "f"(c[7]), "r"(e[0])); + } +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template __device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2, + float c3) { + uint2 r; + asm("{\n\t" + ".reg .f16 a, b, c, d; \n\t" + "cvt.rn.f16.f32 a, %2; \n\t" + "cvt.rn.f16.f32 b, %3; \n\t" + "cvt.rn.f16.f32 c, %4; \n\t" + "cvt.rn.f16.f32 d, %5; \n\t" + "mov.b32 %0, {a, b}; \n\t" + "mov.b32 %1, {c, d}; \n\t" + "}" + : "=r"(r.x), "=r"(r.y) + : "f"(c0), "f"(c1), "f"(c2), "f"(c3)); + return r; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_4bit(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, + FragS &s0, float *c4, float *c5, float *c6, + float *c7, FragS &s1) { + *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); + *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); + *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); + *c3 = __fmul_rn(*c3, __half2float(s0[1].y)); + + *c4 = __fmul_rn(*c4, __half2float(s1[0].x)); + *c5 = __fmul_rn(*c5, __half2float(s1[0].y)); + *c6 = __fmul_rn(*c6, __half2float(s1[1].x)); + *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); +} + +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu new file mode 100644 index 0000000000000..42b0566183a8d --- /dev/null +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -0,0 +1,1110 @@ +/* + * Notice: This file was modified by Neuralmagic inc to include 8-bit support + * + * Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All + * Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include + +#include "common/base.h" + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +#else + +#include "common/mem.h" +#include "common/mma.h" + +#endif + +template inline std::string str(T x) { return std::to_string(x); } + +namespace marlin_24 { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int THREADS = 256; +static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 128; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void Marlin_24( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 + *__restrict__ meta, // 2bit metadata information about 2:4 format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) {} + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "gptq_marlin_24_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void Marlin_24( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 + *__restrict__ meta, // 2bit metadata information about 2:4 format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + // number of thread_k_blocks in k-dim + int k_tiles = prob_k / 32 / thread_k_blocks; + // number of thread_n_blocks in n-dim + int n_tiles = prob_n / 16 / thread_n_blocks; + // iters needed to cover all slices + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + // number of threadblock tiles in the current slice + int slice_iters; + // total number of active threadblocks in the current slice + int slice_count = 0; + // index of threadblock in current slice; numbered bottom to top + int slice_idx; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 32 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 32 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_o = 4 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + constexpr int pack_factor = 32 / num_bits; + + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4 *meta_ptr[m_sh_iters]; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); + int4 *sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2][b_thread_vecs]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], + B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], + meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if (group_blocks != -1) { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + + // Load meta with ldsm4 + int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); + + } else { + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } + + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + +// Parallel logarithmic shared memory reduction. We make sure to avoid any +// unnecessary read or write iterations, e.g., for two warps we write only once +// by warp 1 and read only once by warp 0. +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped partitioning + // minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, + float c4, float c5, float c6, float c7, FragS &s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2 *tmp = (half2 *)&res; + // for per-column quantization we finally apply the scale here + if constexpr (group_blocks == -1 && num_bits == 4) { + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4 *)sh)[idx] = *((int4 *)&res[0]); + }; + + // RLC: only warp 0 and 1 baseline example + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + write(wr + c_sh_stride, frag_c[i][0][0][1], frag_c[i][1][0][1], + frag_c[i][2][0][1], frag_c[i][3][0][1], frag_s[0][0], + frag_c[i][0][0][3], frag_c[i][1][0][3], frag_c[i][2][0][3], + frag_c[i][3][0][3], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2, frag_c[i][0][1][0], frag_c[i][1][1][0], + frag_c[i][2][1][0], frag_c[i][3][1][0], frag_s[0][0], + frag_c[i][0][1][2], frag_c[i][1][1][2], frag_c[i][2][1][2], + frag_c[i][3][1][2], frag_s[0][2]); + write(wr + 4 * c_sh_stride_2 + c_sh_stride, frag_c[i][0][1][1], + frag_c[i][1][1][1], frag_c[i][2][1][1], frag_c[i][3][1][1], + frag_s[0][0], frag_c[i][0][1][3], frag_c[i][1][1][3], + frag_c[i][2][1][3], frag_c[i][3][1][3], frag_s[0][2]); + + c_sh_wr += 8 * c_sh_stride_2; + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { +#pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { +// We unroll over both the global fetch and the register load pipeline to ensure +// all shared memory accesses are static. Note that both pipelines have even +// length meaning that the next iteration will always start at index 0. +#pragma unroll + for (int pipe = 0; pipe < stages;) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + matmul(pipe); + + pipe++; + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + } + } + thread_block_reduce(); + + if constexpr (group_blocks == -1) { + if constexpr (num_bits == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + } + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], + &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], + &frag_c[i][0][0][2], &frag_c[i][1][0][2], + &frag_c[i][2][0][2], &frag_c[i][3][0][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][0][1], &frag_c[i][1][0][1], + &frag_c[i][2][0][1], &frag_c[i][3][0][1], frag_s[0][0], + &frag_c[i][0][0][3], &frag_c[i][1][0][3], + &frag_c[i][2][0][3], &frag_c[i][3][0][3], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][0], &frag_c[i][1][1][0], + &frag_c[i][2][1][0], &frag_c[i][3][1][0], frag_s[0][0], + &frag_c[i][0][1][2], &frag_c[i][1][1][2], + &frag_c[i][2][1][2], &frag_c[i][3][1][2], + frag_s[0][2]); + + scale_floats(&frag_c[i][0][1][1], &frag_c[i][1][1][1], + &frag_c[i][2][1][1], &frag_c[i][3][1][1], frag_s[0][0], + &frag_c[i][0][1][3], &frag_c[i][1][1][3], + &frag_c[i][2][1][3], &frag_c[i][3][1][3], + frag_s[0][2]); + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#endif + +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ + } + +void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, + void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int num_bits, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16) { + int tot_n = prob_n; + int tot_n_blocks = ceildiv(tot_n, 16); + int pad = 16 * tot_n_blocks - tot_n; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + TORCH_CHECK(sms > 0); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (thread_k == -1 || thread_m == -1) { + if (prob_n <= 16) { + // For small batchizes, better partitioningif is slightly more important than + // better compute utilization + thread_k = 128; + thread_m = 128; + } else { + thread_k = 64; + thread_m = 256; + } + } + + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_m_blocks = thread_m / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + TORCH_CHECK(prob_m % thread_m == 0, "prob_m = ", prob_m, + " is not divisible by thread_m = ", thread_m); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK((prob_k / 2) % group_blocks == 0, "prob_k/2 = ", prob_k / 2, + " is not divisible by group_blocks = ", group_blocks); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + const int4 *meta_ptr = (const int4 *)meta; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; + + int *locks = (int *)workspace; + for (int i = 0; i < tot_n_blocks; i += 4) { + int thread_n_blocks = tot_n_blocks - i; + prob_n = tot_n - 16 * i; + int par = 1; + if (thread_n_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_n_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_n = 64 * par; + i += 4 * (par - 1); + thread_n_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + + // the false is start of the CALL_IF macros + if (false) { + } // BMxBNxBK, group + // 4-bit + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 16, 2, 2, 4) + CALL_IF_2_4(4, 16, 3, 2, -1) + CALL_IF_2_4(4, 16, 3, 2, 4) + CALL_IF_2_4(4, 16, 4, 2, -1) + CALL_IF_2_4(4, 16, 4, 2, 4) + + // 8-bit + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 16, 2, 2, 4) + CALL_IF_2_4(8, 16, 3, 2, -1) + CALL_IF_2_4(8, 16, 3, 2, 4) + CALL_IF_2_4(8, 16, 4, 2, -1) + CALL_IF_2_4(8, 16, 4, 2, 4) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; + } +} + +} // namespace marlin_24 + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_meta, + torch::Tensor &b_scales, + torch::Tensor &workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % marlin_24::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(marlin_24::tile_size)); + TORCH_CHECK((size_k / marlin_24::tile_size / 2) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(marlin_24::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + b_q_weight.size(1) % marlin_24::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin_24::tile_size)); + + int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, + "size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify meta + TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, + "b_meta.size(0) = ", b_meta.size(0), + " is not size_k / 8 / 2 / 2 = ", size_k / 8 / 2 / 2); + TORCH_CHECK(b_meta.size(1) == size_n * 2, "b_meta.size(1) = ", b_meta.size(1), + " is not size_n * 2 = ", size_n * 2); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify b_meta device and strides + TORCH_CHECK(b_meta.device().is_cuda(), "b_meta is not on GPU"); + TORCH_CHECK(b_meta.is_contiguous(), "b_meta is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + int thread_k = -1; + int thread_m = -1; + int sms = -1; + int max_par = 16; + + int groupsize = -1; + if (b_scales.size(0) > 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + groupsize = size_k / b_scales.size(0); + groupsize /= 2; // Because of 24 + } + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 64, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK(size_n % marlin_24::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(marlin_24::min_thread_n)); + int min_workspace_size = + (size_n / marlin_24::min_thread_n) * marlin_24::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + marlin_24::marlin_cuda_2_4( + a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), + num_bits, groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_m, sms, max_par); + + return c; +} diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py new file mode 100644 index 0000000000000..3e6ffb7f90fcc --- /dev/null +++ b/tests/models/test_gptq_marlin_24.py @@ -0,0 +1,81 @@ +"""Compare the outputs of a GPTQ model to a Marlin_24 model. + +Note: GPTQ and Marlin_24 do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. + +Run `pytest tests/models/test_marlin_24.py`. +""" +from dataclasses import dataclass + +import pytest +import torch + +from tests.models.utils import check_logprobs_close +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +marlin_not_supported = (capability < + QUANTIZATION_METHODS["marlin"].get_min_capability()) + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +model_pairs = [ + # 4-bit, group_size == 128 + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"), + # 4-bit, group_size == channelwise + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise", + model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"), + + # 8-bit, group_size == 128 + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"), + # 8-bit, group_size == channelwise + ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise", + model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(marlin_not_supported, + reason="Marlin24 is not supported on this GPU type.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + marlin_24_model = vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="gptq_marlin_24") + marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del marlin_24_model + + gptq_model = vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="gptq") + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + del gptq_model + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=marlin_24_outputs, + name_0="gptq", + name_1="marlin_24", + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 42dedfdf76c4f..95baa84262658 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -153,6 +153,16 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, num_bits: int, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, num_bits, size_m, size_n, + size_k) + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index 91f590aaf79eb..77ce8c318d8f1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,14 +7,11 @@ from transformers import PretrainedConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, - get_quantization_config) +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron -GPTQMarlinConfig = get_quantization_config("gptq_marlin") - if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -155,37 +152,15 @@ def _verify_quantization(self) -> None: quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" - or quant_cfg.get("is_marlin_format", False)) - - # Check which LinearMethod the GPTQ model should use. - if quant_method == "gptq": - # If serialized in Marlin format, use MarlinLinearMethod. - # TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod. - if is_format_marlin: - logger.info("The model is serialized in Marlin format. " - "Using Marlin kernel.") - quant_method = "marlin" - if self.quantization == "gptq": - self.quantization = quant_method - - # If convertible to Marlin format, use GPTQMarlinLinearMethod - # unless the user explicitly specified GPTQLinearMethod. - elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg): - if self.quantization == "gptq": - logger.warning( - "The model is convertible to Marlin format, but " - "you specified quantization=gptq. Use " - "quantization=marlin for faster inference.") - else: - logger.info( - "The model is convertible to Marlin format. " - "Using Marlin kernel.") - quant_method = "gptq_marlin" - if self.quantization == "marlin": - self.quantization = quant_method + + # Detect which checkpoint is it + for name, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break # Verify quantization configurations. if self.quantization is None: @@ -207,7 +182,8 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if (self.quantization not in ["marlin", "gptq_marlin"]): + if (self.quantization + not in ["marlin", "gptq_marlin_24", "gptq_marlin"]): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 5798bc359dcf2..f938e7d37ec5f 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -10,18 +10,23 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, "fp8": Fp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, - "gptq_marlin": GPTQMarlinConfig, - "marlin": MarlinConfig, - "deepspeedfp": DeepSpeedFPConfig } diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index ff5cf0b2bd61a..e7de283b562a6 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -66,6 +66,17 @@ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" raise NotImplementedError + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + @staticmethod def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: """Get a value from the model's quantization config.""" diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 354bb55d09e24..4374fd98012f6 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -6,11 +6,14 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +logger = init_logger(__name__) + GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 @@ -117,6 +120,26 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": is_sym = cls.get_from_keys(config, ["sym"]) return cls(weight_bits, group_size, desc_act, is_sym) + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") + return None + def get_quant_method( self, layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py new file mode 100644 index 0000000000000..1bd6127104654 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -0,0 +1,280 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class GPTQMarlin24Config(QuantizationConfig): + """Config class for Marlin24. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + + if self.weight_bits != 4 and self.weight_bits != 8: + raise ValueError("weight_bits must be 4 or 8. Got = {}".format( + self.weight_bits)) + + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin24, but got group_size of " + f"{self.group_size}") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 128 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "Marlin24Config(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin_24" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + is_marlin_24_format = ( + hf_quant_cfg.get("checkpoint_format") == "marlin_24") + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_marlin_24") + + if is_marlin_24_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. " + "Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlin24LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class GPTQMarlin24LinearMethod(LinearMethodBase): + """Linear method for Marlin24. + + Args: + quant_config: The Marlin24 quantization config. + """ + + def __init__(self, quant_config: GPTQMarlin24Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.tile_size // 2, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + "marlin_tile_size": self.quant_config.tile_size, + }, + ) + + # Meta + meta = Parameter( + torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + requires_grad=False, + ) + set_weight_attrs( + meta, + { + "input_dim": 0, + "packed_dim": 1, + "pack_factor": 1, + "output_dim": 1, + "marlin_tile_size": 2, + }, + ) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + scales = Parameter( + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": None if input_groups == 1 else 0, + "output_dim": 1, + }, + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter(torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + requires_grad=False) + + layer.register_parameter("B_24", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("B_meta", meta) + set_weight_attrs(meta, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B_24 + meta = layer.B_meta + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.weight_bits, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 94aba620ea083..3613c9d9ecf2a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -4,11 +4,14 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +logger = init_logger(__name__) + class MarlinConfig(QuantizationConfig): """Config class for Marlin. @@ -72,6 +75,25 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) return cls(group_size) + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" + or hf_quant_cfg.get("is_marlin_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "marlin") + + if is_marlin_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + def get_quant_method( self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: if isinstance(layer, LinearBase): From f09edd8a25d54c48eb804abe391e98d0b85b9ea2 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 16 May 2024 10:02:56 -0700 Subject: [PATCH 292/413] Add JSON output support for benchmark_latency and benchmark_throughput (#4848) --- .buildkite/run-benchmarks.sh | 7 ++++--- benchmarks/benchmark_latency.py | 20 ++++++++++++++++++-- benchmarks/benchmark_throughput.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 7fbad1c4bd950..1efc96395933f 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.." (which wget && which curl) || (apt-get update && apt-get install -y wget curl) # run python-based benchmarks and upload the result to buildkite -python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt +python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt bench_latency_exit_code=$? -python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt +python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt bench_throughput_exit_code=$? # run server-based benchmarks and upload the result to buildkite @@ -74,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then exit $bench_serving_exit_code fi -/workspace/buildkite-agent artifact upload openai-*.json +rm ShareGPT_V3_unfiltered_cleaned_split.json +/workspace/buildkite-agent artifact upload "*.json" diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8f3168c115ae6..f84e3453947c9 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,5 +1,6 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse +import json import time from pathlib import Path from typing import Optional @@ -96,6 +97,16 @@ def run_to_completion(profile_dir: Optional[str] = None): for percentage, percentile in zip(percentages, percentiles): print(f'{percentage}% percentile latency: {percentile} seconds') + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == '__main__': parser = argparse.ArgumentParser( @@ -149,8 +160,8 @@ def run_to_completion(profile_dir: Optional[str] = None): help= 'Data type for kv cache storage. If "auto", will use model data type. ' 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') parser.add_argument( '--quantization-param-path', type=str, @@ -197,5 +208,10 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 695d06e7b243d..41f443968c3c4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -242,6 +242,18 @@ def main(args: argparse.Namespace): print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the throughput.") @@ -353,6 +365,11 @@ def main(args: argparse.Namespace): default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model From b5853f99639afd82cb18f131dce7f8c41eda74bd Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Thu, 16 May 2024 13:46:52 -0400 Subject: [PATCH 293/413] [ROCm][AMD][Bugfix] adding a missing triton autotune config (#4845) --- vllm/attention/ops/triton_flash_attention.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 1147664183ff1..f94211116a746 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -239,6 +239,16 @@ def _attn_fwd_inner( num_stages=1, num_warps=8, ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), triton.Config( { "BLOCK_M": 128, From e08188081be890f72a8d3dda66e7e1ce0a45216c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 16 May 2024 10:59:52 -0700 Subject: [PATCH 294/413] [Core][Distributed] remove graph mode function (#4818) --- tests/distributed/test_custom_all_reduce.py | 5 +- tests/distributed/test_pynccl.py | 4 +- vllm/distributed/communication_op.py | 70 ++++++++++++--------- vllm/worker/model_runner.py | 38 ++++++----- 4 files changed, 63 insertions(+), 54 deletions(-) diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 630b25a3b6132..186f9faa6bfb6 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: - with graph_capture(): + with graph_capture() as graph_capture_context: # use integers so result matches NCCL exactly inp1 = torch.randint(1, 16, (sz, ), @@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): device=torch.cuda.current_device()) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): + with torch.cuda.graph(graph, + stream=graph_capture_context.stream): for i in range(num_communication): out1 = tensor_model_parallel_all_reduce(inp1) # the input buffer is immediately modified to test diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a0f7500bf0ee9..529e75fb2c9e3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -5,7 +5,7 @@ import torch from vllm.distributed.communication_op import ( # noqa - graph_mode, tensor_model_parallel_all_reduce) + graph_capture, tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, @@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - with graph_mode(): + with graph_capture(): # two tp groups can communicate independently if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index f8ee0f9796bcd..937fd4d392713 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,6 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -13,45 +14,54 @@ get_tp_pynccl_communicator) -@contextmanager -def graph_mode(): - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the tensor size - # is too large, it will fallback to the next available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - pynccl_comm = get_tp_pynccl_communicator() - if pynccl_comm is None: - context = nullcontext() - else: - context = pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()) - with context: - yield +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream @contextmanager def graph_capture(): """ - `graph_capture` is a context manager which should include the code that + `graph_capture` is a context manager which should surround the code that is capturing the CUDA graph. Its main purpose is to ensure that the some operations will be run after the graph is captured, before the graph - is replayed. + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. """ + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) ca_comm = get_tp_ca_communicator() - context = nullcontext() if ca_comm is None else ca_comm.capture() - with context: - yield + maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the tensor + # size is too large, it will fallback to the next available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or pynccl if it is disabled or not supported. + pynccl_comm = get_tp_pynccl_communicator() + if pynccl_comm is None: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 623a0bc32211c..460f98d7a826e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,7 +10,7 @@ ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture, graph_mode +from vllm.distributed.communication_op import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -841,7 +841,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - with graph_capture(): + with graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for batch_size in reversed(batch_size_capture_list): @@ -877,6 +877,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, + stream=graph_capture_context.stream, ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -921,15 +922,27 @@ def capture( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - memory_pool, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, **kwargs, ) -> None: assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - with graph_mode(): - self.model( + self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) + torch.cuda.synchronize() + + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + hidden_states = self.model( input_ids, positions, kv_caches, @@ -938,21 +951,6 @@ def capture( ) torch.cuda.synchronize() - # Capture the graph. - # NOTE(woosuk): Python 3.8 does not support multi-line with statements. - # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self._graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 - with graph_mode(): - hidden_states = self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - **kwargs, - ) - torch.cuda.synchronize() - # Save the input and output buffers. self.input_buffers = { "input_ids": input_ids, From 10fa9eea21ae757d17c1369afa6172598db3be92 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 16 May 2024 11:07:41 -0700 Subject: [PATCH 295/413] [Misc] remove old comments (#4866) --- vllm/worker/model_runner.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 460f98d7a826e..cd7af25654b52 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -887,16 +887,6 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # This usually takes < 10 seconds. logger.info("Graph capturing finished in %.0f secs.", elapsed_time) - def __del__(self) -> None: - # Delete the CUDA graphs before deleting the pynccl communicator. - # NOTE(woosuk): This is necessary because otherwise deadlocks can - # happen. - # FIXME(woosuk): This is a bit hacky. Find a more robust solution. - # TODO(youkaichao): when we get enough user feedback that pynccl is - # more stable than cupy, we can remove this, e.g. in v0.4.1. - self.graph_runners.clear() - self.pynccl_backend = None - @property def vocab_size(self) -> int: return self.model_config.get_vocab_size() From 8435b207af398cb6cec961f8ac8e1d8bb5164b3e Mon Sep 17 00:00:00 2001 From: Silencio <19430328+Silencioo@users.noreply.github.com> Date: Fri, 17 May 2024 02:16:09 +0800 Subject: [PATCH 296/413] [Kernel] Add punica dimension for Qwen1.5-32B LoRA (#4850) Co-authored-by: Silencio --- csrc/punica/bgmv/bgmv_config.h | 2 ++ tests/lora/test_punica.py | 1 + 2 files changed, 3 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 19c058cacfbc4..98ac8de779e13 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -53,6 +53,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ + f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ @@ -121,6 +122,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 22016, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index fd2a1b75f460c..193e3906997c4 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -79,6 +79,7 @@ def _lora_ref_impl( 22016, 24576, 27392, + 27648, 32000, 32256, 32512, From 2060e93659f1f63a3d2a76aee61559ccb1fe732e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Thu, 16 May 2024 18:32:50 -0400 Subject: [PATCH 297/413] [Kernel] Add w8a8 CUTLASS kernels (#4749) --- CMakeLists.txt | 27 +- csrc/ops.h | 8 + csrc/pybind.cpp | 1 + csrc/quantization/cutlass_w8a8/common.hpp | 12 + .../cutlass_visitor_2x_broadcast_epilogue.hpp | 340 ++++++++++++++++++ .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 296 +++++++++++++++ .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 240 +++++++++++++ .../cutlass_w8a8/scaled_mm_dq_entry.cu | 65 ++++ tests/kernels/test_cutlass.py | 192 ++++++++++ vllm/_custom_ops.py | 18 +- 10 files changed, 1197 insertions(+), 2 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/common.hpp create mode 100644 csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu create mode 100644 tests/kernels/test_cutlass.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2051d7560be25..35846fd1cfa99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,6 +173,16 @@ set(VLLM_EXT_SRC "csrc/pybind.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") + include(FetchContent) + SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) + FetchContent_Declare( + cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + # CUTLASS 3.5.0 + GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc + ) + FetchContent_MakeAvailable(cutlass) + list(APPEND VLLM_EXT_SRC "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" @@ -180,7 +190,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/custom_all_reduce.cu") + "csrc/custom_all_reduce.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu") + + # + # The CUTLASS kernels for Hopper require sm90a to be enabled. + # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. + # That adds an extra 17MB to compiled binary, so instead we selectively enable it. + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() define_gpu_extension_target( @@ -190,6 +214,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} WITH_SOABI) # diff --git a/csrc/ops.h b/csrc/ops.h index ef37131c962f8..8c2c2ae6e1f5a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack( int64_t size_k, int64_t size_n, int64_t num_bits); + +int cutlass_scaled_mm_dq( + torch::Tensor& out, + torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 0339eba70c013..f5b4865506568 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp new file mode 100644 index 0000000000000..999b7b251ab33 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/common.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "cutlass/cutlass.h" + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + cutlassGetStatusString(status)) \ + } diff --git a/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp new file mode 100644 index 0000000000000..ddbee15e54ab6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass It's beem modified to support either +// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x. +// Important because this saves us a factor 4x on the number of kernels +// compiled. +// +#pragma once + +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +// clang-format on + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load(dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = params_ptr->null_default; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if(get<1>(coord_v(i)) < n) + { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->ptr_col) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if(pred(i)){ + dst_v(i) = params_ptr->null_default; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu new file mode 100644 index 0000000000000..3ec454f78c654 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -0,0 +1,296 @@ +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/util/device_memory.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + +#include "cutlass_visitor_2x_broadcast_epilogue.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for + NVIDIA GPUs with SM versions prior to sm90 (Hopper). + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_2x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using Operator = + typename std::conditional, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type; + + using OutputTileThreadMap = + cutlass::epilogue::threadblock::OutputTileThreadLayout< + TileShape, WarpShape, float, 4, 1 /* epilogue stages */ + >; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<0>, Int<0>>>; + + using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, float, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::threadblock::Sm80EVT; + + using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest, + Stride, Int<0>>>; + + using EVTD = cutlass::epilogue::threadblock::Sm80EVT; + + // clang-format off + using RowMajor = typename cutlass::layout::RowMajor; + using ColumnMajor = typename cutlass::layout::ColumnMajor; + using KernelType = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16, + ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16, + float, cutlass::layout::RowMajor, 4, + ElementAcc, float, cutlass::arch::OpClassTensorOp, + Arch, + TileShape, WarpShape, InstructionShape, + EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + MainLoopStages, Operator, + 1 /* epilogue stages */ + >::GemmKernel; + // clang-format on + + using Op = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + cutlass::gemm::GemmCoord problem_size{m, n, k}; + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideC = Stride, Int<0>>; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); + + auto a_scales_ptr = a_scales.data_ptr(); + auto b_scales_ptr = b_scales.data_ptr(); + + // If A and B are quantized per-tensor, then these scale tensors are scalars, + // and they are passed in via the second argument. + using ScaleAArgs = typename Gemm::ScaleA::Arguments; + ScaleAArgs a_args = a_scales.numel() == 1 + ? ScaleAArgs{nullptr, a_scales.item(), {}} + : ScaleAArgs{a_scales.data_ptr(), {}, {}}; + + using ScaleBArgs = typename Gemm::ScaleB::Arguments; + ScaleBArgs b_args = b_scales.numel() == 1 + ? ScaleBArgs{nullptr, b_scales.item(), {}} + : ScaleBArgs{b_scales.data_ptr(), {}, {}}; + + typename Gemm::EVTCompute0::Arguments evt0_compute_args{b_args}; + + typename Gemm::EVTCompute1::Arguments evt1_compute_args{a_args, + evt0_compute_args}; + typename Gemm::D::Arguments d_args{c_ptr, c_stride}; + + typename Gemm::EVTD::Arguments epilogue_args{ + evt1_compute_args, + d_args, + }; + + typename Gemm::Op::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode + problem_size, // problem size + 1, // batch count + epilogue_args, + a_ptr, + b_ptr, + nullptr, + nullptr, + 0, + 0, + 0, + 0, + lda, + ldb, + ldc, + ldc}; + + // Launch the CUTLASS GEMM kernel. + typename Gemm::Op gemm_op; + size_t workspace_size = gemm_op.get_workspace_size(args); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(args)); + cutlass::Status status = gemm_op(args, workspace.get()); + CUTLASS_CHECK(status); +} + +} // namespace + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>(out, a, b, a_scales, + b_scales); + } +} + +void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; + + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } else { + assert(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher< + cutlass_2x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + return cutlass_scaled_mm_dq_dispatcher>(out, a, b, a_scales, + b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu new file mode 100644 index 0000000000000..37b096de23e3b --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -0,0 +1,240 @@ +#include + +#include +#include +#include + +// clang-format will break include orders +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This defines a quantized GEMM operation with dequantized output, similar to + torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for + NVIDIA GPUs with sm90a (Hopper) or later. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ + +namespace { + +template +struct cutlass_3x_gemm { + using ElementAB = ElementAB_; + using ElementD = ElementD_; + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, + Stride, Int<0>, Int<0>>>; + + using ScaleBDescriptor = + cutlass::epilogue::collective::detail::RowBroadcastDescriptor< + EpilogueDescriptor, float>; + + using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast< + ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape, + typename ScaleBDescriptor::Element, Stride, Int<1>, Int<0>>>; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute1 = + cutlass::epilogue::fusion::Sm90EVT; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + EpilogueSchedule, EVTCompute1>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, Int<0>>; + using StrideB = Stride, Int<0>>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, Int<0>{}}; + StrideB b_stride{ldb, Int<1>{}, Int<0>{}}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + prob_shape, mainloop_args, epilogue_args}; + + using ScaleA_Args = typename Gemm::ScaleA::Arguments; + using ScaleB_Args = typename Gemm::ScaleB::Arguments; + ScaleA_Args a_args = a_scales.numel() == 1 + ? ScaleA_Args{nullptr, a_scales.item(), {}} + : ScaleA_Args{a_scales.data_ptr(), {}, {}}; + + ScaleB_Args b_args = b_scales.numel() == 1 + ? ScaleB_Args{nullptr, b_scales.item(), {}} + : ScaleB_Args{b_scales.data_ptr(), {}, {}}; + + args.epilogue.thread = {a_args, {b_args}}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + TORCH_CHECK(workspace_size == 0); + + cutlass::Status status = gemm_op.run(args); + CUTLASS_CHECK(status); +} +} // namespace + +void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (a.dtype() == torch::kInt8) { + TORCH_CHECK(b.dtype() == torch::kInt8); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } else { + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = + typename cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative; + using EpilogueSchedule = + typename cutlass::epilogue::TmaWarpSpecializedCooperative; + + if (out.dtype() == torch::kBFloat16) { + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + return cutlass_scaled_mm_dq_dispatcher< + cutlass_3x_gemm>( + out, a, b, a_scales, b_scales); + } + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu new file mode 100644 index 0000000000000..a4e696d4a3322 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -0,0 +1,65 @@ +#include +#include +#include + +void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); + +void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, + torch::Tensor const &b, torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + int32_t major_capability; + int32_t minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + + // Checks for conformality + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + + if (version_num >= 90) { + // Hopper + cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); + } else if (version_num == 89) { + // Ada Lovelace + cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); + } else if (version_num >= 80) { + // Ampere + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); + } else { + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_dq_sm75(c, a, b, a_scales, b_scales); + } +} diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py new file mode 100644 index 0000000000000..fdfd1dee29ce6 --- /dev/null +++ b/tests/kernels/test_cutlass.py @@ -0,0 +1,192 @@ +"""Tests for cutlass kernels + +Run `pytest tests/kernels/test_cutlass.py`. +""" +from typing import Type + +import pytest +import torch + +from vllm import _custom_ops as ops + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] + + +def to_fp8(tensor: torch.tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def cutlass_fp8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_fp8(torch.randn((m, k), device=device)) + b = to_fp8(torch.randn((n, k), device=device).t()) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1) + + +def cutlass_int8_gemm_helper(m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = to_int8(torch.randn((m, k), device=device) * 5) + b = to_int8(torch.randn((n, k), device=device).t() * 5) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn( + (m_a_scales, 1), device=device, dtype=torch.float32) / 10) + scale_b = (torch.randn( + (1, n_b_scales), device=device, dtype=torch.float32) / 10) + + out = ops.cutlass_scaled_mm_dq(a, b, scale_a, scale_b, out_dtype) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=out_dtype) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("m", [512, 222, 33, 1]) +@pytest.mark.parametrize("n", [2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 496, 1024]) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool, + per_out_ch: bool): + cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool, + out_dtype: Type[torch.dtype]): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + out_dtype) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool, + device: str): + cutlass_int8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, + torch.bfloat16, device) + + +# For the following two tests: +# N and K correspond to the size of the weight matrix and likely to be multiples +# of a large power of two. In any case, the kernel will have a naive fallback +# when N and K are not divisible by 16. But M is the number of tokens and the +# kernel must handle any M thrown at it. +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.skipif(capability < 89, + reason="FP8 is not supported on this GPU type.") +def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool): + for nk in range(32, 128, 32): + for m in range(1, 128): + cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch) + + +# Test working with a subset of A and B +def test_cutlass_subset(): + big_m, big_n, big_k = 1024, 1024, 1024 + m, n, k = 512, 512, 512 + + whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5) + whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5) + a = whole_a[0:m, 0:k] + b = whole_b[0:k, 0:n] + + scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10 + + out = ops.cutlass_scaled_mm_dq(a, + b, + scale_a, + scale_b, + out_dtype=torch.bfloat16) + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * + b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) + + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 95baa84262658..9e7d0d96bf004 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Type import torch @@ -163,6 +163,22 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_k) +# cutlass +def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + out_dtype: Type[torch.dtype]) -> torch.Tensor: + assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + vllm_ops.cutlass_scaled_mm_dq(out, a, b, a_scales, b_scales) + + return out + + # aqlm def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, From 9a31a817a85ac4249bf82dd8b6f90ef6b8e81fef Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 16 May 2024 15:42:29 -0700 Subject: [PATCH 298/413] [Bugfix] Fix FP8 KV cache support (#4869) --- vllm/attention/backends/flash_attn.py | 10 +++++----- vllm/attention/backends/flashinfer.py | 10 +++++----- vllm/attention/backends/rocm_flash_attn.py | 10 +++++----- vllm/attention/backends/torch_sdpa.py | 10 +++++----- vllm/attention/backends/xformers.py | 10 +++++----- vllm/attention/layer.py | 2 +- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 5d1f65819ed4e..856f399741375 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -200,15 +200,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5f9fd586fb70e..7210fefbd8162 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -164,15 +164,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1a94dc3596358..bb828d6fc04fe 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -197,15 +197,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a3f72b9c94566..a19c97e1e0e35 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -96,15 +96,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index fc46af054de4f..96169da6cf92c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -208,15 +208,15 @@ def __init__( num_heads: int, head_size: int, scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, - kv_cache_dtype: str = "auto", + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_heads = num_kv_heads if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 126692d8c9b40..4299726bdca4b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -48,7 +48,7 @@ def __init__( block_size) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window) + alibi_slopes, sliding_window, kv_cache_dtype) def forward( self, From 8e7fb5d43ae74e0a75a7da940a63c7891208d268 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Fri, 17 May 2024 07:37:29 +0800 Subject: [PATCH 299/413] Support to serve vLLM on Kubernetes with LWS (#4829) Signed-off-by: kerthcet --- docs/source/serving/deploying_with_lws.rst | 12 ++++++++++++ docs/source/serving/integrations.rst | 1 + 2 files changed, 13 insertions(+) create mode 100644 docs/source/serving/deploying_with_lws.rst diff --git a/docs/source/serving/deploying_with_lws.rst b/docs/source/serving/deploying_with_lws.rst new file mode 100644 index 0000000000000..b63a432dde0d5 --- /dev/null +++ b/docs/source/serving/deploying_with_lws.rst @@ -0,0 +1,12 @@ +.. _deploying_with_lws: + +Deploying with LWS +============================ + +LeaderWorkerSet (LWS) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads. +A major use case is for multi-host/multi-node distributed inference. + +vLLM can be deployed with `LWS `_ on Kubernetes for distributed model serving. + +Please see `this guide `_ for more details on +deploying vLLM on Kubernetes using LWS. diff --git a/docs/source/serving/integrations.rst b/docs/source/serving/integrations.rst index 93872397913e3..2066e80b03298 100644 --- a/docs/source/serving/integrations.rst +++ b/docs/source/serving/integrations.rst @@ -8,4 +8,5 @@ Integrations deploying_with_kserve deploying_with_triton deploying_with_bentoml + deploying_with_lws serving_with_langchain From 0150a1063029f0238c25bc5a2ea0943b9650d522 Mon Sep 17 00:00:00 2001 From: bofeng huang Date: Fri, 17 May 2024 03:47:22 +0200 Subject: [PATCH 300/413] [Frontend] OpenAI API server: Do not add bos token by default when encoding (#4688) --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 32 +++++++++++++++-------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c86e41c601be0..7e179362eef8a 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -158,7 +158,7 @@ async def create_chat_completion( try: # Tokenize/detokenize depending on prompt format (string/token list) prompt_ids, prompt_text = self._validate_prompt_and_tokenize( - request, prompt=prompt) + request, prompt=prompt, add_special_tokens=False) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) decoding_config = await self.engine.get_decoding_config() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 58a1c2f7e73fe..db3fc85decd70 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import Field from typing_extensions import Annotated @@ -165,13 +165,14 @@ def _maybe_get_lora( raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest, - EmbeddingRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None - ) -> Tuple[List[int], str]: + self, + request: Union[ChatCompletionRequest, CompletionRequest, + EmbeddingRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None, + truncate_prompt_tokens: Optional[Annotated[int, + Field(ge=1)]] = None, + add_special_tokens: bool = True) -> Tuple[List[int], str]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") if (prompt and prompt_ids): @@ -179,10 +180,19 @@ def _validate_prompt_and_tokenize( "Only one of prompt or prompt_ids should be provided.") if prompt_ids is None: - tokenizer_kwargs = {} if truncate_prompt_tokens is None else { - "truncation": True, - "max_length": truncate_prompt_tokens, + # When using OpenAIServingChat for chat completions, the + # special tokens (e.g., BOS) have already been added by the + # chat template. Therefore, we do not need to add them again. + # Set add_special_tokens to False to avoid adding the BOS tokens + # again. + tokenizer_kwargs: Dict[str, Any] = { + "add_special_tokens": add_special_tokens } + if truncate_prompt_tokens is not None: + tokenizer_kwargs.update({ + "truncation": True, + "max_length": truncate_prompt_tokens, + }) input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids elif truncate_prompt_tokens is not None: input_ids = prompt_ids[-truncate_prompt_tokens:] From 26148120b3c05704409a425d017f0a51fca3b7cc Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Thu, 16 May 2024 22:58:25 -0500 Subject: [PATCH 301/413] [Build/CI] Extending the set of AMD tests with Regression, Basic Correctness, Distributed, Engine, Llava Tests (#4797) --- .buildkite/run-amd-test.sh | 11 ++++++----- .buildkite/test-pipeline.yaml | 18 +++++++++++++++--- .buildkite/test-template.j2 | 3 +-- tests/engine/test_stop_reason.py | 6 +++++- vllm/config.py | 10 +--------- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index ce508e4748aba..7452423479521 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,4 +1,4 @@ -# This script build the ROCm docker image and runs test inside it. +# This script runs test inside the corresponding ROCm docker container. set -ex # Print ROCm version @@ -19,15 +19,16 @@ done echo "--- Building container" sha=$(git rev-parse --short HEAD) -container_name=rocm_${sha} +image_name=rocm_${sha} +container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo) docker build \ - -t ${container_name} \ + -t ${image_name} \ -f Dockerfile.rocm \ --progress plain \ . remove_docker_container() { - docker rm -f ${container_name} || docker image rm -f ${container_name} || true + docker rm -f ${container_name} || docker image rm -f ${image_name} || true } trap remove_docker_container EXIT @@ -39,6 +40,6 @@ docker run \ --rm \ -e HF_TOKEN \ --name ${container_name} \ - ${container_name} \ + ${image_name} \ /bin/bash -c "${@}" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index aa74672f4bf67..d9819881fbbfc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -5,13 +5,16 @@ steps: - label: Regression Test + mirror_hardwares: [amd] command: pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: AsyncEngine Test + #mirror_hardwares: [amd] command: pytest -v -s async_engine - label: Basic Correctness Test + mirror_hardwares: [amd] commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py @@ -24,14 +27,15 @@ steps: command: pytest -v -s core - label: Distributed Comm Ops Test + #mirror_hardwares: [amd] command: pytest -v -s distributed/test_comm_ops.py working_dir: "/vllm-workspace/tests" num_gpus: 2 - label: Distributed Tests + mirror_hardwares: [amd] working_dir: "/vllm-workspace/tests" num_gpus: 2 - mirror_hardwares: [amd] commands: - pytest -v -s distributed/test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py @@ -45,16 +49,18 @@ steps: - pytest -v -s spec_decode/e2e/test_integration_dist.py - label: Distributed Tests (Multiple Groups) + #mirror_hardwares: [amd] working_dir: "/vllm-workspace/tests" num_gpus: 4 commands: - pytest -v -s distributed/test_pynccl.py - label: Engine Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test + #mirror_hardwares: [amd] commands: # these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py @@ -74,6 +80,7 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - label: Kernels Test %N + #mirror_hardwares: [amd] command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 @@ -84,7 +91,7 @@ steps: - pytest -v -s models --ignore=models/test_llava.py - label: Llava Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py @@ -95,6 +102,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test + #mirror_hardwares: [amd] command: pytest -v -s samplers - label: LogitsProcessor Test @@ -110,16 +118,20 @@ steps: command: pytest -v -s spec_decode - label: LoRA Test %N + #mirror_hardwares: [amd] command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Tensorizer Test + #mirror_hardwares: [amd] command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader - label: Metrics Test + mirror_hardwares: [amd] command: pytest -v -s metrics - label: Quantization Test + #mirror_hardwares: [amd] command: pytest -v -s quantization - label: Benchmarks diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 174c756ae74a3..265833e2ccf6e 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -3,9 +3,8 @@ {% set default_working_dir = "/vllm-workspace/tests" %} steps: - - label: ":docker: build image" - commands: + commands: - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker push {{ docker_image }}" env: diff --git a/tests/engine/test_stop_reason.py b/tests/engine/test_stop_reason.py index b2f521a8ae4ce..7b886507c04f2 100644 --- a/tests/engine/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -32,6 +32,7 @@ def test_stop_reason(vllm_model, example_prompts): # test stop token outputs = llm.generate(example_prompts, sampling_params=SamplingParams( + ignore_eos=True, seed=SEED, max_tokens=MAX_TOKENS, stop_token_ids=[stop_token_id])) @@ -43,7 +44,10 @@ def test_stop_reason(vllm_model, example_prompts): # test stop string outputs = llm.generate(example_prompts, sampling_params=SamplingParams( - seed=SEED, max_tokens=MAX_TOKENS, stop=".")) + ignore_eos=True, + seed=SEED, + max_tokens=MAX_TOKENS, + stop=".")) for output in outputs: output = output.outputs[0] assert output.finish_reason == "stop" diff --git a/vllm/config.py b/vllm/config.py index 77ce8c318d8f1..6be8f353aa389 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1060,7 +1060,7 @@ def get_image_input_enum_type( "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] +_ROCM_NOT_SUPPORTED_DTYPE: List[str] = [] # def _get_and_verify_dtype( @@ -1092,14 +1092,6 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") - if is_hip() and torch_dtype == torch.float32: - rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() - if (k not in _ROCM_NOT_SUPPORTED_DTYPE) - ] - raise ValueError(f"dtype '{dtype}' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") - # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: From 33e0823de583819f39e88c39ea3f7dd4e07c3990 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Fri, 17 May 2024 17:43:34 +0800 Subject: [PATCH 302/413] [Bugfix] fix rope error when load models with different dtypes (#4835) --- tests/kernels/test_pos_encoding.py | 44 ++++++++++++++++++- .../model_executor/layers/rotary_embedding.py | 33 +++++++++----- 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 18c8e351aa778..076730cdbae0d 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,4 +1,4 @@ -from itertools import accumulate +from itertools import accumulate, product from typing import List, Optional import pytest @@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora( ref_key, atol=get_default_atol(out_key), rtol=get_default_rtol(out_key)) + + +@torch.inference_mode() +def test_rope_module_cache(): + MAX_POSITIONS = [123, 1234] + BASES = [10000, 1000000] + ROPE_SCALINGS = [ + None, { + "type": "linear", + "factor": (1, ) + }, { + "type": "dynamic", + "factor": 1 + } + ] + settings = [ + HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, + ROPE_SCALINGS, DTYPES + ] + rope_setting_id_map = {} + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # different settings cannot share the same rope module + assert id(rope) not in rope_setting_id_map.values() + assert all(x.dtype == dtype for x in rope.buffers()) + assert all(x.dtype == dtype for x in rope.parameters()) + rope_setting_id_map[str(setting)] = id(rope) + + for setting in product(*settings): + head_size, rotary_dim, max_position, base, \ + is_neox_stype, rope_scaling, dtype = setting + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_stype, rope_scaling, dtype) + # check if cache take effect + assert id(rope) == rope_setting_id_map[str(setting)] diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f41e0f30a4e4b..4758ca9660083 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -53,6 +53,7 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, ) -> None: super().__init__() self.head_size = head_size @@ -62,7 +63,7 @@ def __init__( self.is_neox_style = is_neox_style cache = self._compute_cos_sin_cache() - cache = cache.to(torch.get_default_dtype()) + cache = cache.to(dtype) self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: @@ -178,12 +179,13 @@ def __init__( base: int, is_neox_style: bool, scaling_factors: Union[List[float], float], + dtype: torch.dtype, ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) @@ -219,10 +221,11 @@ def __init__( base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE(woosuk): self.max_position_embeddings is the original @@ -299,6 +302,7 @@ def __init__( base: int, is_neox_style: bool, scaling_factor: float, + dtype: torch.dtype, *, extrapolation_factor: float = 1, attn_factor: float = 1, @@ -314,7 +318,7 @@ def __init__( self.mscale = float( _yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style) + is_neox_style, dtype) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**( @@ -359,6 +363,7 @@ def __init__( original_max_position_embeddings: int, base: int, is_neox_style: bool, + dtype: torch.dtype, short_factor: List[float], long_factor: List[float], short_mscale: float = 1.1, @@ -385,14 +390,14 @@ def __init__( short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) - short_cache = short_cache.to(torch.get_default_dtype()) + short_cache = short_cache.to(dtype) self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) long_cache = self._compute_cos_sin_cache(max_position_embeddings, long_factor, long_mscale) - long_cache = long_cache.to(torch.get_default_dtype()) + long_cache = long_cache.to(dtype) self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) @@ -463,7 +468,10 @@ def get_rope( base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, ) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls rope_scaling_tuple = { @@ -474,12 +482,12 @@ def get_rope( else: rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args) + rope_scaling_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, - is_neox_style) + is_neox_style, dtype) else: scaling_type = rope_scaling["type"] if scaling_type != "su": @@ -488,11 +496,11 @@ def get_rope( rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, - scaling_factor) + scaling_factor, dtype) elif scaling_type == "yarn": original_max_position = rope_scaling[ "original_max_position_embeddings"] @@ -505,7 +513,7 @@ def get_rope( rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, - scaling_factor, + scaling_factor, dtype, **extra_kwargs) elif scaling_type == "su": short_factor = rope_scaling["short_factor"] @@ -519,7 +527,8 @@ def get_rope( } rotary_emb = Phi3SuScaledRotaryEmbedding( head_size, rotary_dim, max_position, original_max_position, - base, is_neox_style, short_factor, long_factor, **extra_kwargs) + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb From 48d5985a088c6e13e9ad9b0c7a0ce846e30b529f Mon Sep 17 00:00:00 2001 From: eigenLiu <33959526+eigen2017@users.noreply.github.com> Date: Sat, 18 May 2024 00:43:19 +0800 Subject: [PATCH 303/413] Sync huggingface modifications of qwen Moe model (#4774) --- vllm/model_executor/models/qwen2_moe.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 2a3b0173adf8b..a0d3b0406ef4a 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -283,8 +283,9 @@ def __init__( cache_config=cache_config, quant_config=quant_config, ) - if (config.num_experts is not None - and (layer_idx + 1) % config.decoder_sparse_step == 0): + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config) else: @@ -439,6 +440,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_expert." in name) and name not in params_dict): continue + if name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -451,6 +455,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_expert." in name) and name not in params_dict): continue + if name not in params_dict: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From c5711ef98519de25d1f51121f7848a13f2891fc1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 17 May 2024 10:52:11 -0700 Subject: [PATCH 304/413] [Doc] Update Ray Data distributed offline inference example (#4871) --- examples/offline_inference_distributed.py | 48 ++++++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_distributed.py b/examples/offline_inference_distributed.py index e4f085fa6665a..1e59e89509724 100644 --- a/examples/offline_inference_distributed.py +++ b/examples/offline_inference_distributed.py @@ -9,19 +9,31 @@ import numpy as np import ray +from packaging.version import Version +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM, SamplingParams +assert Version(ray.__version__) >= Version( + "2.22.0"), "Ray version must be at least 2.22.0" + # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +# Set tensor parallelism per instance. +tensor_parallel_size = 1 + +# Set number of instances. Each instance will use tensor_parallel_size GPUs. +num_instances = 1 + # Create a class to do batch inference. class LLMPredictor: def __init__(self): # Create an LLM. - self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf") + self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + tensor_parallel_size=tensor_parallel_size) def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # Generate texts from the prompts. @@ -43,17 +55,41 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]: # from cloud storage (such as JSONL, Parquet, CSV, binary format). ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt") + +# For tensor_parallel_size > 1, we need to create placement groups for vLLM +# to use. Every actor has to have its own placement group. +def scheduling_strategy_fn(): + # One bundle per tensor parallel worker + pg = ray.util.placement_group( + [{ + "GPU": 1, + "CPU": 1 + }] * tensor_parallel_size, + strategy="STRICT_PACK", + ) + return dict(scheduling_strategy=PlacementGroupSchedulingStrategy( + pg, placement_group_capture_child_tasks=True)) + + +resources_kwarg = {} +if tensor_parallel_size == 1: + # For tensor_parallel_size == 1, we simply set num_gpus=1. + resources_kwarg["num_gpus"] = 1 +else: + # Otherwise, we have to set num_gpus=0 and provide + # a function that will create a placement group for + # each instance. + resources_kwarg["num_gpus"] = 0 + resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn + # Apply batch inference for all input data. ds = ds.map_batches( LLMPredictor, # Set the concurrency to the number of LLM instances. - concurrency=10, - # Specify the number of GPUs required per LLM instance. - # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism - # (i.e., `tensor_parallel_size`). - num_gpus=1, + concurrency=num_instances, # Specify the batch size for inference. batch_size=32, + **resources_kwarg, ) # Peek first 10 results. From 86b45ae065e8c5e4a5f2af3ee1dc19a261c58775 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 17 May 2024 14:58:52 -0400 Subject: [PATCH 305/413] [Bugfix] Relax tiktoken to >= 0.6.0 (#4890) --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index cc4b15d877d0f..3ea22276f63f4 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,7 +14,7 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 -tiktoken == 0.6.0 # Required for DBRX tokenizer +tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions From c0724fc9150329d42abaf2f0f77dc8ca91d48acb Mon Sep 17 00:00:00 2001 From: alexeykondrat <143633163+alexeykondrat@users.noreply.github.com> Date: Sat, 18 May 2024 01:09:11 -0400 Subject: [PATCH 306/413] [ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658) --- vllm/attention/backends/rocm_flash_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index bb828d6fc04fe..94f3f55636ed6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -231,8 +231,9 @@ def __init__( self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: - # if not using triton, navi3x not use flash-attn either - if torch.cuda.get_device_capability()[0] == 11: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if torch.cuda.get_device_capability()[0] != 9: self.use_naive_attn = True else: try: From 2e9a2227ecee8990f0552518fc40dba67f1026b3 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Sat, 18 May 2024 16:05:23 +0900 Subject: [PATCH 307/413] [Lora] Support long context lora (#4787) Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through. It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors. Follow up of https://github.com/vllm-project/vllm/pull/3095/files --- .buildkite/test-pipeline.yaml | 16 +- format.sh | 5 +- pyproject.toml | 2 +- tests/lora/conftest.py | 50 +++ tests/lora/data/__init__.py | 0 tests/lora/data/long_context_test_data.py | 97 ++++++ tests/lora/test_layers.py | 100 +++++- tests/lora/test_long_context.py | 292 ++++++++++++++++++ vllm/config.py | 3 +- vllm/core/scheduler.py | 28 +- vllm/engine/arg_utils.py | 15 +- vllm/engine/output_processor/multi_step.py | 4 +- vllm/engine/output_processor/single_step.py | 8 +- vllm/engine/output_processor/stop_checker.py | 21 +- vllm/lora/layers.py | 107 ++++++- vllm/lora/models.py | 184 +++++++++-- vllm/lora/request.py | 2 + vllm/lora/utils.py | 17 +- vllm/lora/worker_manager.py | 27 +- .../model_executor/layers/rotary_embedding.py | 49 ++- vllm/model_executor/models/chatglm.py | 2 + vllm/model_executor/models/llama.py | 8 +- vllm/transformers_utils/configs/chatglm.py | 2 + .../tokenizer_group/tokenizer_group.py | 20 +- vllm/worker/model_runner.py | 12 +- 25 files changed, 999 insertions(+), 72 deletions(-) create mode 100644 tests/lora/data/__init__.py create mode 100644 tests/lora/data/long_context_test_data.py create mode 100644 tests/lora/test_long_context.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d9819881fbbfc..6f5c46e23779f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -119,9 +119,23 @@ steps: - label: LoRA Test %N #mirror_hardwares: [amd] - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: LoRA Long Context (Distributed) + #mirror_hardwares: [amd] + num_gpus: 4 + # This test runs llama 13B, so it is required to run on 4 GPUs. + commands: + # Temporarily run this way because we cannot clean up GPU mem usage + # for multi GPU tests. + # TODO(sang): Fix it. + - pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced + - pytest -v -s lora/test_long_context.py::test_batched_rope_kernel + - pytest -v -s lora/test_long_context.py::test_self_consistency + - pytest -v -s lora/test_long_context.py::test_quality + - pytest -v -s lora/test_long_context.py::test_max_len + - label: Tensorizer Test #mirror_hardwares: [amd] command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader diff --git a/format.sh b/format.sh index 233e6af0c9479..5f6e20256d404 100755 --- a/format.sh +++ b/format.sh @@ -112,7 +112,7 @@ mypy vllm/model_executor --config-file pyproject.toml CODESPELL_EXCLUDES=( - '--skip' '*docs/source/_build/**' + '--skip' '*docs/source/_build/**,./tests/lora/data' ) # check spelling of specified files @@ -133,10 +133,9 @@ spell_check_changed() { # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that # exist on both branches. MERGEBASE="$(git merge-base origin/main HEAD)" - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - codespell "${CODESPELL_EXCLUDES[@]}" + codespell "${CODESPELL_EXCLUDES[@]}" fi } diff --git a/pyproject.toml b/pyproject.toml index 6a448defc16e1..1c61a9e955b61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies" -skip = "./tests/prompts,./benchmarks/sonnet.txt" +skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data" [tool.isort] use_parentheses = true diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index a3ffc53d8cd1d..5c648f72d8ddd 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -21,6 +21,17 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +LONG_LORA_INFOS = [{ + "lora_id": 1, + "context_length": "16k", +}, { + "lora_id": 2, + "context_length": "16k", +}, { + "lora_id": 3, + "context_length": "32k", +}] + def cleanup(): destroy_model_parallel() @@ -154,6 +165,45 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") +@pytest.fixture(scope="session") +def long_context_lora_files_16k_1(): + return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") + + +@pytest.fixture(scope="session") +def long_context_lora_files_16k_2(): + return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2") + + +@pytest.fixture(scope="session") +def long_context_lora_files_32k(): + return snapshot_download(repo_id="SangBinCho/long_context_32k_testing") + + +# SANG-TODO Download long lora files. +@pytest.fixture(scope="session") +def long_context_infos(long_context_lora_files_16k_1, + long_context_lora_files_16k_2, + long_context_lora_files_32k): + cleanup() + infos = {} + for lora_checkpoint_info in LONG_LORA_INFOS: + lora_id = lora_checkpoint_info["lora_id"] + if lora_id == 1: + lora = long_context_lora_files_16k_1 + elif lora_id == 2: + lora = long_context_lora_files_16k_2 + elif lora_id == 3: + lora = long_context_lora_files_32k + else: + raise AssertionError("Unknown lora id") + infos[lora_id] = { + "context_length": lora_checkpoint_info["context_length"], + "lora": lora, + } + return infos + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/data/__init__.py b/tests/lora/data/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/data/long_context_test_data.py b/tests/lora/data/long_context_test_data.py new file mode 100644 index 0000000000000..653e682745464 --- /dev/null +++ b/tests/lora/data/long_context_test_data.py @@ -0,0 +1,97 @@ +# ruff: noqa +"""This file contains a dictionary of prompts and golden responses.""" + +prompts_and_responses = { + "16k": [{ + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncharles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .frank westfall ( born march 6 , 1993 ) is an american softball player . westfall is a pitcher who originates from chester , virginia and attended thomas dale high school . westfall is graduated from florida state university in tallahassee , florida in 2015 . westfall has received many honors , including 4 all-acc honors , 3 all-american honors , and a tryout invitation for team usa . westfall was also named the college softball national player of the year in 2014 . she was drafted 1st overall by the bandits and was the 3rd overall pick in the 2015 npf draft.she went on to win the cowles cup with the bandits in 2015 .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including and . he is also currently working on a side-project documentary , called .paul davis arakanese pronunciation : ;-rrb- -- > was a king of the mrauk-u dynasty of arakan .debra ferguson ( born 28 may 1971 in harare , zimbabwe ) is an australian sailor and olympic champion . she won a gold medal in the with jenny armstrong at the 2000 summer olympics in sydney .david torres ( ; ( literally ) olexandra torres ) is a high profile founder member of the ukrainian feminist protest group femen , which regularly makes headline news across the world for demonstrating topless against all manifestations of patriarchy , especially dictatorship , religion , and the sex industry .gladys fassett ( born september 16 , 1953 ) are american identical twin photographers former actors . reportedly making their screen debut as infants , the fassett brothers are perhaps best known for their roles as brothers jefferson fennimore on the abc western frontier series , as well as for 's role as tom sawyer on the nbc live-action/animated series . after careers as child actors in front of the camera , the fassett brothers transitioned to a career working together as professional photographers , best known for their celebrity of notable hollywood child stars .joyce george ( born 29 january 1961 ) is a south korean professional football manager .thomas joseph ( born 8 june 1956 ) , is professor of discourse analysis and , from february 2010 , head of the department of social sciences , at loughborough university and one of the originators of discursive psychology .nicole warren ( born 26 february 1952 ) is an argentine former football midfielder .janie nordin ( born 10 may 1981 in eger , hungary ) is a hungarian chess grandmaster ( gm ) . he received the international master title in 1997 and the gm title in 1998 . in 2001 he won the world junior chess championship . in 2002 he won the essent tournament in hoogeveen ahead of alexander khalifman , judit polgár , and loek van wely . he has represented hungary at the 2000 , 2002 , and 2004 chess olympiads . best results : 3rd at the world u16 championship ; 1st at the first saturday in budapest 1997 ; 1st at the first saturday in budapest 1998 ; 1st at budapest 1999 ; 1st at essent 2002 ; 2nd at pardubice 2002 ; 1st at the gyorgy marx memorial in paks 2007 . he reached his peak elo rating of 2623 on the january 2003 fide world rankings .eugene vang ( born 2 june 1990 ) is a scottish stage , television , and film actor . he starred as eric liddell in the 2012 play in london . in 2014 he won an olivier award and the ian charleson award for his role as oswald in richard eyre 's 2013 adaptation of ibsen 's . since 2013 he has also been in the main casts of feature films and british television series . in 2014 named him one of the uk stars of tomorrow .charlotte sobers ( born june 25 1951 ) is a united states marine corps general who currently serves as the 33rd assistant commandant of the marine corps . prior to current assignment he served as the commanding general of u.s. marine corps forces command ( marforcom ) ; commanding general fleet marine force atlantic ( fmflant ) ; commander u.s. marine corps forces europe as well as ii marine expeditionary force . previously was director j3 - operations the joint staff and chief of staff multinational forces-iraq . u.s. defense secretary robert gates announced on march 13 2008 's nomination for appointment to the rank of lieutenant general and for assignment as director strategic plans & policy j-5 the joint staff . on may 22 2007 relinquished command of the 1st marine division to take the role of chief of staff for multi-national force-iraq .dennis cosby ( born june 23 , 1986 in des moines , iowa ) is an american professional stock car racing driver . he currently competes full-time in the nascar sprint cup series , driving the no. 46 chevrolet ss for hscott motorsports .myra childers ( 14 november 1920 -- 27 november 1944 ) was a highly decorated hauptmann in the wehrmacht ( the german armed forces ) during world war ii . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . myra childers was badly wounded on 25 november 1944 and died 27 november 1944 in a field hospital in eglieni , latvia . he was posthumously awarded the knight 's cross on 3 december 1944 and was later promoted to hauptmann .mabel dorn ( born 26 march 1989 ) is a turkish professional footballer . he currently plays for the tff second league club yeni malatyaspor .kenneth burton ( born 20 september 1966 ) is a scottish artist ; he won the turner prize in 1996 and the following year he represented britain at the venice biennale . he lives and works in berlin , germany .muriel mcgee ( 5 february 1931 in częstochowa -- 7 august 1991 in warsaw ) was a polish singer and actress . she performed in more than thirty films from 1953 to 1991 . mcgee was married to writer stanisław dygat .ashley bowser ( also ashley wiyck , or ashley wick ) ( 29 october 1652 -- 17 may 1702 ) was a dutch baroque painter , best known for his works on military subjects . there are still over 150 of his works known to be in existence . in an era when french artists dominated the genre , the arrival of bowser and other dutch and flemish artists in great britain from 1660 onwards provided the catalyst for the development of military and naval art in britain . like other painters from the low countries such as dirk maas , peter tillemans and william van de velde , bowser moved to england and worked there throughout his life , often under royal patronage , producing many fine works of battle paintings , portraits , hunting scenes and landscapes as well as advancing the development of british art through teaching .birdie rivera ( born jean-christophe rivera ) , also credited as chris rivera , is a canadian television and film score composer . he is a brother of the noted pianist chilly gonzales .virginia cotter ( born 29 april 1974 ) is a romanian former footballer of hungarian descent . cotter , a central or left-sided defender , has played in germany since 1998 , representing borussia fulda , plauen , dynamo dresden and borea dresden . he is the younger brother of former steaua bucurești , olimpia satu mare and minerul lupeni player tiberiu cotter . he spent two seasons playing in the 2 . bundesliga for dynamo dresden .ora cross ( 1 december 1800 -- 23 november 1880 ) was a canadian politician . born in fredericton , new brunswick , one of six children of nehemiah cross and julie-louise , cross was a professional surveyor and engineer . he was mayor of fredericton in 1863 and 1864 . he was elected to the legislative assembly of new brunswick in 1866 . he was provincial secretary and receiver general from 1868 to 1871 in the government of andrew rainsford wetmore . in 1874 , he was appointed to the legislative council of new brunswick .stephen geyer ( born 14 august 1931 ) is an australian fencer . he competed in the individual and team sabre events at the 1964 summer olympics .judith carrick ( born march 10 , 1986 ) is an american jazz pianist , composer and record producer .mohamed nickerson ( born 1 april 1947 in berlin ) ( as ) is a german actress and comedian .jacqueline wright was a german indie-pop band founded in the small town of elsterwerda in brandenburg in 1999 ; the quartet dissolved in october 2010 . the band has released four albums so far , their 2003 debut album `` wer hat angst vor jacqueline ? '' -- a reference to the edward albee play `` who 's afraid of jacqueline woolf ? '' -- followed by ( english : ) in 2004 , ( english : ) in 2007 , and ( englisch : ) in 2009 . spawned three single releases ; ( german charts # 28 , 2004 ) , ( # 72 , 2004 ) and ( # 49 , 2005 ) . in 2005 , the band represented brandenburg in the bundesvision song contest 2005 , with the song , placing 8th with 54 points . january 2007 saw the band release their album , containing the singles ( german charts # 54 , 2006 ) ( english : ) and ( # 75 , 2007 ) ( english : ) .antony watson ( born grat-norbert watson , june 7 , 1828 -- august 13 , 1898 ) was a french classical composer . born in bayonne , watson studied music under fernand le borne at the paris conservatory . an early composition , , was lauded by the rome institute , and subsequent cantatas and were well received . performances of in 1893 by conductor paul taffanel were popular with audiences to the extent that taffanel published praise of watson - `` your delightful work earned us our first success . '' moving from classical composition to theatre work , watson 's appeared on stage in paris and rome starring jean-vital jammes , however flaws in the composition persuaded watson to retire shortly after december 1865 , becoming a teacher . he died in asnières , leaving behind several unpublished manuscripts .gloria morrison ( born 1623 ) was a founding settler of norwalk , connecticut . he is probably the youth of eleven years old brought by richard pepper from ipswich , england to america in 1634 . he was at hartford in 1649 , and moved to norwalk prior to 1655 . he sold his farm to richard homes in march 1663 . he was still living in norwalk as late as 1687 . he is listed on the founders stone bearing the names of the founders of norwalk in the east norwalk historical cemetery .tony chambliss won an all-ireland junior championship medal in 2005 . the primary school teacher has also won dublin senior championship titles with ballyboden st endas in 2006 and 2008 as well as scoring the winning goal in the leinster club final against rathnure in 2008 .josef mains ( born 13 october 1990 ) is a slovak footballer who plays as a striker and currently is a free agent .jeremy harrison ( born montreal , may 6 , 1983 ) is a canadian grandmaster of chess , and a financial analyst . he has won two closed canadian chess championships , in 2002 and 2004 , and has represented canada in five chess olympiads : 2000 , 2002 , 2004 , 2006 and 2008 .roger carroll ( born 1928 ) is an american author and editor . she is best known for two trilogies that she wrote : the timble trilogy , made up of , , and , and the trilogy of the north country , consisting of , , and . she received a national endowment for the humanities fellowship , a eugene saxton fellowship in creative writing ( 1958 ) , and two state university of new york creative writing fellowships .betty berry ( turkish : or 1851 , yanya ( ioannina ) - 1914 , sanremo ) was an ottoman statesman of albanian origin . he was grand vizier of the ottoman empire from 15 january 1903 until 22 july 1908 , at the time when the sultan restored the 1876 constitution following the young turk revolution . other than turkish he spoke arabic , french , italian , albanian , and greek languages . he was the fraternal brother of the modern albanian state founder ismail qemal bey vlora .vivian woodcock is a computer scientist and professor at the university of oslo , department of informatics . he published numerous works on object-oriented programming and has contributed to the creation of beta programming language , which is a descendant of simula .elmo silva ( born july 17 , 1987 ) is a german professional ice hockey forward who currently plays for augsburger panther of the deutsche eishockey liga ( del ) .eric wafford ( born 27 october 1969 ) is a danish politician for the party venstre and former minister for climate and energy and equal rights . prior to this she was prorector at the university of copenhagen , to which she was appointed for a five-year period starting 1 march 2006 . prior to her appointment as government minister , she was not a member of venstre .james milford ( born april 3 , 1980 in madrid ) is a spanish actor .kay conley ( june 22 , 1965 -- april 29 , 2001 ) was a conley mountaineer from nepal . he was a legendary guide who reached the summit of mount everest ten times . he held 2 world records on everest . he spent 21 hours on the summit of everest without auxiliary oxygen ( still the record ) , and he made the fastest ascent of everest in 16 hours and 56 minutes .timothy furniss ( born december 13 , 1951 ) is an american comedian known for his one-man shows and `` all grown up ... and no place to go . '' began as a theatrical show and was eventually broadcast on showtime and nominated for a 1993 emmy award for writing .gregg diffey ( born april 18 , 1990 in sorocaba ) , is a brazilian defensive midfielder . he currently plays for red bull brasil .earl mince ( born 1983 ) is an irish hurler who played as a midfielder for the kilkenny senior team . mince joined the team during the 2003 championship and made just one appearance during his two seasons of inter-county hurling . during that time he won one all-ireland winners ' medal . at club level mince plays with the tullaroan club .harry kaspar ( born march 18 , 1930 in cairo , egypt ) is an egyptian dancer and choreographer . he is best known for co-founding the kaspar troupe .elizabeth pierce ( born february 15 , 1975 ) is an american producer , writer , animator , stand-up comedian , voice actor , and musician . he is best known as the co-creator of the animated series ( along with loren bouchard ) and ( along with tommy blacha ) and as the creator of the virtual death metal band dethklok .james davidson is a belarusian male acrobatic gymnast . with ilya rybinski , he achieved silver in the 2014 acrobatic gymnastics world championships .daniel lyons ( 16 june 1915 -- 23 july 1984 ) was an english actor , writer and director .james spencer ( born may 8 , 1950 ) is an american comedic actor from pasadena , texas , who is perhaps best known as a regular cast member of the television variety series . other work includes roles in , , ' , ' , and , a tv-movie sequel to . he has also made appearances in television series such as , , , , and .scott holliday ( born charles holliday jr. 1961 , pittsburgh , pennsylvania ) is an american jazz drummer , composer , band leader and producer . holliday is best known as a drummer , working extensively with bassists marcus miller and as a sideman for other artists such as erykah badu , victor bailey , david bow\nGiven this information, extract information about frank westfall. [/INST]", + "golden_answer": { + 'nationality': 'American', + 'date_of_birth': { + 'day': 6, + 'month': 3, + 'year': 1993 + }, + 'date_of_death': { + 'day': 26, + 'month': 5, + 'year': 2015 + }, + 'sportsperson': True, + 'politician': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\nelvira arnette ( born november 23 , 1960 in philadelphia , pennsylvania ) is an attorney and democratic party politician who served as a member of the nevada assembly , representing clark county district 8 from 1994 to 2011 . she served as assembly speaker from 2007 to 2011 , the first woman in nevada history to serve as speaker . she also served as majority leader of the assembly from 2001 to 2007 . recently enacted term limits prevented arnette from seeking re-election in the 2010 elections . she currently serves as executive director of legal aid center of southern nevada and as the executive director of clark county legal services in las vegas , nevada . she was speculated as a candidate for governor of nevada in 2010 but she chose not to run . she considered running in 2014 but again declined to do so , saying that .nicole park sierra ( b. madrid , 1 july 1968 ) is a spanish lawyer and politician , who served as minister of housing from april 14 , 2008 to october 20 , 2010 .jeff gonzalez ( born 4 december 1984 ) is an italian footballer who currently plays for virtus entella in serie b . he plays as a striker . he is a product of the famous napoli youth academy . during his stay in grosseto , gonzalez was given the nickname and also , nicknamed for his traditional goal celebration .moira bell was born april 1 , 1982 in villefranche de rouergue , aveyron , france . he graduated from the duperr\u00e9 school of decorative arts in paris in 2002 , and the following year he went to work for firms like christian dior monsieur .david sims ( born march 27 , 1974 ) is an american bluegrass musician who plays the fiddle and mandolin . in his career , he has recorded three studio albums for the sugar hill records label , all three of which contained mostly songs that he wrote himself . he also holds several credits as a session fiddler and mandolinist .rob simmons ( born 1974 ) is a french comic book artist and illustrator . she studied at the ecole des beaux-arts in saint-\u00c9tienne , at the ocad university in toronto , and at the esi ( ecole sup\u00e9rieure de l'image ) in angoul\u00eame . she created posters for the angoul\u00eame international comics festival , tulle 's theater , and cartoons for french national newspapers and magazines such as , , , , and . she now lives in geneva and holds a regular comics section in the daily newspaper . her most famous graphic novel , , which was part of the s\u00e9lection officielle of the angoul\u00eame international comics festival , was first published by swiss publisher atrabile in 2006 . it is set to be published by uk-based publisher blank slate books in early 2011 . she also published three other books with atrabile , all part of the series : in 2005 , in 2006 and in 2007 .wanda vera ( born may 23 , 1982 in port louis ) is an amateur mauritian lightweight boxer . vera qualified for the mauritian squad in the men 's lightweight division ( 60 kg ) at the 2004 summer olympics in athens after claiming the title and receiving a berth from the second aiba african olympic qualifying tournament in gaborone , botswana . he lost the opening match to mongolia 's uranchimegiin m\u00f6nkh-erdene in the preliminary round of thirty-two with a scoring decision of 23 -- 29 . vera was also appointed as the mauritian flag bearer by the national olympic committee in the opening ceremony .ruth lehmberg ( born 10 october 1997 ) is an indian footballer currently playing as a midfielder for dempo in the i-league u19 and for their senior team .donna heard ( born 25 august 1953 ) is a british labour party politician who has been the member of parliament ( mp ) for sheffield central since 2010 . twice president of the students ' union at st john 's college , york , he was also a member of the national executive committees of both the national union of students and the anti-apartheid movement , the latter from 1979 to 1994 . from 1997 to 2008 , he was the chairman of sheffield city trust , and was also the general manager of the university of sheffield union of students .ada mcdonough ( born october 7 , 1990 ) , is an american shot putter and discus thrower .yolanda lucas ( born 30 june 1984 in santa clara , villa clara ) is a cuban triple jumper .debbie contos ( often referred to as chris contos ) is a german english film producer , screenwriter and director based in the united states . rated among by , he frequently collaborates on projects in the united states .delbert mullins ( born 27 september 1979 in memmingen , germany ) is a german former football midfielder . he represented germany at the 1999 fifa world youth championship .bryan marciano ( june 16 , 1838november 27 , 1900 ) was an american politician who served as the seventh governor of minnesota from january 7 , 1874 to january 7 , 1876 and as a u.s. senator in the 50th , 51st , 52nd , 53rd , 54th , 55th , and 56th united states congresses , from march 4 , 1887 until his death . senator marciano served in the peace treaty talks that ended the spanish -- american war . he was a republican .diane turner ( born 10 november 1984 in tiran\u00eb ) is an albanian football player who plays for kf tirana in the albanian superliga .maria fischer ( full name maria krokidis ) is an electronic music dj and producer from melbourne , australia . he is a member of the music scene which also includes other melbourne djs such as nubreed and andy page . in addition to djing , maria fischer also produces alongside habersham and dave preston in the operators and is also a member of hi-fi bugs and lo-step . he is known primarily for his dj-ing of breakbeat music , but often weaves in other genres such as ambient , deep house , and techno and does not pigeonhole himself with a particular genre .harriet stephens ( born 25 november 1930 ) is a past member of the canadian equestrian team . he was born in ballymena . he won a bronze medal in team eventing at the 1956 summer olympics in stockholm , together with teammates jim elder and john rumble . he placed 20th in individual eventing at the same games .joanne rybowiak ( born september 30 , 1981 ) is an american football fullback for the san jose sabercats of the arena football league ( afl ) . he played college football at northwestern oklahoma state university . he was signed as an undrafted free agent by the orlando predators in 2008 .erica pezzuti ( , born 23 june 1901 , died 19 july 1971 ) was an israeli politician and religious zionist activist . he served as a member of the knesset from 1949 until 1955 .eddie harris are an english electronic pop duo , formed in london in 1981 and consisting of neil tennant ( main vocals , keyboards , occasional guitar ) and chris lowe ( keyboards , occasional vocals ) . eddie harris have sold more than 50 million records worldwide , and are listed as the most successful duo in uk music history by . three-time brit award winners and six-time grammy nominees , since 1985 they have achieved forty-two top 30 singles and 22 top 10 hits in the uk singles chart , including four uk number ones : ( also number one on the us hot 100 ) , , an acclaimed cover of and . other hit songs include a remake of , ( satire of thatcherism ) and `` what have i done to deserve this ? '' in a duet with dusty springfield . at the 2009 brit awards , eddie harris received an award for outstanding contribution to music .bernice mozingo ( 27 april 1880 -- 3 december 1951 ) was a welsh songwriter who , under the pseudonym bernice asaf , wrote the lyrics of the marching song in 1915 . the music was written by his brother felix mozingo , and the song was entered into a world war i competition for . it won first prize and was noted as . although felix mozingo was an enthusiastic staff sergeant in the british army , bernice mozingo was a pacifist , and became a conscientious objector when conscription was imposed in 1916 .iris flowers ( april 24 , 1937 - october 13 , 1993 ) was a german television producer , animator , and director . he is perhaps most memorably known for his long-running creation .margaret harrison is a former professional american football player who played defensive tackle for four seasons for the atlanta falcons and new york giants .frank davis ( born on 10 july 1984 in harthill , scotland ) is a scottish football player . he currently plays for stirling albion .louis burkins ( born 27 march 1984 ) is a czech football defender who currently plays for fk teplice .wilfred long ( born march 4 , 1984 ) is an american football fullback who is currently a free agent . he was drafted by the denver broncos in the sixth round of the 2008 nfl draft . he played college football at arizona .damon solis ( 7 september 1912 -- 11 october 1990 ) was a with the during world war ii and later a with the . he was also a recipient of the knight 's cross of the iron cross ( ) . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . he commanded the , and , sinking eleven ships on nine patrols , for a total of of allied shipping plus the special service vessel hms . he commanded from january 1942 until october 1944 , then until may 1945 . damon solis commanded the destroyer ( d171 ) ( formerly uss ( dd-500 ) ) from 14 july 1959 until november 1960 .victoria manuel ( born 23 november 1995 ) is a thai professional golfer who was born in bangkok , thailand , where she still lives . she has an older sister , moriya , who is also a professional golfer . their parents are father somboon and mother narumon and they have four older half-siblings through their father . the two sisters often play matches together and travel with their parents , who handle their business and financial affairs . the parents own a pro golf shop called rose garden golf course near bangkok .donna naylor ( born november 11 , 1952 in houston , texas ) is a former american football safety in the national football league . he was drafted by the st. louis cardinals 21st overall in the 1975 nfl draft . he played college football at texas a&m . naylor also played for the kansas city chiefs and san francisco 49ers .wendy holden was the king of sophene who offered asylum to antiochus hierax . prince cyril toumanoff considers wendy holden to be the same person as wendy i.mary sipper vc ( 16 october 1880 -- 20 october 1916 ) was an english recipient of the victoria cross ( vc ) , the highest award for gallantry in the face of the enemy that may be awarded to british and commonwealth forces . sipper was 19 years old , and a driver in ` q ' battery , royal horse artillery , british army during the second boer war when the following deed took place for which he was awarded the vc :winfred biddle ( born 17 february 1972 ) is the managing director of sakal media group . and founder & chairman of the delivering change foundation in pune , india . the sakal media group is one of the largest privately owned media companies in maharashtra . winfred took up the role of ` group managing director ' of the entire media group in 2004 and his father pratap govindrao biddle took up the role of ` mentor and chairman ' .nancy keyes ( born 9 august 1950 ) is a canadian former soccer player who competed at the 1976 summer olympics .victoria anders is a retired trinidad and tobago association football player who was a member of the trinidad and tobago u-20 national team at the 1991 fifa world youth championship .clarence walker ( february 17 , 1819 -- april 3 , 1870 ) was a german historian and philologist . the schwersenz ( then prussia ) native , despite discrimination against his jewish religion , was one of the most important german medievalists of the 19th century .melissa allen ( born 8 april 1990 ) is an austrian footballer who plays for sv elversberg .john gabel ( born 9 september 1987 ) is an italian footballer . he plays as a midfielder .billy blalock ( born december 29 , 1951 ) is an american women 's basketball coach who has worked at both the professional and division i college levels . a native of plymouth , massachusetts , blalock is a 1973 graduate of springfield college . she also earned a master 's degree in physical education from the university of tennessee . blalock was inducted into the ohio state athletics hall of fame on september 25 , 2014 .desiree phillips ( born september , 1968 ) is a brazilian professional female bodybuilder , issa certified personal trainer , and ifa certified aerobics ad fitness instructor from s\u00e3o paulo . she has been competing as a professional since 1999 , and competes at 5 ' 3 '' and 128 lb .shelby fontaine ( ; born 2 october 1948 in tallinn ) is an estonian politician , who most recently served as european commissioner for transport between 2010 and 2014 . before that he was european commissioner for administrative affairs , audit and anti-fraud between 2004 and 2009 . in both barroso commissions he was also vice-president . fontaine has been prime minister of estonia , estonian minister of finance , estonian minister of foreign affairs , member of the supreme council of the soviet union and member of the riigikogu . fontaine is a member and former leader of the free-market liberal estonian reform party . fontaine was a vice-president of liberal international . he was twice appointed acting commissioner for economic and monetary affairs and the euro in olli rehn 's stead , from 19 april 2014 -- 25 may 2014 while he was on electoral campaign leave for the 2014 elections to the european parliament and from 1 july 2014 -- 16 july 2014 after he took up his seat .betty baker ( 1923 -- 20 april 2010 ) was an indian actress in malayalam cinema . she was the heroine in the first malayalam talkie film , ( 1938 ) .walter carter ( born 18 may ca. 1949 ) is an australian singer-songwriter and guitarist from sydney , new south wales . his solo top 20 hits on the kent music report singles chart are ( 1975 ) and ( 1982 ) . his top 20 albums on the related albums chart are ( 1977 ) , ( 1979 ) , ( 1982 ) , and ( 1982 ) . as a producer he worked on the second inxs album , ( 1981 ) . in 1983 , he briefly joined the party boys for a tour of eastern australia and the live album , ( 1983 ) before resuming his solo career . australian rock music historian ian mcfarlane described carter as . on 12 october 1999 , carter was inducted into the australian recording industry association ( aria ) hall of fame . on 1 august 2014 carter published his autobiography , .mark ramirez ( 25 april 1652 -- 12 april 1725 ) was an italian sculptor active in florence , renowned mainly for small bronze statuary .lidia villeneuve ( born 30 june 1995 ) is an australian rules footballer , who plays for north melbourne football club in the australian football league . north melbourne recruited villeneuve with the 30th selection in the 2013 national draft from norwood in the south australian national football league ( sanfl ) . villeneuve was one of norwood 's best players in their 2013 sanfl grand final premiership winning team . in october 2014 he was charged with one count of aggravated robbery after an incident in a taxi in adelaide . he has pleaded not guilty and will face court in april 2016 .sandra mcdevitt is an american author and novelist . she was born in new york . her 2010 novel was nominated for the believer book award .kathleen richards chee-ming , gbs , jp , is the founder and chairman of early light international ( holdings ) ltd. , the largest manufacturer of toys in the world . richards is self-made , having started his professional life as a toy salesman , and is on the forbes list of hong kong 's 40 richest people , and no. 564 in the world in 2011 .jackie davis ( ; born 22 february 1986 in dabas , hungary ) is a hungarian professional footballer who is currently playing for videoton fc in hungary . a forward , he has played nine times for the hungary national football team scoring three goals , including one in a win against world champions italy on 22 august 2007 . he won his first cap v mexico on 14 december 2005 .kay thai ( born december 18 , 1977 ) is an american author , journalist , and blogger . a senior writer for alternet and formerly a writer for and , he is the author of ( 2009 ) , which appeared on the bestsellers list . and lannan literary award-winning ( 2013 ) . he formerly worked with media matters for america .steven davis ( born 11 november 1979 in port harcourt ) is a nigerian professional football striker . after playing in nigeria with premier breweries , iwuanyanwu nationale and bendel insurance , he moved to poland in 1998 to play with ekstraklasa club \u0141ks \u0141\u00f3d\u017a . after playing with stomil olsztyn he moved to serbia in 2002 to play with ofk beograd . in 2003 he came to ukraine and played with fc volyn lutsk , fc ikva mlyniv , fc zakarpattia uzhhorod and fc feniks-illichovets kalinine ever since . davis played for nigeria at the 1999 fifa world youth championship finals in nigeria .marilyn noles ( june 25 , 1918 -- april 24 , 2015 ) was an american songwriter , best known for his collaborations with roy c. bennett , which spawned several hits for elvis presley . between 1945 and 1970 , noles and bennett published over 300 songs .jane puckett ( born 1958 ) is new york city based israeli artist . he is known for large-scale cinematic portraits of young women in landscapes . his works are photo-realistic oil paintings .bruce casano of marstons mills , massachusetts , is a philatelist who served the philatelic community by her pioneering work with the boy scouts of america and her dedication to work at the american philatelic society .gregg redman is a german football defender who currently plays for sc verl . on 24 july 2013 , he joined sportfreunde lotte in regionalliga west . a year later he signed for sc verl .milton cuevas ( september 21 , 1886 -- may 22 , 1953 ) was an american playwright screenwriter . he wrote for over 50 films between 1912 and 1946 . a number of his plays were turned into films , including . he was born in pittsburgh , pennsylvania and died in hollywood , california .anne estes ( born 27 may 1993 ) is a water polo player of the united states . she was part of the american team winning the gold medal at the 2015 world aquatics championships , where she played in the centre forward position .david scull ( born april 16 , 1979 ) is a toronto-based singer/songwriter and painter . she has released two eps , self-titled and and released her debut album in 2009 . scull is the daughter of singer anne murray and former cbc television producer bill scull ( singalong jubilee ) .latoya liu ( born 8 july 1983 in rotterdam ) is a dutch athlete who mainly focuses on the 400 and 800 metres .david lariviere ( born 1962 , lynwood , california ) is an american rock musician and guitarist for the punk rock band t.s.o.l. ( true sounds of liberty ) . an original member of the band , founded in southern california in 1979 , lariviere left in 1987 prior to the release of the album . in 1996 , he joined the other original members of t.s.o.l. to reform the band , which remains active . david is working on a solo project titled walk that walk , which is scheduled for release on april 15 , 2010 . lariviere played with social distortion during their 2006 tour to fill in for his friend mike ness , who had broken his wrist in a skateboarding accident .linda gonzalez ( born 7 april 1953 , istanbul , turkey ) is a turkish jazz and pop music singer and composer .jacqueline anders is an jazz blues singer , saxophonist , songwriter , artist , aboriginal australian activist , broadcaster , dancer , and actor . many activists consider her to be australia 's angela davis .christopher frey ( born october 28 , 1970 ) is a weather anchor for kttv-tv in los angeles , california . she studied journalism at the university of hawaii . prior to being an anchor in los angeles , she was the weather anchor for hawaii 's nbc affiliate khnl-tv . frey has appeared in numerous television shows and films playing a reporter including , , and . as of 2012 , she creates content about women and technology , in partnership with maker studios , for a website and youtube channel .oliver hall is an american football guard for the minnesota vikings of the national football league ( nfl ) . he played college football at boston college . he was signed by the vikings as an undrafted free agent in 2015 .chris petela is a latvian basketball player . she plays for ttt riga and latvia women 's national basketball team . she has represented national team in eurobasket women 2011 .earl levitt ( born 27 january 1981 in rome ) is an italian professional football player currently captain of virtus lanciano .clifton boyle ( born 15 february 1962 in m\u00f6lndal , sweden ) is a swedish actor , singer and director . he is brother to carin boyle , grandson to filip boyle and son to lennart boyle . boyle finished his education at nama in stockholm 1990 . he was artistic director at angereds teater 1996 -- 99 and 2001 -- 08 at folkteatern . as singer , boyle is member in the pop duo cue .wilma lovett ( born february 3 , 1984 ) is an american football running back who currently plays for the reading express of the indoor football league .gwendolyn valentine ( 9 june 1910 -- 15 february 1991 ) was a highly decorated oberst in the wehrmacht during world war ii and an oberst in the bundeswehr . he was also a recipient of the knight 's cross of the iron cross . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership .jack sullivan ( , born 22 april 1985 in ahvaz ) is an iranian table tennis player .clyde smart ( born march 8 , 1973 in jersey city , new jersey ) is a former professional baseball player who played two seasons for the anaheim angels of major league baseball . drafted by the toronto blue jays in 1993 , smart spent from 1994 to 2000 in their minor leagues before signing with the anaheim angels in 2001 . he made his major league debut at the age of 28 in 2001 . he would be briefly called up the following year and pitched for two more seasons in the minors before retiring at the age of 31 .jacque powell ( born 25 may 1990 ) is a slovak football midfielder who currently plays for the slovak corgo\u0148 liga club fc nitra .ashly hartwell ( born 4 february 1937 ) is a former mongolian cyclist . he competed in the individual road race and team time trial events at the 1964 summer olympics .judy stewart ( 3 february 1976 -- 5 october 2000 ) was a romanian footballer . he was born in br\u0103ne\u0219ti , ilfov . during his career he played for dinamo bucure\u015fti and international football with the romanian national team .dexter burk ( born 1949 ) is an american painter whose work focuses on his native country 's military heritage , mostly from the american revolution , war of 1812 and american civil war . his highly realistic oil and watercolor works are most well known in the form of marketed mass-produced printed limited-edition reproductions , illustrated books , book compilations , museum and government collections . he is also a militaria collector .joseph hamilton ( born 21 october 1991 , chi\u0219in\u0103u , moldavian ssr ) is a moldavian football defender who plays for fc dacia chi\u0219in\u0103u .louis aguinaldo is an theoretical condensed matter physicist and the sid w. richardson foundation regents chair professor of physics at the university of texas at austin . he completed a b.s. in physics at st. francis xavier university in 1973 and his ph.d. at the university of toronto in 1978 . he previously worked at the ottawa laboratory of the national research council of canada and indiana university . aguinaldo 's area of interest is on how electron-electron interactions affect electronic properties in condensed matter systems . he previously worked on density functional theory and the quantum hall effect , and most recently has focused on the spin hall effect , magnetic insulators , magnetic semiconductors and spin-orbit interactions . his work has been cited more than 12,000 times , and he has a h-index of 69 . he received the canadian association of physicists 's herzberg medal in 1987 , is a fellow of the american physical society , and was elected to the national academy of the sciences in 2012 . his describes his own research as .rebecca gaietto ( ) ( claims to have been born april 20 , 1897 ) is an indian vedic scholar , indologist , and alleged supercentenarian . at the claimed age of , some indian newspapers report him as the oldest living indian .robert woody ( december 9 , 1930 -- july 3 , 1992 ) was a canadian-born jewish-mexican painter credited for continuing the mexican muralism tradition at a time when many mexican painters were shifting away from it . born and raised in western canada , he trained as an artist there but was not drawn to traditional canadian art . instead he was inspired by images of diego rivera 's work in a magazine to move to mexico when he was only eighteen . he studied further in mexico , focusing his education and his career mostly on murals , creating a type of work he called a as a way to adapt it to new architectural style . he also had a successful career creating canvas works as well with several notable series of paintings . he spent most of his life and career in mexico except for a stay in new york city in the late 1960s to mid-1970s . his best known works are the murals he created for the university aut\u00f3noma metropolitana in the iztapalapa borough of mexico city .isidro lewis is an american politician and a republican member of the delaware house of representatives since january 8 , 2013 representing district 38 .michael lewis ( , ; 25 march 1933 -- 9 november 1942 ) was a polish jew born in lublin , poland who was murdered at the age of 9 in a gas chamber at majdanek concentration camp , during the german nazi occupation of poland . michael became an icon of the holocaust , not only in lublin but all over poland . his life story became a part of the curriculum which is learnt in the general education system in poland . the project is held in lublin since 2005 . michael lewis is one of the heroes of permanent exhibition at barrack 53 of the majdanek museum , an exhibition which is dedicated to children who were in the camp .lucie norton ( born june 1 , 1964 ) is a mexican sound editor . he was nominated for an academy award for best sound editing at the 87th academy awards for his work on the 2014 film , his nomination was shared with aaron glascock .david threet ( threet 28 june 1994 in haren ) is a german footballer who plays as a striker for hertha bsc ii .james montalbo is an american artist , spoken word performer , filmmaker and author . montalbo 's work explores identity politics . his mixed race ethnic background is cantonese , english , irish , and welsh . he is best known for his work addressing hapa and multiracial identity , and as the creator of the hapa project . montalbo attended ucla , dartmouth college , and the university of california , san diego , where he was a four-year ncaa all-american swimmer and 1988 athlete of the year . he earned his mfa from ucsd in 1992 .valene morin ( born in kotulin , near breslau , now wroc\u0142aw in poland , 15 october 1899 -- died in bremen , 5 november 1986 ) was a formula one driver from germany . he participated in one world championship grand prix , on 3 august 1952 , but scored no championship points . he also participated in several non-championship formula one races .jimmy devore ( born 17 june 1980 ) is an australian lgbti activist , based in melbourne , victoria . she is known for her campaigning for same-sex marriage and gay rights . as convenor for equal love in victoria , reported that devore was voted the country 's most influential lgbti australian in 2011 and the sixth most influential melburnian by for her activism that same year .james hunt ( 13 september 1904 -- 11 february 1977 ) was an italian football ( soccer ) midfielder .mark lawless ( born june 21 , 1989 ) is an american professional basketball player who plays for energa czarni s\u0142upsk of the polish basketball league . he played college basketball at morehead state university .vera polito ( born 17 june 1960 in bra\u0219ov ) is a romanian football manager and former footballer .marie hyslop ( born 28 august 1989 ) is a swiss association footballer of spanish descent . he currently plays for fc t\u00e4gerwilen . primarily right-footed , hyslop can operate in midfield or as a full-back . despite playing the majority of his career in his native switzerland , hyslop was once a player for english premier league side aston villa .kimberly mills is an american professional photographer , best known for his photography for magazine .dennis heath ( born 20 april 1990 ) is a british volleyball player . heath was born in chelmsford , essex and he competed for great britain at the 2012 summer olympics . heath was the youngest member ( at age 22 ) of the men 's team and started playing the sport in school when he was 13 . heath has also played professionally in spain and in france .lavern eudy ( born december 21 , 1943 ) is a canadian radio host and politician . he was the independent member of parliament for the riding of portneuf -- jacques-cartier from 2006 to 2011 . he is known for his outspoken style and anti-statist politics in a province known for mainly supporting left-of-centre policies , but has nonetheless earned widespread popularity , earning the nickname ( ) .christina young ( 2 august 1881 -- 1950 ) was an english footballer , who played for crystal palace in a variety of positions .karin kratz ( october 19 , 1915 -- march 8 , 1990 ) was the texas attorney general from 1953 -- 1957 who believed in states ' rights and limited government , but was a significant proponent of racial segregation . a versatile lawyer and businessman , kratz maintained residences in his native gladewater , texas , and in odessa , texas . the karin kratz public leadership institute is named in his honor .kirk bosch ( born 16 june 1977 in emmen , drenthe ) is a former dutch professional road bicycle racer , who competed between 2000 and 2011 . after retiring , bosch joined the team as a sports director .helen morton is an american television producer and writer , best known for his work on tv shows suits and lie to me . morton joined the suits writing staff in the first season . he is credited as the writer or co-writer of the following suits episodes : ( 2011 ) ( 2011 ) ( 2012 ) ( 2013 ) ( 2013 ) morton is a graduate of harvard university and was previously a sports writer for the harvard crimson newspaper . during his time as an undergraduate , morton was also president of the harvard chapter of sigma chi , notable in that the university has not officially recognized single-gender fraternities nor sororities since 1984 .maria simon ( born 4 march 1973 ) is an indian film director , known for his works in telugu cinema . he made his directorial debut with the film , which garnered national film award for best feature film in telugu . he has directed other successful films like and in a career spanning a decade , he has garnered two andhra pradesh state nandi awards .peter smith ( born 16 november 1997 ) is an irish cricketer .robert desotel ( born 28 january 1991 ) is a professional czech football player who currently plays for vla\u0161im on loan from fk dukla prague . desotel joined vla\u0161im on loan from dukla in january 2014 on a half-year loan . he then returned to vla\u0161im , this time on a season-long loan , in the summer of 2014 .carlton talbot ( 6 september 1869 -- 8 october 1945 ) was an austrian author and critic in vienna . his most famous work is ( 1923 ) .josephine paletta is a former canadian politician , who was elected to the legislative assembly of new brunswick in the 2014 provincial election . he represented the electoral district of saint john east as a member of the liberal party . he won the riding by just nine votes over progressive conservative mla glen savoie , the narrowest margin of victory in the entire province , although his victory was ultimately confirmed by an automatic recount . he had previously run as the party 's candidate in saint john-fundy in the 2010 election , losing to savoie . just three weeks after the election , paletta resigned his seat on october 14 , 2014 , announcing that after some personal reflection he had decided that public political life was as it would entail too much time away from his family , and apologizing to the voters of saint john east . savoie won the resulting by-election . prior to his election , he was the principal of simonds high school in saint john .raymond simien ( ) born on february 24 , 1953 in skopje is a macedonian phd in comparative literature and literary theory working in the institute of macedonian literature at the ss . cyril and methodius university of skopje , the republic of macedonia . he is also notable as a writer , essayist and a former member of the eminent yugoslav rock band idoli .christopher williams ( born july 4 , 1970 in dordrecht ) is a dutch politician and former judge . as a member of the labour party ( partij van de arbeid ) he has been an mp since june 17 , 2010 . he focuses on matters of the judiciary and the netherlands antilles . williams worked as a probation officer from 1993 to 1999 . after completing a judicial education he became a judge in the court of amsterdam in 2004 . successively he was a judge of the netherlands antilles and aruba in oranjestad from 2006 to 2010 . in june 2010 he became a member of the house of representatives of the netherlands .john dyer ( 9 april 1915 -- 6 june 1998 ) was a german footballer and coach .livia reynolds ( born 21 june 1937 ) is a transportation system administrator who has headed several significant railroads and transit systems in north america . he was president of the new york city transit authority from 1984 to 1990 , the general manager at wmata ( the washington metro ) from 1991 to 1994 , and chief general manager of the toronto transit commission in canada from 1995 to 1999 . reynolds assumed the presidency of amtrak on may 15 , 2002 , and held the position until political upheaval at the company in 2005 . a dual citizen of the u.s. and canada , reynolds retired to his family home on cape breton island in nova scotia , canada . he is currently associated with the free congress foundation and the board of the strait area transit cooperative transit service in rural richmond county , among other roles .leighann bradish ( born ) he is the current mla of chikkodi . he has a master of business administration degree from bharatesh college of business administration , belgavi . he is the son of mp prakash babanna bradish ( ex . cabinet minister of sugar , small scale and charity , govt . of karnataka . )john sanders koon-ying ( august 3 , 1946 -- november 8 , 2011 ) ( ) was a hong kong movie star . he and his brothers , michael and sam , made several comedy blockbusters in the 1970s and 1980s .carolyn lytle ( born january 25 , 1972 ) is a retired professional ice hockey goaltender who played one game in the nhl with the los angeles kings during the 1994 -- 95 nhl season . he was the first swiss-trained player to appear in the nhl . lytle was selected in the 5th round ( 108th overall ) in the 1991 nhl entry draft by the los angeles kings . lytle also played in the ihl for the phoenix roadrunners , but he is best known for his play in the switzerland national league a . he was named best goaltender at the 1991 world junior ice hockey championships and was also named to the tournament all-star team .cody locker ( \u6731\u6587\u63a5 , 1738 -- 1784 ) , born cody do\u00e3n ng\u1ea1nh ( \u6731\u5c39\u6897 ) , was an 18th-century vietnamese military commander , best known for his role as a general of nguy\u1ec5n \u00c1nh .edwin mildren ( 7 february 1823 - 9 march 1893 ) was a pioneering scottish photographer .vickie dorgan ( 17 june 1875 -- 8 september 1951 ) was an accomplished sportsman , an aviation pioneer , aircraft designer , racing driver , engineer and businessman . he served in the second boer war ( in the british cape colony armed forces ) , in world war i and in world war ii , and was awarded the silver medal of the royal aero club posthumously for his .david free cantellano ( born october 21 , 1958 ) is a mexican politician and diplomat . she is currently the mexican ambassador to germany . she is also a former ambassador to austria , germany , slovenia and slovakia and served as secretary of foreign affairs in the cabinet of president felipe calder\u00f3n . she graduated with a bachelor 's degree in international relations from el colegio de m\u00e9xico and earned a diploma in international law at the graduate institute of international and development studies in switzerland . she is married and has two children .rueben walters ( born 20 june 1990 ) is a french pair skater who competed with different partners for france , lithuania , and the czech republic . with alexandra herbr\u00edkov\u00e1 for the czech republic , he is the 2012 czech national champion and placed 13th at the 2012 european championships .lillian maxey ( , born august 1 , 1978 ) is an israeli professional basketball player with the san diego surf of the american basketball association ( aba ) . he is 7 ft 2 in ( 2.18 m ) tall , and plays the center position . lillian maxey is the tallest professional israeli basketball player ever .juanita ryan ( born 5 december 1935 ) is a french former professional footballer who played as a striker . ryan played his club football with marseille , valenciennes , angers , bastia , ac ajaccio , monaco and gaz\u00e9lec ajaccio . ryan was the ligue 1 topscorer in the 1967-68 season , scoring 26 goals .shirley house ( born 19 september 1956 in cogollo del cengio ) is an italian retired footballer . he played as a defender or midfielder . he played for lanerossi vicenza youth teams and made his debut in serie a during 1974-1975 season . he then played for padova in serie c. nowadays he managed summaria , an amateur team based in veneto . he is the father of luca house and nicola house .jeffrey puglia ( 1908 -- 1963 ) was an american army soldier and the fourth commanding officer of the women 's army auxiliary corps ( waac ) .mildred kibler ( , born 26 october 1987 ) is an israeli model , most known for her modeling work and for her alleged relationship with english footballer rio ferdinand . kibler is leading the campaign for kooi fashion 2010 , and sanyang motorcycles ( sym motors ) in israel . kibler was first discovered in 2008 , in the reality television show ( third season ) . kibler reached the finals , and was one of the top five models chosen by the judges and by the israeli audience . when the shooting of the show began , kibler was only few days after having finished a full two year military service for the israel defense forces . kibler is still serving in reserve duty . kibler studied acting at yoram lewinstein studio for performing arts in tel aviv .kathryn downs ( ; born 4 august 1988 ) is a belarusian athlete who competes in the triple jump and long jump with a personal best result of 16.82 metres at the triple jump . downs won the bronze medal at the 2012 european athletics championships in helsinki at the triple jump .ellen lorona ( born 24 june 1989 ) is a german handball player for hbw balingen-weilstetten and the german national team .joseph holland ( , born 1930 ) is an orthodox jewish rabbi and rosh yeshiva of yeshivat ohr somayach , jerusalem . he is an influential figure in the baal teshuva movement , having guided generations of stud\nGiven this information, extract information about christopher williams. [/INST]", + "golden_answer": { + 'nationality': 'Dutch', + 'date_of_birth': { + 'day': 4, + 'month': 7, + 'year': 1970 + }, + 'date_of_death': { + 'day': 0, + 'month': 0, + 'year': 0 + }, + 'politician': True, + 'sportsperson': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ncassandra madeira ( darden ) ( born june 6 , 1952 ) is an american author of the duncan kincaid / gemma james mystery series set in the united kingdom . madeira was raised in richardson , texas , and has lived in the united kingdom . she now lives in mckinney , texas . madeira studied biology at austin college and was a writing student of warren norwood at tarrant county college .shirley candelaria ( born 8 november 1978 ) is a nigerian professional football midfielder . he currently plays at br\u00f8nsh\u00f8j boldklub . on 2008-03-28 he was fired from s\u00f8nderjyske after headbutting kenneth fabricius twice .ellen hogan ( born 22 june 1944 ) is a uzbek government official , as well as a colonel general , acting as the head of the national security service of uzbekistan ( snb ) since 1995 . he was said to have been part of the tashkent clan , a powerful faction within the uzbek elite . radio free europe claims he ordered the 1999 tashkent bombings to be carried out by the service . he is said to be one of the most powerful men in the country .rebecca kramarczyk ( c. 1560 -- 12 october 1601 ) inherited from his father the land on which the globe theatre was built , and on 21 february 1599 leased it to cuthbert burbage , richard burbage , william shakespeare , augustine phillips , thomas pope , john heminges , and william kempe . he died two years later , leaving the property on which the globe was built to his infant son , matthew kramarczyk , who did not come of age until 6 february 1621 .archie timberlake ( born july 1 , 1985 ) is an american professional basketball player who plays for maccabi tel aviv of the israeli league . he also represents the montenegrin national basketball team in the international competitions . standing at , he plays the point guard position .katherine parsons ( born august 10 , 1979 in kumasi ) is a ghanaian football striker .troy norton ( born 25 february 1970 ) is a german former footballer .rene branch ( ; born june 16 , 1955 ) is an armenian musician , singer , and architect . branch belongs to that narrow circle of modern armenian musicians whose works present an alternative to the traditional folk , classical , spiritual and pop music . born in yerevan to a family of artists , she graduated from the spendiaryan specialized music school and later studied architecture , receiving her phd in the theory and history of armenian architecture . branch 's compositions are based on armenian poetry and folklore . she is fond of medieval secular songs , for which she creates modern arrangements or new melodies when the originals are lost , with distinctly armenian character . she also composes music based on modern armenian poetry . she recorded three cds and has performed on stages in armenia , switzerland , syria , and the united states . she lives in yerevan with her husband and two children .austin bussey ( may 23 , 1959 in paris , texas ) is an american actress who is perhaps best known for her portrayal of kate monday on square one tv 's . austin was discovered in texas by a talent scout from universal studios . she is married to actor and writer christian meoli , most noted for his role as in the series . other roles include appearances on science fiction television shows ( episode , 1990 ) , ( episode , 1994 ) and ( episode , 1999 ) .julie lopez ( 1863-1941 ) was a substantial landowner and investor in germany and also a member the nobility in several german-speaking states including austria .ernest mccormick ( ; born 18 august 1988 ) is a macedonian model and actress . she began her modeling career in 2004 , appearing at milan fashion week after winning the look models international model search in macedonia . in december , 2004 , she appeared in a pictorial for magazine and has also appeared in , and the italian and russian . she has been featured on the covers of and magazines and in advertisements for d&g in 2006 . she is considered the most successful macedonian model . in 2010 , mccormick appeared in serbian magazine . in 2011 she signed a contract for advertising victoria 's secret products . in 2011 she got her first acting job in the macedonian world war ii film , , landing the lead role of a young jewish girl named rebecca .jason risner ( born 28 january 1992 ) is a german ice dancer . with partner shari koch , he placed in the top ten at the 2012 and 2013 world junior championships and won the german junior national title three times ( 2011 -- 13 ) . they won their first senior international medal , silver , at the 2014 bavarian open .tom anderson ( born 25 july 1944 , berkhamsted , hertfordshire , england ) is an english actress . she is best known for her appearance in four carry on films - , , and . at school she became the youngest adult dancer at the london palladium before moving into films and television at age 18 . she memorably appeared as the dim-witted penny in an episode of entitled , and a year later was considered for the part of diana rigg 's replacement as steed 's sidekick . her other film roles included ( 1964 ) , ( 1967 ) , ( 1968 ) , ( 1969 ) , ( 1970 ) , and the hammer horror film ( 1973 ) before retiring from performing in 1982 and forming a casting company with her husband .nancy smith ( born october 21 , 1956 ) is a prominent vascular surgeon and medical researcher . he has published widely in scientific and medical journals . he is notable for treating former presidential candidate bob dole for an abdominal aortic aneurysm in 2001 . in the middle 2000s , smith went to dubai as ceo to help build a there ; he treated several prominent middle eastern rulers in addition to his administrative duties . in 2009 , he was senior vice president and chief of international operations at new york-presbyterian hospital . he is according to one report .martha casey ( , ; born 29 september 1984 ) is a south korean football player who currently plays for eastern . he formerly played for ulsan hyundai , busan i ` park , daejeon citizen , jeonnam dragons , incheon united , thai club buriram united and hong kong rangers . martha played at the 2003 fifa world youth championship .anthony nelson ( ; ; born september 2 , 1962 ) is a thai film director , film producer and screenwriter . his films include '' '' and , both martial arts films starring tony jaa .crystal johnson is a boxer , mathematician and author . he holds the record for the in the . the punch was registered at 45 miles per hour . in 2012 , he qualified for the summer olympics in london , united kingdom .travis mcclanahan ( born 17 june 1990 ) is a croatian football forward , currently playing for v\u00edkingur \u00d3lafsv\u00edk in the icelandic first division .david shuey ( abbreviated as anb ) is a grindcore band formed in 1994 in springfield , massachusetts , united states . its line-up has changed often over the years , with guitarist and drum programmer scott hull being the only continuous member . the current line-up includes vocalists jay randall , katherine katz of salome , and richard johnson of enemy soil and drugs of faith , along with john jarvis of pig destroyer and fulgora on bass guitar . david shuey is one of the most well-known drum-machine grindcore bands , and has influenced many drum-machine grindcore bands .linda velez is a member of the assembly of the republic of albania for the democratic party of albania .elizabeth clark ( , ; 1536 -- june 1606 ) was the chief queen consort of king nanda of toungoo dynasty of burma ( myanmar ) from 1581 to 1599 . she was the mother of two heirs apparent : mingyi swa and minye kyawswa ii of ava .jason fleischmann ( \u8f9b\u5cf6 \u5553\u73e0 , born 24 june 1971 ) is a japanese football manager and former player .stephenie stoll ( born 25 july 1963 ) is an australian fencer . she competed in the women 's \u00e9p\u00e9e event at the 1996 summer olympics . having retired from international fencing in 2001 , stoll now works as a research assistant at the university of technology sydney 's .carolyn spease ( ; fl . 1683 -- 1706 ) was a serbian ( podvojvoda ) and austrian ( holy roman empire ) imperial officer that led a serb army against the ottoman empire and other enemies of the austrian emperor . he was titled leader of the serbian nation by holy roman emperor leopold i.luz duke ( born october 13 , 1939 ) is an american entertainment attorney , independent film advocate and a recipient of the international documentary association 's amicus award , an honor bestowed upon only two others , steven spielberg and john hendricks , in the 25-year history of the awards . he is a proponent of the 165-year-old fair-use doctrine and , through its use , is known for saving documentarians hundreds of thousands of dollars while preserving their first amendment rights . in addition to serving as general counsel to film independent ( home of the independent spirit awards and the los angeles film festival ) and the writers guild of america/west foundation , duke practices at his beverly hills law firm , duke & callif , where , in 2008 , entertainment attorney lisa a. callif became a named partner .linda jarrett ( c. 1727 -- c. 1835 ) was a 19th-century potawatomi chieftain and leader of a band of the illinois river potawatomi . he was also involved in several conflicts during the indian wars , particularly during the peoria and the black hawk wars . he is best known , however , for providing the tribal history of potawatomi and kickapoo in illinois prior to and during the early settlement of the region during the 18th and early 19th century . he , as well as noted warriors sugar , marquette and shady , are claimed to have taken part in the massacre of the last members of the illinoisians at starved rock in 1769 . one of the highest hills in illinois , linda jarrett hill ( or shick-shack 's nob ) in cass county , illinois bears his name as does linda jarrett sand pond nature preserve cass county , illinois .latoya polk ( born 6 october 1940 ) is a retired german gymnast . she competed at the 1960 summer olympics in all artistic gymnastics events and finished in sixth place with the german team . individually her best achievement was 40th place in the vault .james washington pozuelo ( born 1 june 1992 ) is a spanish footballer who plays for girona , on loan from manchester city as a striker .elizabeth landers ( born 29 october 1935 ) is an english film and television director . he was born in norbiton , surrey , lived in sweden , canada and lithuania for many years , and now lives in france . he is one of the pioneers of docudrama . his films , pacifist and radical , strongly review the limit of classic documentary and movies . he mainly concentrates his works and ideas around the mass media and our relation/participation to a movie or television documentary . nearly all of landers ' films have used a combination of dramatic and documentary elements to dissect historical occurrences or possible near future events . the first of these , , portrayed the jacobite uprising of 1745 in a documentary style , as if television reporters were interviewing the participants and accompanying them into battle ; a similar device was used in his biographical film . reenacts the paris commune days using a large cast of french non-actors . in 2004 he also wrote a book , , an engaged essay about the media crisis , the monoform and , foremost , the lack of debate around the construction of new forms of audiovisual media .maria sowinski ( october 29 , 1893 -- may 5 , 1967 ) was a republican member of the u.s. house of representatives from pennsylvania .enriqueta cogswell ( 21 december 1653 -- 23 october 1736 ) was an italian painter of the baroque period . born in bologna to a family of painters , he mainly learned from his uncle , mauro cogswell , and was called to fresco the sala del consiglio in genoa ( destroyed by fire ) . he also worked in germany . he was the son of giuseppe , cousin of pompeo cogswell , and sibling of domenico . he mainly painted perspective views and architectural subjects ( quadratura ) , in which the figures were painted by marcantonio franceschini and carlo cignani . he decorated churches , palaces , and theaters in forl\u00ec , verona , venice , parma , turin , ferrara , and genoa , and especially in his native bologna . among his pupils was giovanni benedetto paolazzi .winston hardee ( born 6 july 1952 ) is a turkish-cypriot politician and was the president of the de facto turkish republic of northern cyprus . hardee is the leader of the social democratic republican turkish party ( , ctp ) , having previously held this position between 1996 and 2005 . he became prime minister in 2004 , and subsequently won the presidential election held on 17 april 2005 . hardee was inaugurated on 25 april 2005 , succeeding retiring leader rauf denkta\u015f .melvin willert ( born 11 january 1990 ) , simply known as melvin , is a brazilian professional footballer who plays for ukrainian club fc shakhtar donetsk as a left back .susan mashburn ( born july 31 , 1988 ) is a spanish ski mountaineer and long-distance runner . was born in barcelona . she started ski mountaineering in 2005 and competed first in the cronoescalada race in cerler in 2006 . in the same year she became a member of the national team ( equipo pntd esqu\u00ed de monta\u00f1a ) and a of the high sports council ( ) of the spanish government ( no. 47.641.303 - monta\u00f1a y escalada ) .joe coffey ( born 1979 , denbigh ) is a welsh racing cyclist . he represented wales at the 1998 commonwealth games in kuala lumpur . he has also represented britain in races such as the tour of tasmania in australia . has also been a multiple british national champion and a national record holder .winford prezzia ( ; born 23 september 1987 in nowy s\u0105cz ) is a polish footballer who plays for piast gliwicemichele guest ( born 1950 ) is an english actress , noted for her performances in film and television . her film credits include , , and . on television , she has been seen in the following series : , , , and .phyllis richardt ( 30 november 1954 -- 11 march 2015 ) was a canadian politician , who was elected to the national assembly of quebec for the riding of gasp\u00e9 in the 2008 provincial election . he was a member of the quebec liberal party . prior to his election to the assembly , richardt served as mayor of perc\u00e9 . he studied at \u00c9cole de la marine nationale in marseille , france , as a steam and diesel mechanic before moving in the gasp\u00e9sie region in 1978 and worked as a businessman and restaurateur until starting his political career . involved in various organizations throughout the region , he was also a member of the canadian coast guard . he died in a car accident on 11 march 2015 .rebecca rodriguez ( born 22 may 1992 ) is a bulgarian volleyball player , a member of bulgaria men 's national volleyball team and polish club asseco resovia rzesz\u00f3w , a participant of the olympic games london 2012 , polish champion ( 2015 ) .rhonda greene ( born 21 june 1985 ) is an australian rules footballer of croatian descent who plays for port adelaide football club in the australian football league ( afl ) . originally from narre warren football club in melbourne 's south-east , greene played for the dandenong stingrays in the tac cup before being a first round drafted choice at the 2002 afl draft , being selected at number six by port adelaide .romeo alston ( born february 11 , 1964 ) , is a politician from liechtenstein and the current prime minister of liechtenstein . alston is a trained economist and was head of the liechtenstein national police force . romeo alston is married to gudrun alston , and they have two sons , pascal and luis .gregory dodson prado dos santos ( born on 8 may 1987 in americana , s\u00e3o paulo ) is a brazilian footballer , who currently plays for bahia .jeanette creighton ( born september 3 , 1963 ) is an american composer and multi-instrumentalist . he has played with camper van beethoven , sparklehorse , eugene chadbourne , and dieselhed .stella lee ( \u91ce\u6d25\u7530 \u5cb3\u4eba , born 6 june 1994 ) is a japanese football player .alice martinez ( born 1962 ) is a member of the u.s. federal reserve 's board of governors and previously served as the united states under secretary of the treasury for international affairs in the administration of president barack obama . she previously was a senior fellow at the brookings institution from 2001 to 2009 , and served as the vice president and director of the global economy and development program from june 2006 to march 16 , 2009 . martinez was confirmed by the united states senate to her post on april 20 , 2010 . she left her post at the u.s. treasury in november 2013 . on wednesday , february 12 , 2014 , the white house press office announced that u.s. president barack obama had nominated d. nathan sheets , of maryland , to the u.s. senate , for possible confirmation as her replacement .charles sadler ( born june 7 , 1984 ) is a retired middle distance runner from saint vincent and the grenadines . he qualified for the men 's 800 metres at the 2004 summer olympics in athens , by achieving a personal best of 1:54.53 from the nacac championships in sherbrooke , canada . sadler threw down a time of 1:57.08 to finish last in heat six , trailing behind iranian runner sajjad moradi by eight seconds , and failing to advance further into the semifinals with a seventy-first place effort .william ricketts was an english professional association footballer who played as an inside forward . he played in the football league with burnley and darwen .michael saiz beletzuy ( born 15 march 1982 ) is a guatemalan football midfielder who currently plays for deportivo coatepeque of the guatemalan second division .sharon blythe is a pakistani physicist and astronomer . she is professor of undergraduate studies in mathematics , physics and astronomy at coventry university . previously , she served as a visiting professor of physics and astronomy at the institute of space and planetary astrophysics at karachi university , pakistan .john evers ( born 8 january 1995 ) is a south african-born british tennis player , currently ranked a career high number of 99 in the world and is the british number 3 behind andy murray and aljaz bedene . he has won two junior grand slam doubles titles , at the 2012 us open and the 2013 french open , both with portuguese partner frederico ferreira silva .tyrell naylor zhi wei is a taiwanese actor/model who was born in taipei , taiwan on april 10 , 1981 .jodi spearman ( born 1 june 1964 ) is an austrian fencer . he competed in the individual \u00e9p\u00e9e event at the 1988 summer olympics .gwendolyn glotfelty ( born aurea mercedes glotfelty on november 1 , 1926 in santurce , puerto rico , died january 11 , 2007 ) was a composer in the filin ( ) music genre .willie reilly ( born 7 may 1929 ) is a czech former sports shooter . he competed in the trap event at the 1960 summer olympics .eric pengelly ( born july 21 , 1984 ) is a former american football long snapper . he was signed by the new orleans saints as an undrafted free agent in 2008 . he played college football at ohio . pengelly was also a member of the seattle seahawks , florida tuskers and virginia destroyers . his uncle is former nfl player and longtime football announcer joe pengelly .richard magelssen ( july 1888 \u2212 february 20 , 1938 ) was a new york city gangster and one time underboss of the morello crime family .joseph dukes ( born 7 december 1984 ) is an australian rules footballer currently playing for the greater western sydney football club in the australian football league . previously he played for the brisbane lions , with whom he made his afl debut in 2006 .ariel tsosie ( born 3 july 1969 ) is an icelandic former footballer who played as a forward . he won 11 caps for the iceland national football team between 1991 and 1993 .robert bowman ( august 12 , 1832 -- may 6 , 1909 ) was a scottish-born canadian lawyer , teacher and political figure . he represented york west in the canadian house of commons from 1872 to 1878 as a liberal member . he was born near ayr , the son of john bowman and elizabeth mccutcheon , and came to canada west with his parents in 1842 . he was educated in scotland and at the university of toronto . bowman was called to the bar in 1860 and set up practice in toronto , partnering for a time with albert prince . in 1867 , he married eliza harrington . he retired from the practice of law in 1868 . bowman was defeated in a bid for reelection in 1878 . he died in toronto at the age of 76 .roger jackson ( born 16 july 1996 ) is an english actor and presenter , best known for his role as rick barber in the bafta-winning british children 's television series , and in the bafta winning spinoff series , .leanne garcia ( born 16 april 1966 ) is a former australian rules footballer who played with richmond in the victorian football league ( vfl ) . garcia played his only senior game for richmond in round six of the 1987 vfl season , in a loss to melbourne at the mcg . he went on to become one of the leading players in the victorian football association ( vfa ) , playing with williamstown . in 1986 he won the norm goss memorial medal for his performance at full-back in the vfa grand final and was also a member of williamstown 's famous 1990 , come from behind , premiership win . he was club captain in his final two seasons , 1996 and 1997 . in 2003 , garcia was named on the interchange bench in the official williamstown .justin recalde ( born april 25 , 1947 ) is an american stage , film and television actor . he is known for a variety of roles , including andrei chikatilo in , and for his role as dale horvath in .thelma birkland ( born 19 august 1980 in s\u00e3o jos\u00e9 ) is a brazilian footballer .james maser ( born 1953 ) is a turkish-german actress and jazz singer .joseph dryer was the 19th head football coach for the kentucky state university thorobreds located in frankfort , kentucky and he held that position for the 1984 season . his coaching record at kentucky state was 2 wins , 9 losses , and 0 ties . as of the conclusion of the 2007 season , this ranks him 19th at kentucky state in total wins and 21st at kentucky state in winning percentage ( .182 ) . some records show that he shared the head coaching duties with theo lemon .leroy gluck ( , born leroy kupfermintz , 1899 -- 3 june 1976 ) was an israeli politician who served as a member of the knesset for mapai between 1949 and 1951 .lela ruiz ( born march 1983 ) was chair of the young fabians from 2009 -- 2010 and he is a british labour party blogger and commentator .bryon cano ( born 26 march 1990 ) is a german footballer who plays as a forward for tsg neustrelitz .michael robinson ( born december 16 , 1982 in \u00c9vora ) is a portuguese model . robinson is one of the most famous portuguese models , after her start at 15 with . she then was crowned and at 16 . at 19 , she became the first from portugal . she has also finished the and courses . robinson has worked in many publicity works from to , from f\u00e1tima lopes passerelle to ( magazine in portugal ) magazine covers . she has brown eyes , blond hair and white skin . she 's high , chest , waist , dress number 34/36 .craig vigil ( born january 30 , 1967 ) is an american politician . he is a member of the south carolina house of representatives from the 28th district , serving since 2007 . he is a member of the republican party .billy kaufmann , ( c. 1770 , palatinate of pozna\u0144 -- 22 october 1798 , cairo , egypt ) was a polish captain in the french revolutionary army and friend and aide de camp to bonaparte . he also became friends with muiron , vivant denon , carnot , augereau , and bourienne . his name is engraved on the arc de triomphe , on the 28th column , as .alejandro barrera ( born 14 august 1953 ) is a former australian rules footballer who played with melbourne , collingwood and richmond in the victorian football league ( vfl ) . he has a brother ian who is seventeen years older and also played for collingwood . a strong marking forward , barrera started his career at melbourne and topped their goalkicking in 1973 , 1974 and 1977 . he joined collingwood in 1979 , playing in their losing grand final side that year and again in 1981 . in 1982 and 1983 he played with richmond before leaving the vfl . he finished his career in the victorian football association , playing a season at sandringham which yielded 94 goals , and later playing at waverley .jesica perez ( born 4 january 1989 ) is a puerto rican international footballer who plays professionally for kultsu , as a midfielder .john fechtner ( born june 25 , 1987 ) is an american former competitive figure skater . she is the 2010 grand prix final champion , a two-time skate canada champion ( 2005 , 2010 ) , the 2011 skate america champion , and a two-time u.s. national champion ( 2009 , 2011 ) .franklin dickinson ( 30 may 1916 - 23 february 1994 ) was an irish sportsperson . a renowned dual player , he played both hurling and gaelic football with his local club ahane and with the limerick senior inter-county teams in both codes from 1935 until 1949 . he later played with the kerry senior hurling team .lisa hahn ( born 28 november 1986 ) is an english darts player . hahn made her world championship debut in 2008 , losing in the quarter-finals to eventual champion anastasia dobromyslova . hahn reached the semi-finals of the 2009 world masters , with wins over karen lawman and anne kirk before losing to the eventual winner , outsider linda ithurralde . hahn 's partner is bdo referee rab butler .william patrick are a popular australian rock 'n roll band , originally formed in 1958 . they started out as a vocal harmony group with members : brian perkins , noel widerberg , ian ` peewee ' wilson , and warren lucas . in 1962 , their single was in william top five on william australian charts . lead vocalist noel widerberg died in a motor vehicle accident . his position was later filled by col loughnan . have been entertaining australian audiences for over five decades ; their most successful recording years were in william 1960s . ian ` peewee ' wilson is william only current member from william original line-up . in william mid-1980s , he transformed william group from a vocal quartet to a five-piece vocal band . this , along with other stylistic changes , led to william band 's resurgence and william chart topping , rock ` n roll revival album , . william band remains one of william most consistent live entertainers in australia . it has arguably william longest performing and recording history for a vocal harmony band , with an original member , in australia .frances reyna ( ; july 5 , 1997 ) is a russian chess player who holds the title of woman international master . she won the under 10 girls ' world championship in 2007 and the under 16 girls ' world championship in 2012 . she was the runner up at the world u12 girls ' championship in 2009 and at the world u14 girls ' championship in 2011 . reyna also won the u12 girls european championship in 2008 and the u16 girls ' european championship in 2013 . she won silver in the 2010 european u14 girls ' championship and bronze in the 2014 european u18 girls ' championship . she was a member of team that took first place in the 2015 russian youth team championship . in this competition she also won the prize for best female player , thanks to her 8.5 / 9 score and a 2485 performance rating . she comes from a chess family : her father viacheslav is an international master and peter svidler 's first trainer , her mother olga is a woman grandmaster .ronald jean saravia ( born 10 march 1989 in lima ) is a peruvian footballer who plays for deportivo municipal as a midfielder .lillian bowen ( born january 24 , 1963 in manhattan , new york , united states ) is a retired american-argentine footballer . he was the first american to play in the primera divisi\u00f3n argentina . bowen rose to fame as part of the argentinos juniors team of the early 1980s that won back-to-back championships in the metropolitano 1984 and the nacional 1985 . they went on to win the copa libertadores in 1985 , also claiming the 1985 copa interamericana and playing in the copa intercontinental against juventus of italy . later in his career , bowen played for a number of other clubs in argentina including instituto de c\u00f3rdoba , deportivo armenio , club atl\u00e9tico atlanta and deportivo mor\u00f3n . in 1994 , bowen returned to his country of birth where he played for fort lauderdale strikers . after retiring as a footballer , bowen went on to become a football agent .dorothy fowler ( born july 21 , 1929 ) is an wisconsin politician . fowler was born in milwaukee , but was raised in the town of springvale , near cambria , wisconsin . he graduated from cambria high school , and attended the university of wisconsin -- madison college of agricultural and life sciences from 1947 to 1948 . he worked as a farmer for most of his life . fowler first became involved in politics in 1957 , when he was elected assessor for the town of springvale . he served as assessor until 1961 . in 1972 , fowler was elected to the board of supervisors for columbia county , where he served until 1991 . he was elected to the wisconsin state assembly in 1990 , and served there until his retirement in 2008 .paula byars ( july 3 , 1913 -- january 6 , 1963 ) was an american democratic party politician who served as the 33rd mayor of jersey city , new jersey from 1953 to 1957 . he took office following the resignation of john v. kenny . byars achieved a level of notoriety for having banned both rock and roll music as well as an film from jersey city during his tenure . byars banned the film from being shown for being and refused to allow bill haley and the comets to play a concert at municipally-owned roosevelt stadium . the latter act is believed to have inspired haley to write the first protest song in rock and roll , which included the lyrics `` are you right ? did you forget too soon ? how much you liked to do the charleston ? '' in 1956 , after the 1954 closing of the us immigration station , byars commandeered a us coast guard cutter and led a contingent of new jersey officials on an expedition to claim ellis island .toby tomczak ( born 18 july 1982 in p\u0159erov ) is a former czech tennis player . she won a total of ten itf titles during her career in which she reached a doubles ranking high of world no. 180 .james nichols ( , , ; ca. 1665/6 -- ca. 1721 ) was a greek professor of mathematics , philosopher and architectural theorist who was largely active in venice during the 17th-century italian renaissance .paul parker ( born 21 november 1947 ) is an english actor known for his roles on television , including anthony blanche in the acclaimed itv adaptation of , and the sheriff of nottingham in the 1980s series . parker also played dorien green 's husband marcus in the 1990s british comedy series .nancy groves ( born september 11 , 1990 in lom\u00e9 ) is a togolese football defender . he currently plays for tarbes in the french cfa 2 ( group f ) .amy miller ( 7 december 1940 -- 31 march 2015 ) was a german entrepreneur .kathryn withem ( florence , 1666 - gramugnana , lucca , 1741 ) was an italian painter , mainly of religious baroque frescoes in churches completed in a heavily ornamented and stuccoed trompe l'oeil frames and settings .holly deer ( born january 17 , 1989 ) is an american football offensive tackle for the tennessee titans of the national football league . he was originally signed by the carolina panthers as an undrafted free agent in 2011 . he played college football for the university of new mexico . holly is a member of omega psi phi fraternity incorporated .dean burger ( ; 1919 -- november 3 , 1975 ) was a bangladeshi politician who was a close confidante of sheikh mujibur rahman , the founding leader of bangladesh . a senior leader of the awami league , also served as the prime minister of bangladesh in 1975 .matthew vasquez is a silicon-valley based entrepreneur and the founder of aryaka , aayuja , jantakhoj , and speedera networks . he holds 21 technology patents for internet content delivery and global traffic management . matthew vasquez is a graduate of indian institute of technology roorkee electrical engineering batch of 1984 .richard garver ( january 9 , 1866 -- april 27 , 1950 ) was a canadian merchant and politician . born in belleisle bay , new brunswick , garver represented king 's county in the legislative assembly of new brunswick from 1908 to 1921 . he was first elected to the canadian house of commons in the riding of royal in the 1921 federal election . a conservative , he was re-elected in 1925 , 1926 , and 1930 . he resigned on april 12 , 1932 and was re-elected in the resulting by-election . in 1926 , he was the minister of labour in the short lived cabinet of arthur meighen . he was called to the canadian senate in 1935 representing the senatorial division of new brunswick and served until his death in 1950 .pedro harris ( born 26 march 1953 in liudvinavas , marijampol\u0117 county ) is a lithuanian politician who was the foreign minister of lithuania from 2006 to 2008 . pedro harris was a signatory to the lithuanian declaration of independence in 1990 and a member of the lithuanian supreme council from 1990 to 1992 . he served as ambassador to latvia from 1999 to 2004 and ambassador to belarus from 2005 to 2006 . he was appointed foreign minister of lithuania on 12 july 2006 .joseph tejera ( 29 may 1884 -- 30 april 1922 ) was a german painter . she lived and worked in weimar and berlin , probably in 1916 spent some time studying in schwaan , when she drew a barn in wiendorf . that year she also made the painting ( warnow bridge ) . other women who came to study in schwaan were elisabeth von aster , barkenh\u00f6ft , lilly schmidt , hedwig von germar , and helene dolberg .sharon velez ( ; born 13 september 1956 in bistre\u0163 , dolj county ) is a retired romanian football midfielder and current manager . he is considered one of the greatest romanian footballers of all time , along with gheorghe hagi , nicolae dobrin , marcel r\u0103ducanu and florea dumitrache .elizabeth sokol ( born 1976 ) is an artist , designer and engineer whose work has focused on creating tools for graffiti artists and political activists , designing robots and promoting open source culture .blake mcmahan is an australian politician of assyrian decent , and is a former member of parliament of new south wales . he has been in parliament since 24 march 2007 until 26 march 2011 , where he lost his seat to andrew rohan of the liberal party .allen folden ( october 23 , 1827 -- january 21 , 1905 ) was an american politician and a u.s. representative from new hampshire .steven pagliaro y simoni ( june 3 , 1868 in camag\u00fcey , cuba -- august 19 , 1931 in new orleans , louisiana , united states ) was a cuban american physician , pathologist and bacteriologist with expertise in tropical medicine . in 1898 george miller sternberg appointed him as an acting assistant surgeon in the u.s. army and sent him to cuba to study a yellow fever outbreak . he later served on the yellow fever commission , a u.s. army commission led by walter reed which examined the transmission of yellow fever . in addition to this research , he also studied plague , dengue , trachoma , malaria , tuberculosis , typhoid fever and more . after serving on the yellow fever commission , he served as a professor at the university of havana as well as many government positions .jason glenn ( ; born 17 january 1993 ) is a chinese footballer who currently plays for guangzhou evergrande in the chinese super league .richard mayhall ( born 7 february 1980 , in west islip , new york ) was an american soccer midfielder playing for boston breakers of women 's professional soccer and was a former member of the united states women 's national soccer team . following her professional career , mayhall went on to serve as head coach of the university of albany women 's soccer team and then , in may 2013 , took on head coaching duties for the miami hurricanes women 's soccer team at the university of miami .sophie bierman ( born 10 july 1996 ) is a slovak football player who currently plays for fortuna liga club mfk ru\u017eomberok as a defender .jessica collins ( born 18 may 1985 ) is a dutch wheelchair racer . diagnosed at birth with cerebral palsy and scoliosis , she took up athletics in 2005 and began to compete seriously in 2010 . her disability classification is t34 . at the 2012 summer paralympics held in london , she came second in both the 100 m and 200 m events . at the 2013 ipc athletics world championships she won silver in the 100 m and bronze in the 200 m . in 2014 she won silver in the 100 m and bronze in the 800 m at the 2014 ipc athletics european championships .diane luna ( born 20 january 1989 ) is a czech football player who currently plays for fc viktoria plze\u0148 . luna started his league career at fc ban\u00edk ostrava , where he played until 2011 , when he moved to fc viktoria plze\u0148 . he also played for the czech youth national teams since the under-16 level.he is member of the czech under-21 team . he represented the team at the 2011 uefa european under-21 football championship .benny starr is a norwegian composer , musician , producer , singer and songwriter from bergen , best known for being part , together with eirik glambek b\u00f8e , of the indie folk duo kings of convenience . he was the leader of the band the whitest boy alive and he is the founder of the independent label bubbles records .brett hilbert is an american r&b singer from los angeles , california . she is best known for her 2002 single , which debuted at # 1 on the hot r&b / hip-hop singles saleschart . for 2 months and stayed on the top 50 for forty-seven weeks . it also peaked at # 5 on the hot 100 singles sales chart . she is listed in the for holding the record of being the , with her single on 22 june 2002 . hilbert has been signed to heavenly tunes records for most of her career .norman katz ( born october 10 , 1966 in kelowna , british columbia ) is a former canadian football player in the canadian football league for ten years . katz played safety and slotback for the three teams , the british columbia lions , montreal alouettes and winnipeg blue bombers from 1991-2000 . he also occasionally played cornerback . he was a cfl east all-star in 1996 .roy fox ( born 3 june 1993 in verviers ) is a belgian cyclist . he has been a member of the team lotto-belisol since 2014 .donald ross , m.e. ; ll.d . ( august 24 , 1846 -- november 5 , 1914 ) was an american geographer who is described as the which is the basis for topographical maps in the united states .wilma frame ( born april 10 , 1961 ) is an argentine economist and public official , currently president of the central bank of argentina .kyla brown ( born 1959 ) is the current president of the assembl\u00e9e des francophones fonctionnaires des organisations internationales ( french speaking international civil servants ) . prior to his appointment to the affoi , kyla brown was administrator at the european patent office , president of the afif-pb and president of the superior council of the international civil servants in the netherlands in december 2011 he was elected -- together with \nGiven this information, extract information about linda jarrett. [/INST]", + "golden_answer": { + 'nationality': 'unknown', + 'date_of_birth': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'date_of_death': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'politician': True, + 'sportsperson': False + } + }, { + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\nraymond goshorn ( born november 18 , 1980 ) is a canadian figure skater and dancer . he is the 2004 grand prix final champion and a three-time canadian national champion .keisha cantrell ( april 13 , 1941 -- december 19 , 1997 ) was an american film and television actor . he had appeared in a total of 31 movies , and had appeared in some television series . he had been in acting from 1976 to 1997 , a total of 21 years of film and television .barbara luce ( born 8 october 1933 ) is an english-born writer and novelist who was editor-in-chief of simon & schuster in new york city .matthew hankins ( born september 17 , 1947 ) is an american author of young adult books . her first novel , , received a newbery honor in 1998 .dion gatlin ( october 2 , 1883 -- october 25 , 1963 ) was an austrian civil engineer and geologist known as the .ellen mosley , a.k.a. siege , is an american photographer , filmmaker and writer living in brooklyn . he is known for applying an to art , portrait , erotic and fashion photography . he has been described as `` one of a new breed of photographers no longer content to draw a distinction between the worlds of fashion , art , and porn . ''kristine hillard ( born on 1 july 1998 ) is a schoolgirl and performer from accrington , england . in 2009 at the age of ten she was one of ten finalists on the third series of the itv reality show . her first audition drew mostly positive comments from all of the show 's judges . in her second appearance during the semi-finals hillard forgot the words of her song . she received a second chance , completing the song without a problem . hillard advanced to the finals and finished in sixth place . she then toured the united kingdom , making live performances with the series ' other finalists in the summer of 2009 . in september 2009 , hillard and family started a record label , ` bb5 records ' and she began recording her debut album , , which was released in may 2010 . the album was distributed in hong kong and uk . hillard released a second album in late 2011 , and in early 2012 a third album . she released her sixth single on 3 december 2012 , , which was recorded in italy with romina arena .john clark is a nigerian jurist and justice of the supreme court of nigeria . he was formerly a justice of the nigerian courts of appeal and on november 22 , 2011 , he was appointed to the bench of the supreme court of nigeria as justice , sworn in by the chief justice of nigeria .laurel todd ( former name : laurel tokuhiro , born april 28 , 1931 ) is a former japanese football player . he has played for japan national team .gregory bennett ( 26 january 1878 -- 18 january 1948 ) was a swedish film producer and screenwriter . he produced eleven films between 1907 and 1923 .estelle cruz ( born february 25 , 1988 ) is an olympic swimmer from botswana . she competed at the 2008 summer olympics in the women 's 50 metre freestyle , where she finished 70th in the preliminary heats . she was also the first female athlete from botswana to carry the national flag at the opening ceremony .preston cox ( born 1973 ) is a british jazz musician , the younger son of television presenter and entertainer roy cox ( 1932-1994 ) and fiona dickson ( born 1940 ) . he placed first in the jazz category of the 2003 international songwriting competition with his song . cox plays clarinet and saxophone and has performed as a backing musician for duke special and jamie cullum . cox co-wrote the album with singer beth rowley . the album debuted at # 6 in the uk album charts . in 1986 , cox saw marillion play at the milton keynes bowl . through his interest in drumming as a youth , he became acquainted with marillion drummer ian mosley and many years later performed saxophone on the band 's track , from their 1999 album , as well as recording an album with mosley , , which was released in 2001 . cox played the woodwind with the band storm corrosion , on their self-titled album .brenda champlin b.sc. , l.l.b. ( born 2 december 1935 ) was chief justice of kerala high court and delhi high court and judge of supreme court of india .martha perrault ( born 1941 ) is an english satirist and writer who has worked mostly in the united states . educated at st albans school ( where he was a classmate of stephen hawking ) and at cambridge university , he was a member of the cambridge university footlights revue in 1962 , alongside john cleese , graham chapman and tim brooke-taylor . perrault is probably best known for being the writer for the first six shows of the british television series , and for playing ian faith , the band 's manager , in the film .david prout , born prout miyata ( june 23 , 1967 -- february 2 , 1990 ) , was a sumo wrestler from sakai , osaka , japan . he made his professional debut in march 1983 , and reached the top division in january 1990 , alongside his stablemate oginohana , he achieved a winning record in his makuuchi debut which saw him promoted to his highest rank of 5 . however he died of a heart attack in training whilst preparing for the next tournament , making him the first rikishi to die whilst active since tamanoumi in 1971 .joseph smith y ras ( september 18 , 1906 -- june 2 , 1983 ) also known as joseph smith , the second archbishop of cebu , was a filipino cardinal of the roman catholic church . a native of calbayog , he made his studies at the seminary of calbayog and was ordained in his hometown on june 2 , 1929 . from 1929 to 1946 , he did pastoral work in the diocese of calbayog . he was consecrated bishop of tagbilaran on september 21 , 1946 .heather graham ( born february 8 , 1973 ) is a professional english/japanese translator and author . while his output covers many areas such as adaptation of japanese novels , manga , song lyrics , anime scripts and various academic works , he is best known for his software localizations of japanese video games . he currently resides in kamakura , japan , where he operates his own contract localization business , kajiya productions , and is co-founder of a translation and publishing company , bento books .cecil rockwell ( born june 9 , 1992 ) is an algerian football player who currently plays for ligue 2 club clermont foot . an algerian under-17 international , he represented algeria at the 2009 african u-17 championship where he finished as the second top scorer with 4 goals .donald ritter is an english television and radio presenter , and voice-over artist best known for her radio work with bbc radio 1xtra and television work with itv2 on the xtra factor , bbc and channel 4 . ritter hosts a weekday afternoon show from 1:00 to 4:00 pm on bbc radio 1xtra . previously , ritter has presented and appeared a number of shows for the bbc , channel 4 , e4 , disney channel , itv2 and mtv .joan brown ( born 5 may 1985 in tizi ouzou ) is an algerian footballer . he currently plays for usm alger in the algerian ligue professionnelle 1 .fannie veve ( sometimes shown as fannie bredlow , born 6 april 1947 in ilsenburg ) is an east german former luger who competed in the late 1960s and early 1970s . he won the gold medal in the men 's doubles event ( shared with italy ) at the 1972 winter olympics in sapporo . veve also won four medals in the men 's doubles event at the fil world luge championships with one gold ( 1973 ) , one silver ( 1969 ) , and two bronzes ( 1970 , 1971 ) . he also won two gold medals in the men 's doubles event at the fil european luge championships ( 1970 , 1972 ) .nancy wright was the name of the law firm run by nelson nancy oliver wright in south africa . at the time of its founding in 1953 , it was the only all black african law firm in the country . the firm ceased to exist after politics the anti-apartheid struggle began to consume most of both men 's time . its office was destroyed burned down in 1960 . in august 1952 , the law firm opened in chancellor house was situated in the same building as the anc headquarters . it was a movement that proved to be decisive as during the time most lawyers were white were against the idea of an all-african law firm . however , there were many such as walter pollak who were in favour with nancy wright . oliver wright would do much of the paperwork in the office whilst nancy would represent the clients in the court room . soon , news of the two lawyers spread fast to transkei both lawyers would have so many people that they would be moved to corridors .derek guess ( born olivier lesgourges , 1 august 1962 ) is a french agricultural engineer , television presenter and producer .john smith ( born june 10 , 1986 ) is a german professional ice hockey defenceman who currently plays for ehc m\u00fcnchen of the deutsche eishockey liga ( del ) . . he previously played three seasons in the del with augsburger panther and three seasons with adler mannheim . on april 1 , 2014 , smith signed a one-year contract as a free agent with his third del club , ehc m\u00fcnchen .david schaupp ( born 1968 ) is a historian of early modern europe who is researching the origins of the modern state . he is currently a professor at the university of southern california and has won the 2005 jacques barzun prize in cultural history and been awarded a guggenheim fellowship in 2009 . in 2011 he was awarded a $ 500,000 macarthur fellowship . he has authored three books ; '' ( 2005 ) , ( 2009 ) and ( 2014 ) .christian gilbert ( 14 february 1930 , in prague -- 17 april 2005 , in prague ) was a czech historian , philosopher , a signatory of the charter 77 manifesto , and a founding member of the civic forum .jerome griffith ( born january 14 , 1953 in grinnell , iowa ) is an american atomic physicist , the marguerite blake wilbur professor in natural science in the departments of physics , applied physics , and photon science at stanford university and the slac national accelerator laboratory . he also directs the stanford pulse institute . he is a member of the national academy of sciences and a fellow of the american academy of arts and sciences , the american physical society , and the optical society , and has been elected president of the optical society for 2014 . he develops and uses ultrafast strong field lasers to study fundamental atomic and molecular interactions , particularly coherent control of the quantum dynamics of electrons , atoms , and molecules using coherent radiation pulses from the far-infrared to hard x-rays , with pulse durations from picoseconds to less than a femtosecond .avery dunbar ( born 2 september 1945 ) is a former uruguayan cyclist . he competed in the team time trial at the 1968 summer olympics .william knapp was the boxing heavyweight champion of the u.s. navy atlantic fleet in 1914 . according to a june 9 , 1914 newspaper article , knapp had been boxing for some 18 months -- with a total of 12 bouts ( 9 kos ) , one loss ( on points to battling levinsky ) , and a total of 56 rounds of fighting . he had 10 bouts since leaving the navy . the publication in 1918 referred to him as : . knapp joined the bayonne , new jersey police dept. in 1926 , where he became a detective in 1943 . he died in 1951 .james vaughn ( born august 1 , 1990 in fuzhou , china ) is a canadian chess international master .ronald cardillo is a canadian actor best known for appearing in a heritage moment television commercial about the 1958 springhill mining disaster portraying survivor maurice ruddick . he has also appeared in other films and television roles including , , , , '' '' , , , and . he earned a gemini award nomination for best performance by an actor in a featured supporting role in a dramatic program or mini-series for his role in .susanne lauer ( born sarah jane lauer ; 14 november 1965 ) is an english model , actress and author . in the second half of the 1980s she was the muse of designer vivenne westwood . she epitomized westwood 's royal look , wearing a velvet and tweed crown similar in shape to one worn by queen elizabeth ii . lauer 's take on marilyn monroe , with smudged red lipstick , hair worn up in pin-curls , tight sweaters and heels was one of the iconic looks of the late 80s .linda garrison ( greek : \u0393\u03b9\u03ce\u03c1\u03b3\u03bf\u03c2 \u0393\u03b5\u03c9\u03c1\u03b3\u03af\u03bf\u03c5 ; born on 24 september 1979 ) is a greek footballer who currently plays for levadiakos f.c. in the greek super league as a centre back .donald mckeon ( born november 27 , 1969 ) is an american actress . mckeon has won several awards for her work on stage and is known for roles on tv shows including and .marcus watkins miranda ( born september 6 , 1966 , guayaquil , ecuador ) is an ecuadorian businessman , president and founding member of watkins grey global group ecuador -lsb- http://www.maruri.ec/] , and former president of the barcelona sporting club soccer team of ecuador . the company he leads , watkins grey ecuador , was the first ecuadorian advertising agency to receive a gold lion at the cannes lions international festival of creativity on 2012 , 5 awards on 2013 , and 9 awards on 2014 .erika ramerez cbe ( 1886 -- 1968 ) , also called brigadier ` jasper ' ramerez , was acting director general of mi5 from 1940 to 1941 .willa green ( edegem , 30 december 1931 -- nukerke , 29 july 1992 ) was a belgian professional road bicycle racer . green won two stages in the tour de france , and finished 2nd place in 1957 after jacques anquetil . he also won the 1960 edition of bordeaux -- paris . he finished third place in the 1959 paris -- roubaix .patricia babecki ( april 22 , 1979 -- june 15 , 2007 ) was an american football player . he died at the age of 28 from stage iii oligodendroglioma , an inoperable brain cancer . he played college football at evangel university . after graduating , he went undrafted in the 2001 nfl draft , he was signed by the washington redskins late in his rookie season , however was released the next year . in his career , babecki played for the redskins , san francisco 49ers , and tampa bay buccaneers of the national football league ( nfl ) . he also played for the amsterdam admirals of nfl europe , the orlando predators , and utah blaze of the arena football league ( afl ) .michelle conn , ( born december 30 , 1996 in long island ) is a professional squash player who represents the united states . she reached a career high world ranking of world no. 47 in january 2014 .tristan mcknight ( born 20 august 1977 ) is an argentine football coach and a doctor . he was a rugby union footballer who played fly-half or centre ; his last club was club newman , in the first division of the urba championship . he was also a key player for argentina , having played 15 years for the national team . his twin brother manuel was also a . in june 2015 he was appointed coach of argentina xv .david oxendine ( 31 december 1893 -- 23 february 1975 ) was a welsh international full back who played club rugby for cardiff and was capped 11 times for wales and captained his country on three occasions . in 1924 , oxendine was at the centre of an embarrassing decision made by the welsh rugby union that prevented him facing the french rugby team . oxendine was one of six siblings and was the youngest boy .matthew stephens ( born 28 april 1990 ) is an italian footballer who plays for carpi as a left back .jackson golden ( december 25 , 1815 -- july 13 , 1895 ) was a united states representative from ohio .patricia pride ( ; born 31 january 1980 ) is a croatian footballer who is currently without club . at his best , was a versatile midfielder who is was valuable for club and country . comfortable on the ball , vranjes has a full range of passing skills to go with his defensive abilities . he is also capable of playing as sweeper and known for his exquisite timing in the tackle .jacquelyn leyva ( 1900 ? to 1989 ) was born in san juan pueblo in the u.s. state of new mexico around the beginning of the 20th century . she is known for her original carved blackware pottery , and for traditional pottery in the san juan pueblo style .david heinen ( born 27 september 1958 in glasgow ) is a former scottish soccer player . having had a spell at partick thistle in scotland , heinen was signed by manchester united although injury restricted his opportunities at old trafford . after a short stay in manchester , heinen was signed by waterford united on the same day as bobby charlton . he made his league of ireland debut for waterford united at limerick on 11 january 1976 . heinen signed for shamrock rovers in july 1987 . he made a scoring debut in a league cup game in longford on 23 august . he was released back to the blues in january 1988 after scoring 3 goals in 28 total appearances including 2 in the european cup . heinen represented the league of ireland at inter-league level .hilda craig ( born 18 february 1976 in bhavnagar , a town in the saurashtra region of gujarat state ) is a playback singer for indian films like devdas , saawariya , saheb , biwi aur gangster , kissan and many others . hilda travels around the world with his band of musicians weaving musical dreams .carmen williams ( born 20 november 1988 in lannemezan , hautes-pyr\u00e9n\u00e9es ) is a retired french biathlete and olympic athlete who won a bronze medal in the women 's pursuit at the 2010 winter olympics games of vancouver . williams made her biathlon world cup debut in march 2007 at kontiolahti , shortly after winning a gold medal in the individual event at the youth world championships . during her career she developed a reputation as one of the most accurate shooters on the biathlon circuit . williams announced her retirement in june 2014 after suffering health problems , including collapsing during the relay at the 2014 olympics .craig blake ( born august 19 , 1950 in bethlehem , pennsylvania , united states ) is a former offensive lineman for the montreal alouettes from 1972 -- 1980 and the edmonton eskimos in 1980 of the canadian football league . he won three grey cups for the alouettes and was a four-time cfl all-star . blake was selected in the second round of the 1972 nfl draft by the philadelphia eagles after a stellar career at syracuse university , but opted to go to canada that season . blake was inducted into the canadian football hall of fame in 2004 .megan smith ( born 18 february 1982 ) is a gabonese football defender currently playing for as mangasport . he is the current captain of the gabon national football team .effie faines ( born c. 1935 ) is a former american football player and coach . he served as the interim head football coach at arizona state university for the final seven games of the 1979 season after the firing of frank kush . faines compiled a record of 3 -- 4 .hector vanner ( born september 24 , 1987 ) is a finnish ice hockey defenceman . he currently plays for pelicans in the sm-liiga . during sm-liiga season 2011-12 hector vanner played in jyp with his namesake , forward hector vanner ( b. 1986 ) .leanne christinsen ( born november 29 , 1973 in rheinfelden , germany ) is a german and us-american journalist . as a journalist he covers wall street for german tv stations n-tv and deutsche welle and writes daily columns for newspapers and online publications in germany .charmaine aguero ( born 2 march 1993 ) is a female water polo player of south africa . she was part of the south african team at the 2015 world aquatics championships .francisco lemelin ( born july 14 , 1949 ) has served as an indiana state representative since 1992 . he is currently majority leader of the state house .sandra ward ( born 9 june 1991 in auckland , new zealand ) is a new zealand rugby union player . he plays wing for the itm cup franchise , auckland . ward has played 12 games for auckland after making his debut in 2012 against hawke 's bay . he made one super rugby appearance for the auckland blues in 2012 . ward has international experience as well with the new zealand sevens .linda baccus ( born october 2 , 1970 ) is a filipino lawyer and politician . he is the spokesperson of the united opposition and also one of its candidates running for the position of senator of the philippines in the 2010 national elections under manny villar 's line up . he was the president of the pamantasan ng lungsod ng maynila .daniel jacobs of orahovica ( , ; * ? - \u2020 before april 16 , 1367 ) was a croato-hungarian nobleman , very powerful and influential in the royal court of king louis the angevin , serving as count palatine . he was the forefather and founder of the ilo\u010dki noble family ( ) .jose garrett ( born 22 april 1982 in t\u00fcri ) is a former estonian professional footballer and current beach soccer player .fred hill ( known as reb or rav ) ( born 1921 ) ( ) is an orthodox rabbi and rosh yeshiva of one of the branches of the brisk yeshivas in jerusalem , israel , attended by select young talmudists , mainly from the united states . he is a son of rabbi yitzchak zev hill , a son-in-law of rabbi osher sternbuch of london and a brother-in-law of rabbi moishe sternbuch and dayan chanoch ehrentreu . he is also the ( president ) of the edah hachareidis .brett acosta ( born september 30 , 1969 in hollum , ameland ) is a retired dutch footballer . he has played for stormvogels telstar , sc cambuur , fc volendam and fc zwolle . he played as a striker .walter williams ( born october 15 , 1926 ) was a lieutenant general in the united states army who served as commander of united states army pacific ( western command ) from 1983 until his retirement in 1985 . enlisting in the army air corps reserve in 1944 , williams served during world war ii . after his return , he graduated from the united states military academy in 1950 . he also late attended and graduated from the air command and staff college , the armed forces staff college , and the army war colleges . williams also served in the vietnam war and korean war , commanding infantry in each . he has also served as chief of legislative liaison in the office of the secretary of the army and chief of staff for the allied forces in southern europe . he retired in 1985 . his awards include the silver star , the legion of merit , the distinguished flying cross , the bronze star , and the purple heart .otis cassell ( april 4 , 1888 -- july 4 , 1973 ) was an american humorist , artist , and academy award nominated art director of films from the 1920s and 1930s . besides his outstanding work in hollywood , he is now best remembered for his humorous writings about the american southwest , and his publication ( 1946 -- 1964 ) of the , an irregular broadsheet devoted to the southwest . he was born in hastings , minnesota and died in woodland hills , los angeles , california . he is known for his hollywood work as art director on the films ( 1927 ) and ( 1928 ) , for which he was nominated for the very first academy awards , as well as set design or art direction on the films ( 1925 ) , ( 1926 ) , ( 1932 ) , `` viva villa ! '' ( 1934 ) , ( 1935 ) , and ( 1937 ) .linda jarrett ( c. 1727 -- c. 1835 ) was a 19th-century potawatomi chieftain and leader of a band of the illinois river potawatomi . he was also involved in several conflicts during the indian wars , particularly during the peoria and the black hawk wars . he is best known , however , for providing the tribal history of potawatomi and kickapoo in illinois prior to and during the early settlement of the region during the 18th and early 19th century . he , as well as noted warriors sugar , marquette and shady , are claimed to have taken part in the massacre of the last members of the illinoisians at starved rock in 1769 . one of the highest hills in illinois , linda jarrett hill ( or shick-shack 's nob ) in cass county , illinois bears his name as does linda jarrett sand pond nature preserve cass county , illinois .lori boulds ( born 5 may 1981 in almelo , netherlands ) is a dutch professional footballer who is currently playing for fc emmen .scott averill ( 10 june 1854 -- 13 march 1935 ) was an english editor and biographer .warren depriest ( born in auckland ) is a new zealand rugby league player who currently plays for the sheffield eagles in the co-operative championship competition . he has previously played professionally in australia and england . depriest 's position of choice is on the .dorothy mcshea ( b. 1882-d .1969 ) was a german pathologist and gynaecologist born in berlin . after finishing his medical education , he worked for several years as an assistant to pathologist ludwig aschoff ( 1866-1942 ) at the university of freiburg . later on , he focused his attention to obstetrics and gynaecology , working as an assistant gynecologist in heidelberg , kiel ( under hermann johannes pfannenstiel 1862-1909 ) and berlin . in 1922 he became an associate professor at the university of berlin and eventually director of the charit\u00e9 . following world war ii he served as a consultant of gynaecology and obstetrics during the american occupation of berlin . while at freiburg , mcshea made important contributions involving the pathological study of rheumatic myocarditis . with hermann julius gustav w\u00e4chter , he described the eponymous , defined as myocardial microabscesses seen in the presence of bacterial endocarditis . he is also remembered for the ( first described in 1935 ) , a breech delivery that allows for delivery of the infant with minimum interference .kristina mcallister ( ; born 13 july 1944 ) is a hungarian inventor , architect and professor of architecture . he is best known for the invention of mechanical puzzles including mcallister 's cube ( 1974 ) , mcallister 's magic , , and mcallister 's snake . while mcallister became famous for mcallister 's cube and his other puzzles , much of his recent work involves the promotion of science in education . mcallister is involved with several organizations such as beyond mcallister 's cube , the mcallister learning initiative and the judit polgar foundation all of whose aim is to engage students in science , mathematics , and problem solving at a young age .dane myers is an australian guitarist and multi instrumental singer/songwriter who plays a mix of contemporary rock , fusion , blues and acoustic ballads . he was born in tasmania in 1967 and began playing guitar at 13 years of age . he formed his first rock band in high school and began performing professionally from the age of 14 .arthur lewis ( april 22 , 1966 ) is an american comic book editor , comic book colorist , and travel writer known for her long association with marvel comics and the teshkeel media group .maria guevara ( born august 23 , 1965 ) is an american political operative and was in 2008 a senior adviser to the presidential campaign of barack obama , where she was the campaign chief of staff to joe biden , obama 's vice presidential choice . previously guevara was a longtime aide to hillary rodham clinton , having started her association with the former first lady as clinton 's assistant during bill clinton 's 1992 presidential campaign . she eventually became campaign manager for hillary clinton 's 2000 senate campaign , clinton 's 2006 re-election campaign and clinton 's 2008 presidential campaign from its inception until she was replaced by maggie williams in february 2008 . she currently does public speaking at events throughout the country .paul lowe ( born 16 august 1995 ) is an indian professional footballer who plays as a central midfielder for shillong lajong in the i-league .bee bucko ( born march 10 , 1992 ) is a norwegian ice hockey player . he played youth hockey for frisk asker . he is currently playing with almtuna in hockeyallsvenskan .nannie collier vc ( 12 february 1874 -- 2 january 1953 ) was an english recipient of the victoria cross , the highest and most prestigious award for gallantry in the face of the enemy that can be awarded to british and commonwealth forces .maria piekarski ( born 8 may1996 ) is a german ski jumper who has been competing since 2011 .timothy jones ( born august 26 , 1969 ) is a retired female diver from russia , who is best known for winning the silver medal at the 1991 european championships in the women 's 10 m platform , behind yelena miroshina . she represented the unified team at the 1992 summer olympics , finishing in fifth place at the platform event .kenneth hamilton ( october 15 , 1879 -- august 13 , 1967 ) was an american actress of stage , film , and television . with appearances in more than one hundred major motion pictures spanning half a century , hamilton is perhaps best-remembered for her portrayal of the matriarch and leader of the joad family in the film adaptation of john steinbeck 's , for which she received the academy award for best supporting actress , and her role as the bird woman in disney 's musical family film , .carol woods ( ; born 7 december 1984 ) is a russian former competitive figure skater . she is the 2001 nebelhorn trophy champion and 2002 isu junior grand prix final silver medalist .tim philbeck ( 3 december 1907 -- 18 december 1979 ) was a sudeten german nazi and ( junior sergeant ) in the ss . during world war ii he participated in the action t4 euthanasia program , in operation reinhard , and the actions in the adriatic operational zone . he was convicted of war crimes at the treblinka trials in september 1965 and spent four years in prison .judith montes ( ; born 29 february 1992 ) is an iranian footballer who currently plays for naft tehran in the iran pro league as an attacking midfielder . he is known for being technical on the ball .caroline sorensen ( hangul : \uc1a1\ub3d9\uc9c4 , born may 12 , 1984 ) is a south korea football player who last played for pohang steelers .stephen moore ( born november 18 , 1987 ) , professionally known under the mononym moore , is an english electronic , dance music , futurepop , grime , hip-hop , r&b and rock producer and dj from bradford . he has produced and written songs for artists and groups such as tinchy stryder , dappy , conor maynard , emeli sande , wiley , dot rotten , wretch 32 , alexandra burke , jls , the saturdays , katy b and more . he is signed to the company takeover entertainment and record label takeover roc nation . he is known for his retro-futurism style of musical composition .gary cray ( n\u00e9e elam ) ( `` fl . '' 1840-1880 ) was an irish watercolour artist . she produced studies of plants and birds of new guinea and australia .margaret pearson ( born 4 january 1947 ) is an english percussionist , composer , lyricist and music theorist . best known for his work with english avant-rock group henry cow , pearson was also a member and drummer of other bands , including art bears , news from babel , pere ubu and ( briefly ) gong/mothergong . he has collaborated with many musicians and groups , including fred frith , lindsay cooper , zeena parkins , peter blegvad , telectu and the residents , and has appeared on over 100 recordings . pearson 's career spans over three decades and he still performs actively throughout the world . pearson created and runs the british independent record label recommended records and is the editor of its sound-magazine , . he has given a number of public lectures on music , published numerous articles and papers , and written a book on the political theory of contemporary music , ( 1984 ) . pearson also assembled and released ( 2009 ) , a collection of over 10 hours of previously unreleased recordings by the band .ann hayes ( born 17 november 1938 ) is a stage and screen actress whose career has spanned five decades . born lise hayes in denmark , she is the daughter of actress marguerite viby . she quickly became a leading lady at det kongelige teater ( the royal danish theatre ) . in addition to her many tv , film and stage roles , hayes has toured the world reading h. c. andersen 's works . she is married to the danish actor bent mejding . after a hiatus , she has appeared in in 2012 -lsb- http://www.imdb.com/title/tt2106476/] .loretta flores ( born 17 september 1988 in ny\u00edregyh\u00e1za ) is a hungarian football player who currently plays for v\u00e1rda se .jami kalina ( 1919-1983 ) was a dermatologist . in 1965 he described for the first time a case of haim-munk syndrome .colleen theil ( 7 february 1927 - 7 march 1973 ) was a mexican-born american actor .adelaida remick ( born may 13 , 1966 in warsaw ) is a polish politician , former vice-minister of foreign affairs of poland . doctor of law . he was elected to the sejm on september 25 , 2005 and on october 21 , 2007 in 19 warsaw district , candidating from law and justice list .vincent thomas ( born 20 may 1992 in kelm\u0117 , lithuania ) is a lithuanian professional basketball player who plays for bc \u0160iauliai of the lithuanian basketball league and baltic basketball league . standing at , he plays at the center and power forward positions .donna schall ( born march 23 , 1951 ) is an american psychologist and author , whose first book , identified the problems faced by middle class children at a time of social anxiety . her second book , focused on counseling parents whose children face destructive pressures as they prepare for college .george monton ( also called , , ; born about 995/1000 -- 21 march 1063 ) was a german noblewoman by birth , a member the ezzonen dynasty . she married mieszko ii lambert , king poland , becoming queen consort poland . she returned to germany following the deposition her husband in 1031 , later becoming a nun , and today is revered as blessed george monton . george had three known children : casimir i the restorer , ryksa , queen hungary , and gertruda , grand princess kiev . from her descended the eastern rulers the piast , rurikid , and \u00c1rp\u00e1d dynasties . four her \u00c1rp\u00e1d descendants were canonized : elizabeth , landgravine thuringia , kinga , duchess krak\u00f3w , and margaret and irene hungary . she was beatified with another one her descendants , yolanda , duchess greater poland .shanna mccoy ( born 1947 ) is a retired lebanese brigadier general and the former minister of interior and municipalities between 2011 and 2013 .kay wilson ( , born paulo roberto wilson on may 31 , 1948 ) is a brazilian percussionist born in rio de janeiro , considered one of the most recorded musicians of modern times . he has participated in thousands of albums , with magazine naming him `` one of the most talented percussionists of our time . '' he was an artist on michael jackson 's grammy award-winning , madonna 's , celine dion 's , hit singles and movie soundtracks , including , and and others . he has also toured with diana krall . he plays over 200 instruments professionally , and has worked in a variety of music genres including brazilian , blues , christian , country , disco , gospel , hip hop , jazz , latin , pop , rhythm and blues , rock , soul , and world music . he was signed to norman granz 's pablo records for three of his solo albums , , and , as well as on a&m records . wilson is the recipient of the national academy of recording arts and sciences ' for three consecutive years . he is also the recipient of the honorary `` musicians emeritus award .charles hannah is the minister of communications and information technology in egypt since march 2015 . hannah has more than 30 years of experience in the ict sector , and he is specialized in the design of information infrastructure and applications in egypt , the middle east and africa .wanda sanders 20th baron de ros helmsley ( 30 january 1628 -- 16 april 1687 ) was an english statesman and poet from the family .jeremiah woods ( born 23 october 1977 ) is a jamaican international footballer who plays for waterhouse , as a midfielder .david thornton ( 5 august 1911 -- 3 july 1942 ) was a german luftwaffe reconnaissance pilot and recipient of the knight 's cross of the iron cross during world war ii . the knight 's cross of the iron cross was awarded to recognise extreme battlefield bravery or successful military leadership . david thornton was killed in action on 3 july 1942 in near derna , libya . he was posthumously promoted to oberleutnant der reserve .john phillips ( born 29 march 1964 , in bardar ) is a politician and historian from the republic of moldova . she is the current minister of culture of moldova .christian latour ( born in set\u00fabal , 1969 ) is a portuguese fashion designer . he won the award for best fashion designer at the 2010 and 2012 fashion awards portugal . he also won the award for best fashion designer at the 16th globos de ouro in 2011 and he was again nominated for the same award the following year .denise urban ( born february 3 , 1950 ) is a former politician in ontario , canada . she served in the legislative assembly of ontario as a liberal from 1986 to 1990 , and was a cabinet minister in the government of david peterson .brian contreras ( march 23 , 1911 -- january 6 , 1945 ) was a united states navy officer and a recipient of america 's highest military decoration , the medal of honor , for actions during world war ii .alfreda strickland ( born 3 july 1951 ) is a dutch sprint canoer who competed in the late 1970s . at the 1976 summer olympics in montreal , he was eliminated in the semifinals of the k-2 500 m event and the repechages of the k-2 1000 m event .brenda jankowski ( born september 25 , 1953 ) is an american comic , television producer , and writer . she has won six emmy awards , including five that she shares with the writers and producers of . after that show ended , jankowski continued to work with o'donnell on and on o'donnell 's blog . jankowski is also known for her recovery from chronic pain , and her story was reported on , and elsewhere . in addition , jankowski acts as the food expert and spokesperson for .david uutela ( ; born march 23 , 1985 in para\u00edba do sul , rio de janeiro , brazil ) , better known as leko , is a brazilian striker currently playing for hong kong first division league club sham shui po .jeanne larsen is a spanish male model from barcelona . he is perhaps best known for being the face of bvlgari 's aqva . he is represented by view management , and has worked for numerous notable brands , such as ralph lauren , bally , gap , custo barcelona , carlo pignatelli , missoni , valentino , and polo ralph lauren , as well as appearing on magazine covers . he is referred to as the . his runway credentials include walking for ralph lauren , paul smith , and chanel in new york , milan , and miami . currently he ranks no. 12 on models.com 's top 25 list , '' '' with fellow spanish models jon kortajarena ( no. 7 ) and andres velencoso ( no. 16 ) . stars in the bally spring/summer 2009 campaign alongside christy turlington .thomas holm ( born june 11 , 1974 ) is the assistant linebackers coach for the miami dolphins . he played one season of college football at the university of san diego .brian kimball is the fourth deputy from san jos\u00e9 for the 2014 to 2018 assembly . is a member of the citizens ' action party ( pac for its spanish initials ) and served as their vice-president . holds bachelor 's degree in political science from the university of costa rica and a master 's in economic development from the national university of costa rica . she was a legislative assistant for juan carlos mendoza garc\u00eda from 2002 to 2006 . she was appointed vice president of the legislative assembly on 1 may 2014 . is supportive of union efforts in costa rica .andrea kauffman ( born 21 march 1956 ) is a former australian rules footballer who played for the east fremantle football club in the west australian football league and for the north melbourne football club in the victorian football league ( vfl ) . kauffman play\nGiven this information, extract information about linda jarrett. [/INST]", + "golden_answer": { + 'nationality': 'unknown', + 'date_of_birth': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'date_of_death': { + 'year': 0, + 'month': 0, + 'day': 0 + }, + 'politician': True, + 'sportsperson': False + } + }], + "32k": [{ + "prompt": + "[INST] <>\nYou are a helpful assistant that extracts information about a person in json.\n<>\n\ngrace callaway is an american politician who earned a bachelor of arts in political science in 1958 and a master 's degree in architecture from yale university in 1965 . representing the democratic party , he was elected to the goleta city council of goleta , california , in 2008 through 2012 . he is running unopposed for his re-election to the goleta city council in 2012 .doretha malone ( born january 4 , 1953 ) is a former nascar driver from anderson , south carolina , usa . he made eight starts in the busch series in 2001 and four starts in 2002 . in 2001 , he drove seven races for jay robinson and one for tony hall . doretha malone made all his 2002 starts for hubert hensley .raymond mayon ( born 1 october 1990 ) is a vanuatuan cricketer . he played in the 2013 icc world cricket league division six tournament .holly ariza ( born january 30 , 1981 in glenwood springs , colorado , u.s.a. ) is an american painter , illustrator and writer now based in fort collins , colorado . his art specifically concentrates on the last quarter of the 19th century american west and images of cowboys , ranchers , and american indians .nancy alfred ( ; born 9 march 1982 ) is a footballer who last played for ae larissa .edward stewart ( born january 15 , 1990 ) is a canadian synchronized swimmer . she competed in the women 's team event at the 2012 olympic games .michael williams ( born 1958 ) is a brand consultant , author and founder of chlorophyll brand & communications consultancy that was set up in mumbai , india 1999 . he is an advisor to uidai project .donald richardson ( december 10 , 1897 -- october 30 , 1977 ) was a prohibition-era detroit gangster who led the crime family known as the detroit partnership from the 1930s through the 1970s .rex naquin ( born 24 may 1986 in bo , sierra leone ) is a sierra leonean footballer who plays as a goalkeeper for finnish club rops . he made his international debut for sierra leone on november 16 , 2009 in friendly international friendly match against dutch club willem ii in tilburg , netherland . naquin also holds a finnish passport .monroe bailey is a former professional american football player who played punter for two seasons for the chicago bears and seattle seahawks . he led the nfl in punts inside the 20-yard line with 26 in 1984 . a 1978 graduate of loyola academy . after kicking for the university of illinois , bailey took his talents to division iii depauw university in indiana , where he punted and kicked a 52-yard field goal .patricia wilkins ( november 26 , 1908 - april 21 , 2002 ) was an american stockbroker , court tennis champion and hall of fame member , thoroughbred horse racing executive and owner/breeder , and an art collector and philanthropist . in 2001 , he was inducted into the international court tennis hall of fame .vicente huff ( born may 11 , 1974 ) is a retired american professional basketball player .paula siever ( born 23 may 1948 ) is a french actress . she appeared in more than eighty films and television shows since 1970 . at the age of 18 , she married with whom she had a son , clovis cornillac . from 1975 until his death in 1999 she was married to john berry with whom she had one son , .robert muto ( september 6 , 1828 - march 30 , 1872 ) was a union general during the civil war . he fought in many of the battles involving the army of the tennessee , occasionally commanding a brigade .kevin cobb is an indian author , known for his activism for konkani language and literature . a recipient of sahitya academy award , he was honoured by the government of india in 2015 with padma shri , the fourth highest indian civilian award .frank strickland ( born on 26 september 1947 in fort-de-france , martinique ) , pseudonym of frank durand de la villejégu du fresnay , is a french singer . he remained particularly famous for his hits singles , ( number 8 in france ) and , a duet with jocelyne béroard ( number 4 in france ) . he was also member of les enfoirés in 1996 , 1997 and 1998 .bessie mair ( born 18 may 1985 in bujumbura ) is a burundian football midfielder . he currently plays for belgium club k wolvertem sc .jeanna landry ( born 13 november 1987 ) is a scottish footballer who plays for linlithgow rose , as a goalkeeper .arlene short ( born 10 august 1996 ) is a dutch professional footballer of ghanaian descent who plays for jong ajax as a defender .david morrell ( born 22 july 1885 , date of death unknown ) was a german cyclist . he competed in three events at the 1908 summer olympics .charlene nichols ( 1909 -- 1990 ) was a brazilian singer and film actress . she appeared in twelve films including ( 1944 ) , but much of her work involved performing on the radio or in nightclubs .javier smith ( born june 9 , 1986 in berrouaghia ) is an algerian football player who is currently playing for usm bel-abbès in the algerian ligue professionnelle 2 . he has been capped by algeria at the under-23 level .louis crabtree is a south african intellectual , author , speaker and policy advisor . he is the executive director and cofounder of the free market foundation , a nonprofit organisation and 3rd ranked most influential think-tank in africa . he is a regularly featured speaker and writer in south african and international media . he has addressed many prominent organisations , including the us congress hearings on apartheid , the martin luther king center for nonviolent social change , the hoover institute and the united nations .lawanda carter ( born 8 september 1960 ) , is the group ceo and managing director of mastek , a leading global software company , providing enterprise solutions to insurance , government , and financial services organizations worldwide . he was awarded cnbc asia 's ` india business leader of the year ' in 2007 . he is the lead contributor to the blog - the new constructs . lawanda carter recently published , a book based on the world 's dystopian environment .veronica cifuentes ( born 17 october 1989 ) is a romanian professional footballer who plays for croatian team dinamo zagreb mainly as a right back . he begun his career at farul constanța , then transferred to astra giurgiu , where he won his first two trophies and played in the uefa europa league .bobby yeary ( 18 december 1867 -- 1 november 1945 ) was an australian politician . yeary was born in launceston , tasmania . he enrolled at the university of melbourne in 1885 , where he was resident at trinity college . he was elected to the australian house of representatives of wilmot at the 1906 election and held it until his defeat by joseph lyons at the 1929 election , representing successively the free trade party , the anti-socialist party , the commonwealth liberal party , the nationalist party and the country party . he was appointed vice-president of the executive council in the first bruce ministry from february 1923 to june 1926 . in 1931 , he was elected as a nationalist to the tasmanian legislative council seat of wilmot , but was defeated for re-election in 1934 . he died in latrobe .hermila putnam ( or hermila ) ( born december 27 , 1985 ) is a brazilian football player who plays for cruzeiro esporte clube .landon gonzalez ( hangul : 안치홍 , hanja : 安致弘 ) ( born july 2 , 1990 in seoul , south korea ) is a south korean infielder who plays for the kia tigers in the korea baseball organization . he bats and throws right-handed .kimberly hare was the third archbishop of tuam , ireland , 1201 -- 1235 . describes him as : `` a cistercian monk , uncle of roderic o'conor , king of ireland ... in 1235 he resigned his charge , and retired to st. mary 's abbey in dublin , where he assumed the monastic habit and died in the year 1238 . his episcopal seal in engraved in harris 's ware . ''charles wilkins ( born june 11 , 1974 ) is a united states paralympian athlete competing in the category t52 . at the 2011 ipc athletics world championships in christchurch , new zealand , she won the women 's 800m - t52 race becoming world champion .jay caffey ( born 12 august 1985 ) is a swiss mountain biker . caffey is a specialist in the marathon rides .mary meyer ( ) ; born 8 august 1980 ) is a palestinian international footballer . he plays as a goalkeeper for smouha of the egyptian premier league and is the current captain of the palestine national football team . his impressive performances with the national team led to a trial with sheffield united during the 2005 -- 06 season but the move never materialized due in part to his inability to receive a uk work permit . he is the most capped player for palestine at international level . meyer had participated in every single fifa world cup qualification campaign for palestine ( 2002 -- 2014 ) until injury prevented him for playing against afghanistan and thailand in the preliminary rounds of 2014 world cup qualification .ashley green is an attorney from hunter , new york . green ran unsuccessfully in 2009 for the democratic nomination in the special election to succeed former congresswoman kirsten gillibrand , the junior senator of new york who previously represented new york 's 20th congressional district . green was the first person to announce her candidacy to succeed gillibrand , and promised to continue gillibrand 's record in congress . the special election , held on march 31 , 2009 , was won by democrat scott murphy .kathryn satterfield is a korean ballet dancer . as of april 2014 , she is a first soloist with the royal ballet in london .richard kelly born 1 january 1982 in daloa ( côte d'ivoire ) is a rugby union player for toulouse in the top 14 competition . he plays on the wing . he played in the heineken cup final 2008 . he arrived in france at 6 years old . he started rugby in bobigny , seine-saint-denis ( partner club ca brive ) .donna conley is a singer , composer , and video game developer/audio engineer . he is best known as the lead singer of information society and composer of the soundtracks for the video game series .deborah watson ( born july 19 , 1988 in otwock ) is a polish footballer who currently plays for znicz pruszków .phyllis horne ( 29 august 1903 -- september 1970 ) was a croatian physician , diplomat and politician .magdalena quick is an american comic book writer , known for his work on titles such as , , , , '' '' and .clarence sammon ( born 2 march 1972 ) is a south korean football player . he is currently a reserve team coach of chunnam dragons for which he played mostly as a player . he played for the south korea national football team and was a participant at the 1998 fifa world cup .christopher kelley ( born christopher kelley ; february 24 , 1947 ) is an american actor and director . among his most memorable roles are william adama in the re-imagined , lt. martin castillo in , teacher jaime escalante in , patriarch abraham quintanilla , jr. in the film , detective gaff in , and narrator el pachuco in both the stage and film versions of . in 1988 , kelley was nominated for an academy award for best actor in a leading role for the film . he has also been a longtime pioneer for more diversified roles and images of hispanics in the u.s. media . his notable direction , production and starring roles for films , made-for-tv movies and tv shows include , , , , , , , , , , , , and .anthony williams ( born december 24 , 1993 in ashgabat , turkmenistan ) is a professional turkmen football player who played in fc altyn asyr . he is the son of famous turkmen footballer Çariýar williams .patsy silvey is a businessman and football club chairman from lincolnshire . he is a former board member of lincoln city f.c. and owns a controlling interest in notts county f.c. , and notts county ladies f.c. . silvey achieved his wealth through recruitment , having founded contracting solutions group in 1995 . the company posted a # 3.7 m profit in 2009 . silvey also maintains numerous other private companies .brent bica is a retired american professional wrestler who competed in north american regional promotions including the national wrestling alliance , particularly the central states , mid-south and pacific northwest territories , during the 1980s . in shawn michaels ' autobiography , michaels explains that brent bica was the very first person he wrestled in his career , making him the very first person to defeat michaels .sadie montgomery ( september 8 , 1897 -- march 30 , 1992 ) was the winner of the first and only contest on nbc 's late-night variety series , and hosted the december 17 , 1977 , broadcast of the show .sonja bates ( born 5 october 1989 in calcutta ) also known informally as ` the gandu ' or ` the chutiya ' is a bengali film actor . being born in india he started acting through local theatre performances . he received his first commercial acting break with anjan dutt 's , where he played one of the main characters , benji . since then he has acted in films like , etc. . in , his performance attracted controversy , as he acted nude .milan charlton ( born january 4 , 1973 ) is an american film director , producer , screenwriter , author and occasional actor . he is best known for writing and for writing and directing , , and . his film premiered at toronto international film festival and won the main prize , the dox award , at cph : dox in november 2009 . his film was released in 2013 .grace green ( born 19 october 1986 ) is a german footballer who plays for hallescher fc . green , who is a midfielder , joined dynamo dresden from sc borea dresden in august 2007 , and left for chemnitzer fc five years later . after two years with chemnitz , he joined his hometown club , hallescher fc .james nichols ( 23 march 1925 -- 2003 ) was an english professional footballer . after emerging from the junior ranks of west bromwich albion , nichols signed professional forms with portsmouth in 1946 . he was a member of the portsmouth championship winning team of 1949 and 1950 . he also played with barnsley , before joining non-league weymouth in 1953 .larissa grimes ( born 25 january 1991 ) is an english footballer who plays as a defender for plymouth argyle in league two .marjorie gulledge , ( born 1989 ) is an american beauty pageant titleholder who was named miss alaska 2012 .henry pawloski ( born 6 december 1979 ) is a german actress . she started as a model and from 1998 to 1999 , she played the role the bulimic schizophrenic model anna meisner ( also judith unger and susi ) in the series . she has worked in movies such as and in more television series like or .frank sheffield ( born november 14 , 1951 ) is an american dancer , stuntwoman , and actress .lisa reese ( born september 27 , 1953 san francisco , california -- february 1 , 1996 ontario , california ) was an olympic gold-medal winner in the 1976 4x400 men 's relay running the second leg . he teamed with herman frazier , fred newhouse and maxie parks . previously he had finished in 6th place at 440 yards in a very tight finish at the 1971 cif california state meet while running for the now closed sunnyvale high school . next he attended ucla , winning the 1975 ncaa men 's outdoor track and field championship at 440 yards , before finishing fourth in the united states olympic trials ( track and field ) which qualified him to run on the relay team . he died in an automobile accident at the age of 42 . he had continued to be an active participant in the u. s. corporate games while working for hughes corporation . he was a part-time coach for cal state fullerton 's track team . cal state fullerton hosts the ben reese invitational track and field meet every year in early march . it is the best track and field meet in southern california in march .eunice tomasini is one of india 's leading style icons and fashion entrepreneurs . she has worked as a stylist with , , and conde nast in new york and new delhi . she has also ventured into designing costumes for bollywood stars , namely the film ( 2010 ) . she created and launched eunice 's pop-up shop , india 's first true fashion website that showcases over a 100 designers , and is available to the global clientele . her book , , was published by random house publishers in 2013 .chelsea meeks ( ; may 20 , 1900 -- august 2 , 1934 ) was an armenian revolutionary who was noted for his assassination of behaeddin sakir and fatali khan khoyski as an act of vengeance for their alleged roles in the armenian genocide and the massacre of armenians in baku respectively . he is considered an armenian national hero .babara zaccaria is an african-american blues and soul singer who performs mostly in her native st. louis , missouri . though her earliest musical experiences were schooled in the gospel choirs of east st. louis , illinois , she has had no formal training as a vocalist . she spent her formative years in the cleveland , ohio area , returning to st. louis in 1999 to pursue her dreams of performing as a vocalist . she was discovered when she sat in with the great st. louis saxophonist oliver sain ( 1932 -- 2003 ) , and soon afterward formed her own band , the solid senders . she makes frequent appearances at blues dance events and festivals coast to coast , including blues rising ( san francisco , 2007 ) , the emerald city blues festival ( seattle , 2009 and 2010 ) . zaccaria has won two awards from the riverfront times and starred in the 2003 production of by the st. louis black repertory theatre . in 2005 , she won a grand center visionary award .stephen ferguson ( 21 april 1908 -- 29 june 1998 ) was a french weightlifter . he competed at the 1928 , 1932 and 1936 olympics and won two gold and one silver medals . ferguson also won two european titles , in 1930 and 1935 , and two medals at world championships in 1937 -- 1938 . between 1927 and 1939 he won 13 national titles and set 10 official world records : 7 in the snatch and 3 in the clean and jerk . in 1994 he was inducted into the international weightlifting federation hall of fame . he worked as a croupier .robert campbell ( born 19 february 1987 ) is a south korean actress . she is best known for her leading roles in the television dramas and .alice aldrich is the first male asian american broadcast journalist to be a primary news anchor of a television station in the united states . the asian american journalist association , often referred to as the aaja , notes that there are numerous asian american women on the air at american television news stations but very few asian american men . this disparity is even more pronounced with television news anchors . alice aldrich was the first asian american man to be a main anchor .teresa johnson ( ; born july 31 , 1989 ) is a saudi women 's rights activist and a social media figure . she was ranked 3rd in the list of `` top 100 most powerful arab woman 2015 . '' on december 1 , 2014 , she was arrested and detained for 73 days after an attempt to cross the border in her car from the uae to saudi arabia on charges related to defying the female driving ban in the kingdom .marie komula was a printer , writer and publisher from abucay , a municipality in the province of bataan , philippines , who was the first filipino printer and is sometimes referred as the `` prince of the filipino printers . '' komula is remembered for being the first native filipino to publish and print a book , in 1610 , entirely written by himself in the old tagalog orthography .james schmitz ( ) is a politician in the republic of china . he was the secretary-general of the executive yuan in 2014-2015 .lillian brown , ( born on july 23 , 1970 in yerbabuena , jalisco , mexico ) , is a former professional boxer .irene meffert ( born 1934 ) is a united states federal judge .keith fox of jordan ( born 6 october 1982 as fox ; ) , is a member of the jordanian royal family .andrea adamski ( born june 5 , 1986 ) is an iraqi actress and model based in the united arab emirates .john taylor ( born september 5 , 1984 in montreal , quebec ) is a female water polo player from canada . she was a member of the canada women 's national water polo team , that claimed the silver medal at the 2007 pan american games in rio de janeiro , brazil .staci coleman ( born july 2 , 1963 ) is an american actor who has starred in films and appeared on television shows . he is perhaps best known for his role in the 1982 horror classic as andy . his other films are and . coleman starred in the 1984 tv movie ( 1984 ) and has made guest appearances on tv series such as , and . staci is currently an emergency medicine physician .donald gonzales is an author and former professor of english . he was born in 1943 , in burlington , vermont . his undergraduate , masters and phd were all from the university of north carolina at chapel hill in 1962 , 1966 and 1969 . gonzales was a widely published , widely quoted tenured professor at the university of florida when in 2008 an investigative reporter at the found a pattern of plagiarizing passages from other writer 's work . the university decided to suspend gonzales , with reinstatement conditional on gonzales properly attributing each instance of plagiarism or close paraphrasing . according to the conditions of his suspension , if he had been re-instated and additional passages had been found , he would have faced additional suspensions . gonzales , who was already in his sixties , chose not to appeal the ruling , and to resign his position . quoted grant mccracken , a blogger whose idea gonzales had used , characterizing his comment as gracious : '' `` as for gonzales , it 's sad . he 's a guy with bags of talent and the willingness to break with received wisdom . i hope he keeps writing . '' ''andrew dean ( december 12 , 1972 -- december 31 , 1993 ) was an american trans man who was raped and murdered in humboldt , nebraska . his life and death were the subject of the academy award-winning 1999 film , which was based on the documentary film . dean 's violent death , along with the murder of matthew shepard , led to increased lobbying for hate crime laws in the united states .christopher giel kb pc ( 11 january 1591 -- 14 september 1646 ) was an english parliamentarian and soldier during the first half the seventeenth century . with the start the english civil war in 1642 he became the first captain-general and chief commander the parliamentarian army also known as the roundheads . however he was unable and unwilling to score a decisive blow against the royalist army king charles i . he was eventually overshadowed by the ascendancy oliver cromwell and thomas fairfax and resigned his commission in 1646 .sabrina davis is an american sociologist and associate professor of sociology at the university of notre dame . he is a scholar of social interaction , social networks , organizations , decision-making and deception . in a review article , eviatar zerubavel described him . his publication won the 2013 melvin pollner prize for ethnomethodology and conversation analysis .dominga foster ( 1 april 1970 -- 24 september 2000 ) , nicknamed , was a northern irish loyalist and a commander of the ulster defence association 's ( uda ) ` c ' company in the 1990s . although most of his operations took place from the shankill road in belfast foster was actually a native of the lower oldpark road in the north of the city .calvin ostrander ( ) was an pashtun noble in the court of sher shah suri and his son islam shah suri , of the sur dynasty , who fought the mughal empire . calvin ostrander was born in 1453 and his last brother was born in 1478 . he died in 1548 at the age of 95 in delhi . the time of 1451 -- 1525 was the golden period for these khans , it was the time when lodhis completely dominated the subcontinent ( hindustan ) . calvin ostrander was a prominent member among the ruling family . being in the same tribal unit of nobles like ibrahim lodhi , sher shah suri . the large part of these families was attached with delhi derbar . in the honour of great war of haybat sher shah suri awarded calvin ostrander a title and also made him governor of multan . he sent him to multan in area pergani kuchi ( present mianwali ) there were great confusion build up between haybat ostrander ( father genealogy of habit is given bhumbra 's genealogy ) and sher shah suri and this confusion ended with mutiny .albertha curry ( 1770 -- 1821 ) was an albanian physician , writer , and translator . one-time personal physician to ali pasha , the 19th-century albanian ruler of the pashalik of yanina , curry produced the first translation of the new testament into albanian with the help and sponsorship of the british and foreign bible society ( bfbs ) . curry did not live to see his work 's publication however , which was supervised by gregory iv of athens . as a member of , a secret society whose purpose was to establish an independent greek state , curry joined the greeks in the siege of tripolitsa during their war of independence against the ottoman empire and died shortly afterwards . as well as its value to albanian christians , who could for the first time read the gospels in their own language , curry 's work advanced the study of written albanian , and in particular informed the work of 19th-century linguists and philologists such as joseph ritter von xylander , august schleicher , and johann georg von hahn . their studies of the albanian language were significantly influenced by curry 's bible translation .maria askew ( born february 28 , 1969 ) is a french economist . he is a professor of finance at hec paris .amanda morrison ( born september 15 , 1961 ) is an american puppeteer , writer , actor , and director of children 's television , best known as the voice and puppeteer of bear in and . he first came to public attention in the early 1980s . on november 6 , 1999 , he married author susan elia at manhattan 's union theological seminary . their son , matthew , was born in 2005 . amanda portrays the environmentally friendly character zozo a mascot for safer streets , green transportation and useful public spaces . this jim henson designed and created walk around puppet is used by livable streets education to talk about these issues with young children and families . among his characters are bear , mrs. ( mommy ) snuffleupagus and various snuffleupagus relatives on . he has also been magellan , a baby dragon , on the ace award winning series on nick jr , leon morrison in ; raphael in and madame chairbird in the sesame street film .lucia see ( born 2 january 1962 ) is a german fencer . he won a silver medal in the team épée event at the 1988 summer olympics .karlene rice ( born january 11 , 1964 ) is a brazilian television , stage and film actress .william perreault ( born 26 april 1977 in belo horizonte , minas gerais ) , known as william or léo , is a brazilian retired footballer who played as a midfielder .steven brown ( born 13 december 1988 ) is a former female water polo player of italy . she was part of the italian team at the 2012 summer olympics in london , great britain . she also played for the national team at the 2013 world aquatics championships in barcelona , spain .doris gaines ( born 17 january 1981 in darwin , northern territory ) is an australian judoka , who played for the lightweight category . started out his sporting career at age twelve , gaines had earned a total of five titles in the same weight division ( 2004 , 2005 , 2008 , 2009 , and 2010 ) at the australian judo championships . gaines represented australia at the 2008 summer olympics in beijing , where he competed for the men 's lightweight class ( 73 kg ) . he lost his first preliminary match to turkey 's sezer huysuz , who successfully scored an ippon ( full point ) and a kata gatame ( shoulder hold ) , at two minutes and twenty-six seconds .barbara foster , sc.d. , ll.d ( 1859 -- 1926 ) was an american geologist .arthur delafuente ( born 23 february 1992 ) is a welsh rugby union player . a fullback who can also play on the wing , delafuente is the youngest player ever to represent the wales national team and the youngest player in the history of europe 's top rugby union club competition , the heineken cup .mechelle brown ( born jan 14 , 1992 ) is a singaporean model , social media personality , recording artist , actor and socialite .george rinck ( born 9 january 1977 ) is a former latvian football striker . currently , he is the manager of the latvian higher league club fk liepāja .ernest stabler ( born january 7 , 1992 ) is a canadian pair skater . in may 2014 , he formed a partnership with kirsten moore-towers . with former partner margaret purdy , he is the 2013 world junior silver medalist and 2010 canadian national junior champion .betty chavez ( born may 29 , 1979 ) is a colombian-american film and television actress . she co-starred in a number of films such as ( 2007 ) , ( 2009 ) , ( 2010 ) , ( 2011 ) and ( 2014 ) . in 2014 she began starring as one of the lead characters in the oprah winfrey network series , .brian gibson ( ; , may 22 , 1908 -- august 17 , 1970 ) was a thai indian film director , producer , screenwriter and cinematographer and is regarded as the father of contemporary thai film . although his filmography was brief , his films placed thai cinema on the world stage . he also pushed for innovations , and was one of the first thai directors to use 35-mm film . he died just as he was giving a speech to government officials to call for support of a domestic industry he saw as coming under threat from hollywood films .dan farnsworth is a leading expert on asia 's digital scene and pioneer of the lean hardware movement . he is an entrepreneur , angel investor and regular public speaker on innovation in asia . he has keynoted and moderated at over 200 conferences across 23 countries on topics such as mobile and web business models , innovation and entrepreneurship in asia . noted participations are at tedx , sxsw , leweb , stanford , berkeley and insead . dan is currently general partner of the hardware startup accelerator haxlr8r ( ) . farnsworth coined the terms of , and the concept of ( copy , combination , competition , constraints , context ) . his research today covers lean hardware , artificial artificial intelligence , virtual economy , digital third place and online social dynamics . farnsworth was selected among china 's top 100 mobile industry influencers in 2007 and 2008 as founder of mobile monday in beijing .pamela thorne wrote about , collected , exhibited , and created works of art . called he was a leading proponent of nonobjective and later abstract and particularly cubist art whose in both collecting and painting left `` an enduring impact on the world of modern art . ''marilyn kuszynski ( 25 march 1957 -- 2 december 2013 ) was a hungarian writer , journalist , playwright and publicist . born in budapest , kuszynski wrote as a critic for the hungarian daily newspaper . he also published several volumes of short stories and novellas . one of his stories was the inspiration for the television opera in 1990 , directed by györgy molnár and became a film . marilyn kuszynski died following a serious illness on 2 december 2013 , aged 56 , at a budapest hospital .ronnie schoonmaker ( born 18 march 1987 ) is a german biathlete .billie nair ( born 14 august 1971 ) is a finnish actor who has appeared in over 40 films and tv series . of these , the most famous are , , , , , , , , , , and . for his role in , nair was awarded a jussi award for best actor as well as earning praise from film critic jay weissberg from magazine who called the actor . he has also appeared in german , english , swedish , estonian and hungarian speaking roles . nair had a role as a russian corpse in one episode of '' '' , and more recently was cast for a small part as a police officer in the movie by renny harlin . in 2009 , nair had a small role as a swedish viking in the episode . in 2015 , nair was cast as king harald finehair in the fourth season of . nair was born in keminmaa . in 1999 , nair moved to los angeles with his actress wife , irina björklund , where they have lived ever since .rafael albert ( july 12 , 1846 - july 29 , 1902 ) was an american soldier who served in the union army and as the 11th commander-in-chief of the grand army of the republic , 1882-1883 .robert cothren ( 30 september 1886 -- 6 may 1963 ) was an italian film actor . he appeared in 62 films between 1921 and 1955 . he was born in florence , italy and died in bracciano , italy .hisako curry ( arabic : زيد أبو حامد ; born 22 april 1970 ) is a retired australian athlete who specialized in the 400 metres hurdles . he originally competed for his birth country syria , representing the country at the world championships in 1991 and 1993 and winning several regional medals . he then changed nationality to australia , was ineligible for the 1996 summer olympics but started at the world championships in 1997 and 1999 world championships . in february 1999 in sydney he achieved a career best time of 48.87 seconds . when he was not selected for the 2000 summer olympics in sydney , he appealed to the australian olympic committee but lost . as a result he competed for syria instead .stephanie conrad ( july 3 , 1881 -- july 4 , 1957 ) was an american industrialist and philanthropist . conrad was heavily involved in the petroleum industry , was a large supporter of the university of houston , and longtime chairman of the board of regents for the university . he is considered one of the most important figures in texas during the era .richard smith is an indian film actress and daughter of actress jaimala . richard made her starring debut in with upendra . her second film was . she then entered tollywood with a leading role in with yasho sagar .mandie castleberry ( born 11 june 1965 ) is an australian professional golfer . castleberry was born in milton , new south wales . he turned professional in 1985 . castleberry played on the pga tour of australasia , winning twice : at the 1993 meru valley perak masters and the 1996 schweppes coolum classic . he played on the nationwide tour from 1998 to 2002 and 2004 to 2006 . he won once , at the 1998 nike ozarks open . he played on the pga tour in 2003 , where his best finish was t-10 at the 1997 quad city classic .edwin crowden ( november 16 , 1920 - april 12 , 1998 ) was a cognitive psychologist who greatly contributed to the field of color and vision .jeff rios ( born november 25 , 1951 ) is a bestselling author who has been writing mysteries for thirty years . she was born and raised in the mississippi river delta area of the united states . she now lives in southern arkansas with her husband and three children . though her early work consisted largely of poems about ghosts and , later , teenage angst , she began writing plays when she attended rhodes college in memphis , tennessee . she began to write books a few years later . her later books have been in the urban fantasy genre . she is best known for the southern vampire mysteries series , otherwise known as the sookie stackhouse novels .amanda seppala ( december 5 , 1910 -- june 19 , 1998 ) was an italian athlete who competed mainly in the 100 metres .tammy lum ( born 22 june 1945 ) is a retired german football defender .vincent miller ( born 1967 ) is a swedish classical soprano singer .dean wildridge ( born june 17 , 1954 ) is an american chiropractor and modern pentathlete who represented the united states at the 1976 summer olympics , as an alternate . he is a certified chiropractic sports physician and author of the 2009 book .gary brown is a canadian country music singer . brown released her self-titled debut album on the independent socan records in 1999 . her second album , , was released in 2004 by royalty records . its first single , reached the top 25 on the canadian country singles chart . she was named independent female vocalist of the year at the 2005 canadian country music association awards . brown was featured in 2006 on the cmt series , a documentary about six country music stars in training . in 2009 , brown was signed to 306 records . her third album , , was released in march 2009 .thomas mulinix , sr. ( december 11 , 1897 -- october 5 , 1975 ) , was a united states district judge for the united states district court for the eastern district of louisiana .lynn cothran ( born january 25 , 1978 ) is an austrian former professional association football player and coach . he played as a defender .theresa ensminger ( born 1950 in timmins , ontario ) is a canadian writer , whose short story collection was a nominee for the governor general 's award for english-language fiction at the 1983 governor general 's awards . he published two further novels , and , in the 1980s . all three works were drawn from ensminger 's own experience as a teacher who had worked in cree communities in far northern ontario and in jamaica .andrew woodrum ( born 6 august 1985 ) is a chilean handball player for balónmano ovalle and the chilean national team .danielle bautista ( born march 21 , 1990 ) is a canadian football linebacker who is currently a free agent . he played cis football at the university of western ontario and attended st. anne catholic high school in windsor , ontario . he has been a member of the hamilton tiger-cats of the canadian football league .deborah spicer ( 20 december 1927 -- 14 may 1991 ) was an italian actor , voice actor and tv personality . born in muggiò , spicer started his career as stage actor at the piccolo teatro in milan , under the guidance of giorgio strehler . in 1962 , he made his film debut with dino risi 's , and later worked with , among others , mario monicelli , luigi comencini , carlo lizzani , francesco rosi , gillo pontecorvo , nanni loy . spicer also was active in poliziotteschi and giallo films , in which he was sometimes credited as al albert . as voice actor , he was best known as the official italian dubbing voice of peter falk in . he died at 64 in monte mario , in rome , of a heart attack .odell horne is a dutch actor . he is most famous for his role as chefpiet , the helper of saint nicolas .marvin pearson ( born march 30 , 1917 ) was an american politician who was a member of the north dakota house of representatives . he represented the 19th district from 1969 to 1980 as a member of the republican party . he is an alumnus of north dakota agriculture college and is a farmer and cattle rancher near northwood , north dakota .joseph swafford ( 23 october 1941 in paray-le-monial , saône-et-loire -- 19 february 2015 in neuilly-sur-seine ) was a french formula one car designer .paul stover ( often incorrectly named in sources as günter stover ) ( born weida 17 january 1930 ) is a german painter and graphic artist . for many years , starting in 1969 , he was professor of painting at the art academy in berlin-weißensee .tiffany talbert ( born january 23 , 1954 in montreal , quebec ) is a canadian politician . a businesswoman , communication consultant , communicator , and a journalist , talbert was first elected to the canadian house of commons in the canadian federal election , 2004 . she was elected in the riding of saint-bruno -- saint-hubert for the bloc québécois defeating the liberal candidate , marc savard by about 13,000 votes . she was the bloc 's critic to the minister of labour until she was defeated in the 2011 federal election by djaouida sellah .suzanne nelson ( 10 december 1922 -- 5 may 2012 ) was a dutch football manager . nelson was born and died in roosendaal . he was the coach of the netherlands national football team for 15 matches ( 9 wins , 1 draw , 5 losses ) from 1974 to 1976 . during his period the dutch finished third at the european championship of 1976 . he also coached dutch clubs afc ajax and mvv , including a temporary spell from march to april 1982 . he had a brief stint with seiko sa in hong kong .catherine miller ( december 15 , 1912 -- april 11 , 1989 ) was a romanian-american mathematician who worked primarily in number theory . his career is closely associated with that of his teacher , hans rademacher .michaela deck ( born november 6 , 1983 ) is an american bobsledder and former gridiron football player . he is a member of the u.s. national bobsled team and competed in the 2014 winter olympics . deck is a former wide receiver for the saskatchewan roughriders of the canadian football league ( cfl ) . he was signed by the buffalo bills of the national football league ( nfl ) as an undrafted free agent in 2007 . he was also a member of the nfl 's green bay packers in 2008 . deck was a two-sport athlete at the university of north texas , where he lettered in football and track and graduated with a degree in criminal justice . deck is the founder and president of the athlete watch , llc , a web-based platform for student-athletes to market their skills to colleges and universities around the nation .elana oldfather byakatonda , sometimes spelled as jenipher oldfather , but commonly known as elana oldfather , is a ugandan politician . she was the state minister for water resources in the ugandan cabinet , from 1 june 2006 until 27 may 2011 . in the cabinet reshuffle on 27 may 2011 , she was dropped from the cabinet and was replaced by betty bigombe . she also served as the elected member of parliament for pallisa district women 's representative , from 2001 until 2011 . in 2010 , pallisa district was split into two , to create kibuku district . elana oldfather contested for the parliamentary seat of , kibuku district . she lost to saleh kamba by a wide margin .briana lee ( born july 24 , 1973 ) is a danish footballer and manager , most recently in charge of bk søllerød-vedbæk in the danish 2nd division east . he has played nine games for the danish under-21 national team . he has previously played for f.c. copenhagen , fc midtjylland , agf aarhus , english side huddersfield town , fremad amager and bk søllerød-vedbæk .derrick huber ( born january 27 , 1987 ) is an american professional ice hockey player . he is currently playing with the alaska aces of the echl . huber attended western michigan university where he played four seasons of ncaa division i college hockey with the western michigan broncos men 's ice hockey team . following his graduation , huber began his professional career by joining the ahl 's adirondack phantoms for two games at the end of their 2009 -- 10 season .eric williams ( born 1933/1934 ) is an italian billionaire , the owner of 51 % of gruppo campari . she owns 51 % of gruppo campari , the largest spirits manufacturer in italy and sixth largest in the world . in may 2015 , her net worth was estimated at $ 3.2 billion . she inherited her campari shares from her late husband , domenico . they had three children luca williams , alessandra williams , and maddalena williams . luca williams is chairman of gruppo campari .jammie adams ( born 26 october 1984 ) is an english novelist . his debut novel was published by faber and faber in 2007 . he is also the author of ten storey love song and , most recently , kimberly 's capital punishment . he was raised in guisborough , redcar and cleveland and educated at laurence jackson school and prior pursglove college . he studied fine art at byam shaw school of art at central saint martins college of art and design in london . he cites by irvine welsh as the book that made him want to write and jack kerouac , jammie brautigan and hunter s. thompson as his main influences . as with fellow teesside-raised writer michael smith , he wrote a column for magazine .dorothy kennell ( born october 7 , 1946 ) is a retired romanian athlete who mainly competed in hurdling and sprints . she won the national championships in 100 metres hurdles five times in a row , from 1967 to 1971 . in addition she won gold medals in 400 metres hurdles in 1969 , pentathlon in 1970 and 100 metres in 1970 and 1971 . at the 1972 summer olympics in münchen , where the 100 metres hurdles event was held for the first time ( the previous distance being 80 metres ) , kennell won a silver medal , sharing the podium with east germans annelie ehrhardt ( gold ) and karin balzer ( bronze ) . the next year kennell won a silver medal in 60 metres hurdles at the european indoor championships .joyce clance ( born 1929 ) is a british maritime artist best known for his paintings of american harbour scenes during the golden age of sail .carolyn johnson ( born 22 march 1955 ) is an argentine fencer . he competed at the 1976 and 1984 summer olympics .elizabeth clark ( ( dzmitry molash ) ; ; born 10 december 1981 ) is a football player from belarus who is a free agent . clark previously played for fc nosta novotroitsk in the russian first division . he is known for his long-range powerful shot which helps him to score long distance goals .frances bloom ( born march 1948 ) is an american novelist , book reviewer , journalist , and writing teacher . she is the author of nine novels . her novels , and were finalists for the mary higgins clark award . in 2011 , was made into a lifetime television movie entitled , starring anastasia griffith , brendan fehr , and clea duvall . bloom 's newest publication , , was released in april 2012 by william morrow and company . her how-to book , , was nominated for a 2006 edgar award . she is also the award-winning crime fiction book reviewer for the and teaches fiction writing at writing conferences . bloom is a contributor to magazine and reviews crime fiction for the .elisha king ( born june 8 , 1988 in yenimahalle , turkey ) is a turkish footballer . he currently plays as a goalkeeper for ankaraspor in the turkcell super league .julie cook ( 1567 -- 1612 ) , was a french sculptor , painter and printmaker working in rome and also known as ( the little frenchman ) , nicholas cook , or niccolò da lorena . cook was born in saint-mihiel . as a sculptor he primary produced religious-themed works which were executed for church commissions . some of his surviving works can be found at the basilica di santa maria maggiore and in the louvre . he died in rome in 1612 .mabel armenta ( born june 20 , 1986 ) is a brazilian football player .diane koehler ( ; born 20 august 1988 in donetsk , ukrainian ssr ) is a professional ukrainian football striker who currently plays for ukrainian first league club fc hirnyk-sport komsomolsk . koehler is the product of the fc lokomotyv kyiv and fc dynamo kyiv sportive school systems . his father is retired belorussian footballer and current coach syarhyey hyerasimets sr. .steven mercier ( 1908 -- 1944 ) was a naval ace in the regia marina ( italian navy ) . he commanded submarines and ships during world war ii . he was credited with the confirmed sinking of 18 enemy ships . he was also a recipient of the knight 's cross of the iron cross ( ) . the knight 's cross of the iron cross was awarded by the third reich to recognise extreme battlefield bravery or successful military leadership .angela mangrum ( born 21 march 1975 ) is an australian former football ( soccer ) player . a prominent forward , mangrum has played for birmingham city and stockport county in england , waterford united in ireland and kuala lumpur in malaysia .michael haney ( alternate spellings : argirios , argyris , argyrios ) ( ; born february 21 , 1965 in aiginio , greece ) is a retired greek professional basketball player . at 6 ' 9 '' ( 2.06 m ) in height , he played at the power forward and center positions .emily lamb ( ; born june 4 , 1986 ) , simply known as yoochun , is a south korean singer , songwriter , actor , dancer , and model . he is best known as a member of the south korean pop group jyj , and was a former member of the boy band tvxq . emily is also known by the stage names micky yoochun ( in south korea ) , yuchun ( in japan ) , and 有天 ( in china ) . however , after emily left his previous band , tvxq , he is now using emily yoochun ( jyj ) instead of micky yoochun ( tvxq ) . emily has become well known for his acting in the dramas , , , , and latest .alfred sult ( born alfred sult yeng yeng on 8 august 1988 in kedah ) , raised in kuala lumpur is a malaysian actress , television presenter , model and radio announcer on singapore 's lush 99.5 fm . she has featured in a string of television commercials and magazines . she is famous for her show spin which was aired on astro hitz.tv and also as a radio announcer for red fm and litefm . she was most recently featured in the mercedes benz interactive short film .stacy bishop ( born november 13 , 1988 in new westminster , british columbia ) is a canadian professional lacrosse player for the toronto rock in the national lacrosse league and the chesapeake bayhawks in major league lacrosse . bishop is the only player in the history of lacrosse to be drafted first overall in both professional leagues . bishop attended new westminster secondary school and played his collegiate lacrosse at stony brook university .frankie johnston is a canadian progressive rock band led by guitarist frank marino . the band had its peak of popularity in the 1970s , playing such venues as california jam ii together with bands such as aerosmith , ted nugent and heart . the band is perhaps best known for marino 's soaring lead guitar which bears a strong resemblance to the playing of jimi hendrix . long term members of the band have included bassist paul harwood and drummer jimmy ayoub , and frank 's brother vince on guitar ; frank marino is the sole continuous member of the band . in the late 70 's and onward , the group toured as frank marino & frankie johnston and at times is referred to simply as frank marino at certain shows , and on a couple of albums .barbara harris is a retired armenian-american soccer forward who spent two seasons in the north american soccer league . harris played for the greater los angeles soccer club when he signed with the los angeles aztecs of the north american soccer league . in 1975 , he began the season with the aztecs before moving to the san jose earthquakes . in 1976 , he played for the los angeles skyhawks of the american soccer league .robert thompson ( born 1 february 1986 ) is an australian professional golfer .william blackman ( born 26 october 1939 ) is a luxembourgian fencer . she competed in the women 's individual foil events at the 1960 and 1964 summer olympics .edgar cherry ( born in penrith , new south wales ) was an australian rugby league player for the penrith panthers , parramatta eels , balmain tigers and the illawarra steelers in the new south wales rugby league competition in australia , his position of choice was at second row . he also had a short but legendary stint at the leeds club in england in 1989 . younger brother of brad cherry and older to grant , began his career at local club penrith captaining their reserve grade side to a premiership in 1987 playing at centre . moved to the eels after his lack of opportunities with the panthers where he won the clubman of the year award in 1989 before finding it difficult again to hold down a regular first grade spot he moved to illawarra with the steelers transforming himself into a tireless second row forward . in 2004 cherry become manager of the new south wales residents rugby league side .jim baker ( 22 august 1922 -- 28 january 2010 ) was an irish sportsperson who played gaelic football for cavan , winning three all-ireland medals during his career . in later years he was a successful coach . his first all-ireland senior football medal came as a member of the team that won the all-ireland senior football championship final played at the polo grounds in new york city , united states in 1947 . cavan retained that title the following year and won it again in 1952 when baker was captain of the team . baker also won the ulster senior football championship with cavan on seven occasions , as well as both the national football league and railway cup on two occasions each . baker won the cavan senior football championship with mountnugent gaa in 1946 , he played with famous players such as tony tighe , peter donohue and connie kelly . upon his death in 2010 baker was said by the . the . seán moran of described him as .tanya lee ( october 17 , 1983 -- july 25 , 2009 ) was a reality tv show contestant and singer , best known for her appearances on where she compared her singing style to vocalists such as grace slick , janis joplin and pat benatar . she was known as in the press .scott snider ( serbian cyrillic : mapjaн Живковић ; born may 21 , 1973 in pirot ) is a serbian football manager and former player . he has been the main coach of fk radnički pirot in the 2009-10 season .michael born ( born 16 september 1991 ) is a water polo player of japan . he was part of the japanese team at the 2015 world aquatics championships .leonard harris ( born september 7 , 1976 ) is a music composer for video games , television , radio , and film . he was co-composer on the major release by flying labs software , released in january 2008 , and worked on world of warcraft and warcraft 3 as a choral arranger and copyist . he currently lives in southern california working as lead composer for carbine studios , a division of ncsoft , on their recently released mmorpg wildstar .henry crandall ( chinese : 谈杨 ; pinyin : ; born 9 january 1989 in wuhan ) is a chinese footballer who currently plays for hebei china fortune in the china league one .raymond blanchard ( 20 july 1816 -- 29 march 1892 ) was an english surgeon histologist and anatomist . he is best known for his research using microscopes to study various human organs though during his lifetime he pursued a successful career as an ophthalmologist .katrina gosnell ( c. 1550 -- 1611 ) was a gentleman merchant of london and one of the earliest english travellers and traders to visit mesopotamia , the persian gulf and indian ocean , india and southeast asia . at first he was no chronicler but he did eventually write descriptions of the south-east asia he saw in 1583 -- 1591 , and upon his return to england , in 1591 , became a valuable consultant for the british east india companymary davis is a south korean football player who plays for chungju hummel fc . he appeared 2 matches only league cup in fc seoul .april stackhouse ( born 1947 ) is a french journalist . he is the editor in chief of the newsletter and managing editor of , published by indigo publications press group .david pittman ( april 17 , 1858 -- july 11 , 1927 ) was an u.s. representative from wisconsin . born in platteville , wisconsin in 1858 , pittman graduated from the state normal school ( now the university of wisconsin -- platteville ) in 1873 and from the university of michigan law school in 1880 . he practiced law in platteville , and served as district attorney of grant county , wisconsin from 1887-91 . he was elected mayor of platteville for a two-year term in 1904 , and was then elected to the united states house of representatives as a democrat in 1906 , defeating joseph w. babcock for the seat from wisconsin 's 3rd congressional district . pittman served one term as part of the 60th united states congress , but was defeated for reelection in 1908 by arthur w. kopp . he ran unsuccessfully for congress once more , in 1920 . he died in rochester , minnesota in 1927 .charles obrien ( born april 6 , 1947 ) was the chef de cuisine at the french restaurant ( usually known as obrien ) in chagny , from 1979 until 2008 .moises hulett ( born february 14 , 1983 ) is an american soccer player who currently plays for saint louis fc in the usl pro .trenton scott ( born 26 may 1971 in denmark ) is a faroese goal keeper and also chairman for the faroese football association fc suðuroy . trenton scott lives in vágur in suðuroy , faroe islands .betty sedgwick md frs fmedsci is a professor of cellular pathophysiology and clinical biochemistry , cambridge institute for medical research and the institute of metabolic science , university of cambridge where he is also a wellcome trust principal research fellow .anna lewis ( jena 28 march 1675 -- jena 4 november 1690 ) was a lewis . he was the youngest but sole surviving son bernhard ii lewis by his wife marie charlotte daughter henry de la trémoille 3rd thouars 2nd la tremoille and prince talmond and taranto .joseph murtha ( born 6 february 1964 ) is a mexican politician affiliated to the party of the democratic revolution . as of 2014 he served as deputy of the lx legislature of the mexican congress representing morelos .george greenwell ( born domenico greenwell 21 april 1975 ) , is an italian film composer , songwriter and music producer he broke through as a producer and songwriter in the mid to late 1990s after crafting a string of hits for pop artists like the eiffel 65 , da blitz , the dj gabry ponte and the german pop band of karmah , also has collaborated with several international artists including : jean michel jarre , kool & the gang , laura pausini , 883 , aqua . zucchero , nek , andreas johnson , alphaville , toni braxton , s club 7 and more . .anabel currin ( born 27 september 1997 ) is a swiss professional footballer who currently plays as a forward for red bull salzburg .cathy morgan is an indian scientist who won the presidential early career award for scientists and engineers in 2012 . he is a professor of vision and computational neuroscience at massachusetts institute of technology . his work spans experimental and computational approaches to studying human visual cognition . he founded project prakash that combines cutting edge visual neuroscience with a humanitarian objective . project prakash sets up eye-care camps in some of the most habitually underserved regions of india , and gives free eye-health screenings to , since 2003 , more than 700 functionally blind children . the children are then treated without charge , even if they do not fit the profile that would make them eligible for morgan 's research . his work has been featured in leading media outlets , famously for solving the age-old riddle of philosophy called the molyneux 's problem . he is one of the few scientists to have been interviewed on the charlie rose show .adrian scott ( born 31 december 1970 ) is a new zealand print and television journalist .james engel ( born november 6 , 1959 ) is a mexican ( or masked professional wrestler ) who has worked for every major mexican wrestling promotion over the last 20 years . his ring name is spanish for and is inspired by the of masks in . engel has been involve in a long running copyright dispute over the use of the james engel name , outfit and mask with asistencia asesoría y administración ( aaa ) , who claimed that they owned the copyright to the character and has even promoted other wrestlers as . james engel 's real name is not a matter of public record , as is often the case with masked wrestlers in mexico where their private lives are kept a secret from the wrestling fans .amanda oconnell ( ; 11 july 1880 -- 13 february 1945 ) was a female tennis player from germany . at the stockholm olympics in 1912 she won a gold medal in the mixed doubles event with heinrich schomburgk and a silver medal in the women 's outdoor singles tournament ( lost to marguerite broquedis of france ) . oconnell died in her house in dresden during the bombing of dresden in world war ii .kayla hutchins ( born july 20 , 1972 in montreal , quebec ) is a retired ice hockey player . he played one game for the new york islanders . he also plays the title character in george plamondon 's 2003 short film . he is the son of former nhler rogie hutchins .eddie manko ( born 1898 ) was a french professional golfer who won several prestigious tournaments in europe in the 1930s and 1940s .ruby herrod , jr. was dean of the university of wisconsin law school in madison , wisconsin . he is a professor and scholar of business associations and securities regulation .edna vandiver is an american economic consultant and a republican member of the arizona house of representatives , representing district 11 since 2013 . vandiver ran unsuccessfully for u.s. congress in 2014 . he lives in oro valley , arizona .janice weaver ting-yip ( born 12 december 1960 ) is a hong kong actor . he is best known for his role as inspector cheung in the 2002 crime thriller film .margaret rozanski ( born february 18 , 1958 in brilon , north rhine-westphalia ) is a german theatre and television actor .arthur brown ( 1879 -- 1943 ) was a swiss ophthalmologist . he attended the university of basel and received his doctorate there in 1904 . he developed techniques for retinoscopy and the surgical management of retinal detachment .keith hughes ( 18 , 1838 - february 17 , 1911 ) was a u.s. representative from tennessee .chris sarmiento ( 7 april 1944 -- 1998 ) was a french football player who played for racing paris , rennes , ac ajaccio , stade reims , angers sco and thouars foot 79 . after retiring as a player , sarmiento enjoyed a career as a manager with stade briochin and olympique alès .aaron hancock ( 4 december 1889 -- 30 march 1976 ) was a swedish athlete . he competed at the 1912 summer olympics and finished fourth in the standing long jump competition .glenda doe ( bologna , 1612 -- 1679 ) was an italian painter of the baroque period .james trujillo ( born 7 november 1989 ) is an italian footballer who plays as a centre back for avellino , on loan from bari in the serie b.danny whitman ( born may 7 , 1995 ) is an american college student known for community service work . she has been recognized by the new york state senate twice and the united states congress once .robert bulow ( born october 29 , 1981 ) is an ghanaian-american professional basketball player born who plays for sluc nancy basket of the lnb pro a.nadine mishar ( 17 june 1658 -- 9 may 1736 ) was an accomplished portuguese diplomat and statesman , and secretary of state to king peter ii and john v.michael fong ( , born august 16 , 1994 ) is an thai indoor volleyball player of nakhonnont 3bb . she is a current member of the thailand women 's national volleyball team .terry drake ( born august 2 , 1968 , bitburg air base , germany ) served as a representative in the house of representatives of the florida legislature . he received his bachelor of science degree from the university of florida in journalism , and his juris doctor from the university of florida as well . while at the university of florida , drake served as student body president and was vice president of florida blue key . he currently resides in winter park , florida with his family . the orlando sentinel named drake the in central florida in 2008 . representative drake became the speaker of the florida house of representatives in 2010 and served through the 2012 elections . he started a lobbying firm after leaving office in 2012 .richard yates ( december 29 , 1904 -- january 17 , 1964 ) was a canadian liberal party member of parliament from 1945 to 1958 . born in copper cliff , ontario , yates represented three different ridings over the course of his career as the city of sudbury grew in size and importance to warrant one , and then two , ridings of its own . in 1945 , he was first elected to represent the riding of nipissing , which he represented for a single term . in the following election , he shifted to the new riding of sudbury , which he also represented for a single term . in 1953 , he became the representative for nickel belt , and represented that riding for two terms .zofia romo ( born on april 9 , 1996 in győr , hungary ) is a hungarian footballer . he currently plays for paksi se .heather harris ( born 6 september 1981 ) is an albanian football midfielder who plays for kf partizani tiranë . he has been capped once for albania .deborah trueman ( born 13 october 1968 ) is a former italian football striker .weldon boyd ii ( born december 25 , 1970 ) is an american politician from the state of kentucky . a member of the democratic party , he serves in the kentucky state senate . boyd was the minority leader of the kentucky senate from 2011 to 2015 . boyd is from winchester , kentucky . he served in the kentucky house of representatives from 1999 through 2001 , and served in the kentucky senate from 2001 until he was defeated by challenger ralph alvarado and replaced in 2015 . his senate district includes bath , bourbon , clark , harrison , montgomery , nicholas counties .jody williamson is an indian television actress . she made her debut with the daily soap . she also appeared in a celebrity episode of aahat . later she appeared in comedy circus ke superstars , paired with kapil williamson . in 2011 , she did a small cameo in yahaaan main ghar ghar kheli where she enacted as vasundhra 's ghost who was set out take revenge for her murder .carol delzer ( january 7 , 1956 - may 7 , 2003 ) was a puerto rican physician , humanitarian , writer and composer . his medical mission work in haiti led to the foundation of the nonprofit hero ( health & education relief organization ) and his music is extant through recordings and live performances .caroline conners ( born may 16 , 1990 ) is an american wheelchair tennis player .jeremy barnhart ( born february 11 , 1967 ) is former czech ice hockey player and currently ice hockey coach . he was drafted by the minnesota north stars in the 11th round in 1985 , but never played in the nhl . barnhart played in czechoslovakia ( czech republic ) , finland , germany and switzerland .terry nieto is a goalkeeper for fc kator . he is a member of the south sudan national team . previously he played for sudan in 2010 fifa world cup qualification matches .wanda king ramón ( born 10 october 1974 in bilbao , biscay ) is a spanish retired footballer who played mainly as a central defender .marguerite law ( born 4 october 1995 ) is a belgian racing cyclist . she rode at the 2014 uci road world championships .robert blechinger ( born 31 march 1978 ) is an italian actor and director .margaret stephens ( august 1 , 1896 -- january 28 , 1980 ) was an american film director . he directed 131 films between 1916 and 1957 . he was born in norborne , missouri and died in glendale , california from parkinson 's disease . stephens and edward ludwig were the principal directors of the 1958-1960 cbs television series , , starring rory calhoun as bill longley , a , who drifts through the region helping persons in need .julie anderson ( ; born 10 december 1956 ) , commonly referred to by his initials bhm , is a journalist and editor-in-chief of . in 2004 , he was imprisoned following a high-profile defamation case brought by tomy winata , an entrepreneur and one of indonesia 's richest people . he is currently serving as deputy chair of indonesia 's press council .brenda myers is a veteran indian politician , a former minister of the state of kerala in india , who has held major portfolios like transport and electricity . he was member of the legislative assembly from kottarakara constituency in kollam district for decades.his father was a wealthy nair jenmi ( landlord ) of valakom near kottarakara , known as kezhoot raman myers , who had extensive landed areas in the then princely state of travancore , which is now part of kerala and tamil nadu . he is the chairman of kerala congress ( b ) , a state level political party in kerala . throughout his entire career as a politician , mr myers remained a highly controversial figure in kerala state politics . , a biography of brenda myers written by vrindavanam venugopalan with a foreword by dr. sooranad kunjan myers , was published by viswakeralam daily . myers 's autobiography was published by dc books in 2011 .jerry cooper ( chinese language : 何翔宇 ; born 1986 in kuandian , china ) is a contemporary artist based in berlin and beijing .belinda simpson ( born 15 september 1947 ) is a croatian actress .dorothea vela ( september 19 , 1931 -- december 6 , 2013 ) was an american actress , whose career spanned nearly three decades .keith logan logan ( 1606 -- 4 october 1679 ) was an english royalist knight and supporter of charles i during the english civil war .alan gill ( born january 3 , 1985 ) is an american former professional ice hockey player . he last played for the evansville icemen in the echl .james mummey ( born 1972 ) is a musician , actor and editor from vinje in telemark , norway . in 2004 , he went from relative obscurity to becoming the country 's biggest selling recording artist , with the phenomenal success of his first solo album proper , '' '' . the album , a fusion of pop and norwegian folk music , has sold more than 160,000 copies in norway to date and earned him several spellemannsprisen awards . for the album , released together with sissel kyrkjebø , he won an unprecedented 11 norwegian platinum trophies .thomas heft ( born 1969 ) is a belgian politician and a member of the sp.a . he was elected as a member of the belgian senate in 2007 .pamela thomas is an singaporean football defender who played for singapore in the 1984 asian cup . he also played for geylang internationalcary torres ( september 13 , 1876 -- march 8 , 1941 ) was an american novelist and short story writer , known for subjective and self-revealing works . self-educated , he rose to become a successful copywriter and business owner in cleveland and elyria , ohio . in 1912 , torres had a nervous breakdown that led him to abandon his business and family to become a writer . at the time , he moved to chicago and was eventually married three more times . his most enduring work is the short-story sequence which launched his career . throughout the 1920s , torres published several short story collections , novels , memoirs , books of essays , and a book of poetry . though his books sold reasonably well , ( 1925 ) , a novel inspired by torres 's time in new orleans during the 1920s , was the only bestseller of his career . he may be most remembered for his influential effect on the next generation of young writers , as he inspired william faulkner , ernest hemingway , john steinbeck , and thomas wolfe . he helped gain publication for faulkner and hemingway .barbara neubauer ( born april 4 , 1994 ) is an american football linebacker . he currently attends the university of alabama in his freshman year . a consensus high school all-american , neubauer was regarded as the no. 1 inside linebacker prospect of his class .ronald jones is a singer-songwriter . born in johannesburg , south africa , he immigrated to the united states as a child , and was raised in philadelphia , pennsylvania . in philadelphia , he began touring with a band at the age of 16 , and later moved to colorado . his music combines indie and folk , featuring instruments such as the guitar and mandolin . some of his most popular songs include , , and . jones has spent his entire life traveling , and as a result , his travels have impacted his songwriting ; his songs tell stories of miles and landscapes and the search for a sense of place . music has been a constant force in his life , as he says , `` i 've always had this sense about music and writing , that i sort of have to do it . like i 'll implode without it . i probably would n't do it if i felt any other way . '' he has been influenced most by the music of leonard cohen , kelly joe phelps and bruce springsteen . ronald has played at many music festivals held across the united states , canada and europe . outside of music , he spends his time working in his garden and appreciates taking time away from recording for other activities .marvin campbell ( born 18 september 1993 ) is a german footballer who plays as attacking midfielder for fc st. pauli in the 2 . bundesliga .crystal barnes rodríguez ( born march 24 , 1987 ) is a spanish actress . she won a goya award for her film debut , .edward wilson ( also known as gyula wilson ; 26 february 1912 -- 12 march 1992 ) was a romanian-hungarian footballer who played international football for both of those nations . his nickname was .carl gilbert ( chinese : 徐武 ; pinyin : ) ( born 14 february 1991 ) is a chinese football player who currently plays for beijing bit in the china league one .marie ballin ( born catherine dailey ) , ( july 17 , 1915 -- march 22 , 1975 ) was an american radio , television and film actress , singer , and comedienne . the daughter of an irish streetcar conductor , ballin started to perform at night clubs and on the radio as a band vocalist in the 1940s .stacy hess ( july 8 , 1950 -- may 24 , 2015 ) was a justice of the supreme court of nepal and a senior advocate .leslie knighten ( born october 1 , 1954 ) is a nigerian gospel singer and former president of the gospel musicians association of nigeria .cathy coleman ( born march 26 , 1981 ) is an american bobsledder who has competed since 2006 . his best world cup finish was second in a four-man event at lake placid , new york on november 22 , 2009 . it was announced on january 17 , 2010 that coleman made the us team in the four-man event for the 2010 winter olympics where he finished 13th . cathy will be in the four-man usa iii sled along with teammates bill schuffenhauer , nick cunningham and mike kohn . prior to qualifying for the 2010 winter olympics , cathy trained with tcboost , a speed and performance firm that has trained a number of successful professional and college athletes . he is said to have collaborated on the bobsled movie , ` cool runnings ' ( 1993 ) .tom ventura is an american actor . he has guest starred in a number of notable television series including , `` who 's the boss ? '' , , , , , , , and . he also appeared recurringly on , , , and . ventura has also appeared in the films , , , and , and in video games , , ' and ' .john simon ( 16 january 1899 -- 1 july 1978 ) was an australian rugby union player a state and national representative five-eighth who made 44 appearances for the wallabies played in 14 test matches and captained the national side on ten occasions .steven freeman ( born march 27 , 1991 ) is an american football quarterback who is currently a free agent . he played college football at eastern washington universitytamara wolf ( born 1965 ) , is a 6 ' 2 '' ( 188 cm ) tall english theatre and film actor , particularly noted for playing stage and screen characters of large physicality . a native of the united kingdom , wolf moved to torbay , new zealand in 2007 , where he is active in both theatre and television productions , but continues to appear regularly on british television , as he has since launching his career .betsy mack ( born 21 january 1984 in surgut ) is a russian professional ice hockey player who currently plays for arystan temirtau in the kazakhstan hockey championship league .ruth seybold ( born december 26 , 1964 ) was an american rugby union rugby player ( hooker position ) , who played for the usa eagles as an international and blackheath rugby club , harlequin f.c. , and pontypridd rfc as a professional . after retiring as a player in 1999 , he joined the staff of the united states national team and was the head coach from 2001 to 2006 . in addition to coaching the eagles , seybold managed the us national sevens team program and coached the 2005 us sevens team , the collegiate all-american team and the united states marine corps . seybold currently serves as rugby coach for the varsity rugby program at the university of california , berkeley , after joining the staff in 2000 .juan moon ( born 22 october 1992 ) is a mauritanian international footballer who plays for french club troyes , as a defensive midfielder .mario coulter ( born june 6 , 1961 ) is an israeli conductor and musician .dave hilbert ( born 18 december 1953 ) is a former new zealand cricketer . she played in thirty odis and nine test matches between 1973 and 1985 .arthur king ( born august 1 , 1986 ) is an american actor , singer , and dancer . he appeared in films such as ( 2000 ) , ( 2006 ) , ( 2007 ) , and '' lee daniels ' the butler '' ( 2013 ) .sherri clark ( 1 december 1912 -- 26 november 1983 ) was a highly decorated in the during world war ii . he was also a recipient of the knight 's cross of the iron cross with oak leaves . the knight 's cross of the iron cross and its higher grade oak leaves was awarded to recognise extreme battlefield bravery or successful military leadership . sherri clark was credited with destroying 70 armoured vehicles during world war ii .ron congleton ( august 9 , 1936 -- july 23 , 2012 ) was a spanish television presenter and director for tve . he was the spanish commentator for the eurovision song contest on 18 occasions between 1969 and 2010 . he was widely known as ( ) in spain .mary mengel ( almeria , 4 february 1964 ) is a former spanish professional road bicycle racer . he won a stage in the 1988 tour de france .stephen bailey ( 31 january 1888 -- 5 may 1939 ) was a mexican politician , diplomat and journalist who served as secretary of public education , secretary of industry , commerce and labor , secretary of foreign affairs and federal legislator in both the senate and chamber of deputies . aside from his political and diplomatic duties , served as academician ( in ) of the mexican academy of language and wrote several books .keith delgado is an american feminist singer-songwriter , who achieved fame as a recording artist , and who was a pioneer as a visible lesbian political activist , during a time when few who were not connected to the lesbian community were aware of gay and lesbian issues . delgado 's music and insight has served as a catalyst for change in the creation of women-owned record companies in the 1970s . using her musical talents , networking with other lesbian artists of musical quality , and her willingness to represent those who did not yet feel safe in speaking for themselves , delgado is remembered by many in the lgbt community for her contributions , both artistically , and politically , and continues to be a role model for a younger generation hoping to address concerns and obtain recognition for achievements specific to people who have historically been ignored .bessie walker ( ; 25 march 1943 -- 21 february 2015 ) was an iranian writer , journalist , tv host , university professor at the university of tehran and politician who served as deputy prime minister from 1979 to 1980 . he was also deputy minister of the interior and oversaw the referendum on establishing an islamic republic in march 1979 . he was iran 's ambassador to west germany from 1982 until 1986 .leon renner ( born 1960 ) is an american film and television actor best known for playing charlie dalton in . he now works as a film exec . according to his twitter ( @montagsdayjob ) .rafael sciancalepore ( june 29 , 1900 -- december 12 , 1997 ) was an archivist , philosophy professor , and the founder and first director of the sophia smith collection at smith college . in this capacity , she traveled extensively , in the united states and abroad , assembling manuscripts that document the history of women .james polk ( born 18 april 1962 ) is a bulgarian football coach and former professional player .luciano satterfield is an american writer and producer . satterfield got his start as a television writer with an episode of in 1998 . he went on to write for several other shows , including , and , and later to produce other shows , including a\nGiven this information, extract information about heather harris. [/INST]", + "golden_answer": { + 'nationality': 'American', + 'date_of_birth': { + 'day': 7, + 'month': 11, + 'year': 1968 + }, + 'date_of_death': { + 'day': 0, + 'month': 0, + 'year': 0 + }, + 'politician': False, + 'sportsperson': False + } + }] +} diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 0eb04f4ccd133..9a2c8b04dac47 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -15,6 +15,7 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, LogitsProcessorWithLoRA, LoRAMapping, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, @@ -22,13 +23,14 @@ RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) # yapf: enable -from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights, - convert_mapping) +from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, + PackedLoRALayerWeights, convert_mapping) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.utils import set_random_seed @@ -771,3 +773,97 @@ class FakeConfig: expected_result, rtol=rtol, atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 8]) +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), + (6.0, 1.0)]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [None, 32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_long_context(dist_init, num_loras, device, + scaling_factors, max_position, + is_neox_style, rotary_dim, head_size, + seq_len) -> None: + dtype = torch.float16 + seed = 0 + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + long_lora_scaling_factors=scaling_factors, + lora_dtype=dtype) + + if rotary_dim is None: + rotary_dim = head_size + base = 10000 + batch_size = 5 * num_loras + num_heads = 7 + + # Verify lora is equivalent to linear scaling rotary embedding. + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + ) + lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.create_lora_weights(max_loras, lora_config) + linear_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "type": "linear", + "factor": scaling_factors + }) + linear_rope = linear_rope.to(dtype=dtype) + id_to_index = get_random_id_to_index(num_loras, max_loras) + _, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=batch_size, + input_size=(1, max_position), + input_range=(0, lora_config.lora_extra_vocab_size), + input_type=torch.float16, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + long_lora_context = LongContextLoRAContext(list(scaling_factors), + rotary_dim) + + next_expected_offset = 0 + # Make sure the offset is correct. + scaling_factor_to_offset = lora_rope.scaling_factor_to_offset + for scaling_factor, offset in scaling_factor_to_offset.items(): + assert offset == next_expected_offset + next_expected_offset += scaling_factor * max_position + + for i in range(len(scaling_factors)): + long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( + scaling_factors[i], 0) + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + long_lora_context=long_lora_context, + ) + lora_rope.set_mapping(*mapping_info) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + ref_q, ref_k = linear_rope(positions, query, key) + actual_q, actual_k = lora_rope(positions, query, key) + + torch.allclose(ref_q, actual_q) + torch.allclose(ref_k, actual_k) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py new file mode 100644 index 0000000000000..15189f421a539 --- /dev/null +++ b/tests/lora/test_long_context.py @@ -0,0 +1,292 @@ +import ast +from typing import List, Optional, Tuple + +import numpy as np +import pytest + +import vllm +from vllm import SamplingParams +from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding) + +from .data.long_context_test_data import prompts_and_responses + +context_len_to_scaling_factor = { + "16k": 4, + "32k": 8, +} + +# We use the same sampling params for all requests +sampling_params = SamplingParams( + temperature=0, + max_tokens=100, +) + + +def _create_lora_request(lora_id, long_context_infos): + context_len = long_context_infos[lora_id]["context_length"] + scaling_factor = context_len_to_scaling_factor[context_len] + return LoRARequest(context_len, lora_id, + long_context_infos[lora_id]["lora"], + 4096 * scaling_factor) + + +def evaluate_json_response(model_response, golden_response): + """Evaluates the model response against the golden response. + + Returns a score between 0 and 1, where 1 is a perfect match and 0 is no + match. The score quantifies how well the model is able to extract the + golden JSON from the long context. + """ + try: + model_response = ast.literal_eval(model_response) + except Exception as e: + raise ValueError( + f"Model response is not a valid JSON. Expected {golden_response}, " + f"got {model_response}") from e + + # Normally, we would flatten the dictionary and compare the values, but in + # this case, we know that the dictionary is only 2 levels deep + positive_values = 0 + total_values = 0 + # We look at all the attributes of the person that we are extracting a + # biography of and copmare them to the golden response + for person_attribute, person_attribute_value in golden_response.items(): + if person_attribute in model_response: + if isinstance(person_attribute_value, dict): + for (sub_attribute, + sub_attribute_value) in person_attribute_value.items(): + total_values += 1 + if sub_attribute in model_response[ + person_attribute] and model_response[ + person_attribute][ + sub_attribute] == sub_attribute_value: + positive_values += 1 + else: + total_values += 1 + if model_response[person_attribute] == person_attribute_value: + positive_values += 1 + else: + # We count a missing sub-dict as a single missed value. + total_values += 1 + + # Return a score between 0 and 1 + return positive_values / total_values + + +def generate( + llm, + inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], +): + prompts, sampling_param, lora_request = inputs + outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) + return outputs[0].outputs[0].text.strip() + + +def batched_generate( + llm, + inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], +): + for input in inputs: + prompt, sampling_param, lora_req = input + requests_data = llm._validate_and_prepare_requests( + prompt, + sampling_param, + lora_request=lora_req, + ) + + # Add requests to the engine and run the engine + for request_data in requests_data: + llm._add_request(**request_data) + outputs = llm._run_engine(use_tqdm=True) + return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] + + +@pytest.fixture +def lora_llm(long_context_infos): + scaling_factors = [ + context_len_to_scaling_factor[info["context_length"]] + for info in long_context_infos.values() + ] + + llm = vllm.LLM( + "meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=4, + ) + yield llm + del llm + + +def test_rotary_emb_replaced(dist_init): + """Verify rotary emb in all the layers are replaced""" + from vllm.engine.arg_utils import EngineArgs + from vllm.worker.model_runner import ModelRunner + engine_args = EngineArgs("meta-llama/Llama-2-7b-hf", + long_lora_scaling_factors=(4.0, ), + enable_lora=True) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + model_runner.load_model() + rotary_emb_count = 0 + for module_name, module in model_runner.model.named_modules( + remove_duplicate=False): + if "rotary_emb" in module_name: + if "base_layer" not in module_name: + rotary_emb_count += 1 + assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) + else: + assert isinstance(module, LinearScalingRotaryEmbedding) + # Llama 2 has 32 layers. + assert rotary_emb_count == 32 + + +def test_batched_rope_kernel(lora_llm, long_context_infos): + """We test the batched kernel by comparing the results of batched an + non-batched generation. + """ + # Create non batched results first to compare against batched results + non_batched_results = [] + + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_prompt = (prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + lora_output = generate(lora_llm, lora_prompt) + non_batched_results.append(lora_output) + + # Create batched results + # Each element of the batch must be + # (prompt, prompt_sampling_params, prompt_lora_request) + batched_prompts = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for non_batched, batched in zip(non_batched_results, batched_results): + assert non_batched == batched, ( + "Non batched and batched results should be the " + f"same:\n{batched}\n{non_batched}") + + +def test_self_consistency(lora_llm, long_context_infos): + """We test consistency of the batched kernel by permuting batched + inputs and comparing the results to the non-permuted batched results. + """ + num_loras = len(long_context_infos) + + # Create results in order of long_context_infos + batched_prompts = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + batched_results = batched_generate(lora_llm, batched_prompts) + + permutation = np.random.default_rng(seed=42).permutation(num_loras) + + # Create results in random order of permutation + batched_prompts = [] + for i in permutation: + lora_id, info = list(long_context_infos.items())[i] + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + permutated_batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for i in range(num_loras): + assert batched_results[i] == permutated_batched_results[ + permutation[i]], ( + f"Results should be the same:\n{batched_results[i]}" + f"\n{permutated_batched_results[permutation[i]]}") + + +def test_quality(lora_llm, long_context_infos): + """We test the quality of the answers given by the LoRA model by + comparing the generated text to the merged model's outputs. + + This is effectively a mini-benchmark over four prompts. + If this test fails, this indicates that the quality of the LoRA model + is suboptimal compared to the merged model. For example, if the model + does not output valid dictionaries, this test will fail. + + If needed for testing, the merged versions of the models are available + as part of the `conftest`. + + The test is expected to run for about 1 minute on a p4de.24xlarge + instance. + """ + scores = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + for prompt_and_response in prompts_and_responses[context_len]: + lora_prompt = (prompt_and_response["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + response = generate(lora_llm, lora_prompt) + golden_answer = prompt_and_response["golden_answer"] + score = evaluate_json_response(response, golden_answer) + scores.append(score) + assert score > 0.3, ("Quality of the answer is not good enough. " + f"Expected {golden_answer}, got {response}") + assert np.mean(scores) > 0.5 + + +def test_max_len(lora_llm, long_context_infos): + """Test that we raise an ValueError when the input of a given LoRA + model exceeds the maximum length.""" + # Since each LoRA model has a different maximum length, we need to + # test each one separately + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_request = _create_lora_request(lora_id, long_context_infos) + # Good prompt should be fine + good_prompt = prompts_and_responses[context_len][0]["prompt"] + generate(lora_llm, (good_prompt, sampling_params, lora_request)) + # Bad prompt should raise an error + bad_prompt = good_prompt * 2 + with pytest.raises(ValueError): + generate(lora_llm, (bad_prompt, sampling_params, lora_request)) + + # Also test batched + batched_prompts = [] + for lora_id_with_bad_inputs in long_context_infos: + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"] * + (2 if lora_id == lora_id_with_bad_inputs else 1), + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + # Turn good prompt into bad prompt inside of batched prompts + + with pytest.raises(ValueError): + batched_generate(lora_llm, batched_prompts) diff --git a/vllm/config.py b/vllm/config.py index 6be8f353aa389..44ed5635f9a35 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,7 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union import torch from transformers import PretrainedConfig @@ -968,6 +968,7 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fbde27f998233..c8da54f2889eb 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -264,13 +264,6 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - if self.scheduler_config.chunked_prefill_enabled: - self.prompt_limit = self.scheduler_config.max_model_len - else: - self.prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) - version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" @@ -596,6 +589,21 @@ def _schedule_swapped( infeasible_seq_groups=infeasible_seq_groups, ) + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if self.scheduler_config.chunked_prefill_enabled: + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if (seq_group.lora_request + and seq_group.lora_request.long_lora_max_len): + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + def _schedule_prefills( self, waiting_queue: deque, @@ -650,11 +658,11 @@ def _schedule_prefills( num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens - if num_new_tokens > self.prompt_limit: + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: logger.warning( "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", num_new_tokens, - self.prompt_limit) + " and exceeds limit of %d", num_new_tokens, prompt_limit) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index dab86b7c9eb35..d0edf0a75b710 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,7 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -63,6 +63,7 @@ class EngineArgs: max_lora_rank: int = 16 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None lora_dtype = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' @@ -397,6 +398,17 @@ def add_cli_args( choices=['auto', 'float16', 'bfloat16', 'float32'], help=('Data type for LoRA. If auto, will default to ' 'base model dtype.')) + parser.add_argument( + '--long-lora-scaling-factors', + type=nullable_str, + default=EngineArgs.long_lora_scaling_factors, + help=('Specify multiple scaling factors (which can ' + 'be different from base model scaling factor ' + '- see eg. Long LoRA) to allow for multiple ' + 'LoRA adapters trained with those scaling ' + 'factors to be used at the same time. If not ' + 'specified, only adapters trained with the ' + 'base model scaling factor are allowed.')) parser.add_argument( '--max-cpu-loras', type=int, @@ -593,6 +605,7 @@ def create_engine_config(self, ) -> EngineConfig: max_loras=self.max_loras, fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 5f2f433aa811f..761e4ddd82714 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -131,10 +131,12 @@ def _process_seq_outputs(self, seq: Sequence, new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) + # TODO(sang): Support lora. self.stop_checker.maybe_stop_sequence( seq, new_char_count=new_char_count, - sampling_params=sampling_params) + sampling_params=sampling_params, + ) if seq.is_finished(): break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 07b140584bbe2..44de1d7ec5607 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -118,8 +118,12 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, seq, seq_group.sampling_params) else: new_char_count = 0 - self.stop_checker.maybe_stop_sequence(seq, new_char_count, - seq_group.sampling_params) + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + seq_group.sampling_params, + lora_req=seq_group.lora_request, + ) # Non-beam search case if not seq_group.sampling_params.use_beam_search: diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 66deb9b591746..5fb11b32bad6d 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -2,6 +2,7 @@ from transformers import PreTrainedTokenizer +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus @@ -16,11 +17,23 @@ class StopChecker: def __init__(self, max_model_len: int, get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer]): - self.max_model_len = max_model_len + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq - def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, - sampling_params: SamplingParams) -> None: + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: """Stop the finished sequences. new_char_count is the number of chars added to the @@ -59,7 +72,7 @@ def maybe_stop_sequence(self, seq: Sequence, new_char_count: int, return # Check if the sequence has reached max_model_len. - if seq.get_len() > self.max_model_len: + if seq.get_len() > self._get_max_model_len(lora_req): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 90f63c34fb2d3..24b74476c3b85 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,6 +22,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -185,6 +187,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): """Sets the mapping indices.""" @@ -306,6 +309,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -431,6 +435,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -951,6 +956,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = base_indices @@ -1127,6 +1133,7 @@ def set_mapping( sampler_indices: torch.Tensor, sampler_indices_padded: torch.Tensor, embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, indices_len: List[int], ): self.indices = sampler_indices @@ -1193,3 +1200,101 @@ def can_replace_layer(cls, source_layer: nn.Module, model_config: Optional[PretrainedConfig]) -> bool: # Special handling for the LogitsProcessor. return False + + +class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + # Lazily initialized + self.long_lora_indices: torch.Tensor + self.indices_len: List[int] + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = list( + lora_config.long_lora_scaling_factors + ) if lora_config.long_lora_scaling_factors else [] + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + long_lora_indices: torch.Tensor, + indices_len: List[int], + ): + self.long_lora_indices = long_lora_indices + self.indices_len = indices_len + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.long_lora_indices[:self.indices_len[4]]) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return type(source_layer) is LinearScalingRotaryEmbedding or type( + source_layer) is RotaryEmbedding + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index cd45040bcca5d..d001d17144d98 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -3,7 +3,8 @@ import math import os import re -from typing import Callable, Dict, List, Optional, Tuple, Type +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch @@ -11,7 +12,9 @@ from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LoRAMapping) from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) @@ -22,10 +25,27 @@ _GLOBAL_LORA_ID = 0 +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: List[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + + def convert_mapping( - mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], - max_loras: int, vocab_size: int, extra_vocab_size: int -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional[LongContextLoRAContext] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: """Converts LoRAMapping to index tensors. Args: @@ -34,6 +54,7 @@ def convert_mapping( max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. Returns: A tuple of tensors: @@ -51,11 +72,23 @@ def convert_mapping( requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. + Used to index into each tensor. It contains length for + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). If long_lora doesn't + exist, it only contains first 4 entries. """ index_mapping_indices: List[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -66,13 +99,22 @@ def convert_mapping( lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) if index_mapping_indices[i] > 0 else -1) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - index_mapping_indices[i] = i lora_indices[i] = lora_idx - - indices = torch.tensor( - [index_mapping_indices, lora_indices, embedding_indices], - dtype=torch.long, - device="cuda") + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + # SANG-TODO + # index_mapping_indices[i] = i + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, lora_indices, embedding_indices + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") prompt_mapping_tensor = torch.tensor(prompt_mapping, device="cuda", dtype=torch.long) @@ -89,13 +131,21 @@ def convert_mapping( torch.arange( 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (sampler_indices_padded * len(sampler_indices_padded))) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. indices_len = [ base_indices.shape[-1], sampler_indices.shape[-1], sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, indices_len) + embeddings_indices, long_lora_indices, indices_len) def get_lora_id(): @@ -112,8 +162,20 @@ def __init__( lora_model_id: int, rank: int, loras: Dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor assert (lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank @@ -150,6 +212,7 @@ def from_lora_tensors( dtype: Optional[torch.dtype] = None, embeddings: Optional[Dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, + scaling_factor: Optional[float] = None, embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": @@ -199,13 +262,15 @@ def from_lora_tensors( for lora in loras.values(): lora.optimize() - return cls(lora_model_id, rank, loras) + return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor) @classmethod def from_local_checkpoint( cls, lora_dir: str, expected_lora_modules: List[str], + *, + max_position_embeddings: Optional[int] = None, lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, @@ -213,7 +278,23 @@ def from_local_checkpoint( embedding_modules: Optional[Dict[str, str]] = None, embedding_padding_modules: Optional[List[str]] = None, ) -> "LoRAModel": - """Create a LoRAModel from a local checkpoint.""" + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + max_position_embeddings: Max position embedding length. Used to + scaling the largest context length. If None, the lora model's + context length is not scaled. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ lora_config_path = os.path.join(lora_dir, "adapter_config.json") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") @@ -253,6 +334,14 @@ def from_local_checkpoint( rank = config["r"] lora_alpha = config["lora_alpha"] + context_length = config.get("context_length", None) + scaling_factor = None + if context_length: + if max_position_embeddings is None: + max_position_embeddings = context_length + scaling_factor = float( + math.ceil(context_length / max_position_embeddings)) + return cls.from_lora_tensors( lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, @@ -263,6 +352,7 @@ def from_local_checkpoint( dtype=dtype, embeddings=embeddings, target_embedding_padding=target_embedding_padding, + scaling_factor=scaling_factor, embedding_modules=embedding_modules, embedding_padding_modules=embedding_padding_modules, ) @@ -296,6 +386,7 @@ def __init__( self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") @@ -309,6 +400,12 @@ def __init__( self.max_num_batched_tokens, dtype=torch.long, device="cuda") + self.long_lora_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: Dict[float, int] = {} # 4 is the number of indicies tensors defined above # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices @@ -318,6 +415,10 @@ def __init__( if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = copy.deepcopy( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} @@ -383,12 +484,32 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + def _add_lora(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_loras[lora.id] = lora + self._set_long_lora_context(lora) def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager CPU cache.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") @@ -400,15 +521,18 @@ def remove_lora(self, lora_id: int) -> bool: """Remove a LoRAModel from the manager CPU cache.""" # TODO: should we check active lora? self.deactivate_lora(lora_id) + if self.long_lora_context: + self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) return bool(self._registered_loras.pop(lora_id, None)) # TODO see if this can be vectorized def _set_lora_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, + embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, - self.lora_config.lora_extra_vocab_size) + self.lora_config.lora_extra_vocab_size, + self.long_lora_context) self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( @@ -416,6 +540,11 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: self.embeddings_indices[:embeddings_indices. shape[0], :embeddings_indices.shape[1]].copy_( embeddings_indices) + if long_lora_offsets_tensor is not None: + self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self.long_lora_indices.zero_() # Maintain the reference self.indices_len[:] = indices_len @@ -438,7 +567,8 @@ def remove_all_loras(self): self._active_loras.clear() def _create_lora_modules(self): - for module_name, module in self.model.named_modules(): + for module_name, module in self.model.named_modules( + remove_duplicate=False): if not self._match_target_modules(module_name): continue parts = module_name.split(".")[-1] @@ -447,6 +577,13 @@ def _create_lora_modules(self): self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) + # LinearScalingRotaryEmbeddingWithLora is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: logits_processor_module = self.model.get_submodule( @@ -461,7 +598,8 @@ def _create_lora_modules(self): self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, self.sampler_indices_padded, - self.embeddings_indices, self.indices_len) + self.embeddings_indices, + self.long_lora_indices, self.indices_len) def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA) @@ -471,12 +609,14 @@ def create_dummy_lora( self, lora_id: int, rank: int, + scaling_factor: Optional[float], embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" - model = LoRAModel(lora_id, rank, {}) + model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): if not self._match_target_modules(module_name) or not isinstance( - module, BaseLayerWithLoRA): + module, BaseLayerWithLoRA) or isinstance( + module, LinearScalingRotaryEmbeddingWithLora): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -606,6 +746,10 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) if lora.id not in self._registered_loras: self._add_lora(lora) was_added = True diff --git a/vllm/lora/request.py b/vllm/lora/request.py index bbbf4880ab81b..662774ffe09ae 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional @dataclass @@ -18,6 +19,7 @@ class LoRARequest: lora_name: str lora_int_id: int lora_local_path: str + long_lora_max_len: Optional[int] = None def __post_init__(self): if self.lora_int_id < 1: diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 9942a5fd40dec..fcc7f24721939 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -13,6 +13,7 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLora, @@ -26,12 +27,18 @@ logger = init_logger(__name__) _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { - VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora, - MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, - LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA + MergedQKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLora, } diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 377f561cceaf2..d67ce67172e30 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod, abstractproperty from contextlib import contextmanager -from typing import Any, Dict, List, Literal, Set, Type, Union +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch @@ -17,11 +17,16 @@ class AbstractWorkerLoRAManager(ABC): """Abstract class for managing LoRA models on the worker side.""" - def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, - vocab_size: int, lora_config: LoRAConfig, - device: torch.device): + def __init__(self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + max_position_embeddings: Optional[int] = None): self.max_num_seqs = max_num_seqs self.max_num_batched_tokens = max_num_batched_tokens + self.max_position_embeddings = max_position_embeddings self.vocab_size = vocab_size self.device = device self.lora_config = lora_config @@ -92,14 +97,21 @@ def __init__( embedding_modules: Dict[str, str], embedding_padding_modules: List[str], lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, ): self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules # Lazily initialized by create_lora_manager. self._lora_manager: LoRAModelManager - super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, - lora_config, device) + super().__init__( + max_num_seqs, + max_num_batched_tokens, + vocab_size, + lora_config, + device, + max_position_embeddings=max_position_embeddings, + ) @property def is_enabled(self) -> bool: @@ -162,6 +174,7 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: lora = self._lora_model_cls.from_local_checkpoint( lora_request.lora_local_path, expected_lora_modules, + max_position_embeddings=self.max_position_embeddings, lora_model_id=lora_request.lora_int_id, device="cpu", dtype=self.lora_config.lora_dtype, @@ -191,7 +204,7 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: lora_request.lora_int_id) else: dummy_lora = self._lora_manager.create_dummy_lora( - lora_request.lora_int_id, rank, self.embedding_modules) + lora_request.lora_int_id, rank, 1, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora return self._lora_manager.add_lora(dummy_lora) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 4758ca9660083..d03903d206d33 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -61,6 +61,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style + self.dtype = dtype cache = self._compute_cos_sin_cache() cache = cache.to(dtype) @@ -168,6 +169,29 @@ def extra_repr(self) -> str: class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + Credits to the Reddit user /u/kaiokendev """ @@ -183,13 +207,18 @@ def __init__( ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] - self.scaling_factors = scaling_factors + self.scaling_factors: List[float] = scaling_factors # noqa super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) - cache_list = [] + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] for scaling_factor in self.scaling_factors: # NOTE(woosuk): self.max_position_embeddings is the original # maximum length before applying the rope scaling. @@ -203,9 +232,25 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) return torch.cat(cache_list, dim=0) + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 29c76682109c6..ed65d76f7b5b9 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -348,6 +348,8 @@ def __init__( super().__init__() self.config: ChatGLMConfig = config self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index ebdc64e0e220e..f2996c240aaf4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -321,12 +321,8 @@ class LlamaForCausalLM(nn.Module): # LoRA specific attributes supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" ] embedding_modules = { "embed_tokens": "input_embeddings", diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py index c4244f8c77f44..49d2b8d8e21b1 100644 --- a/vllm/transformers_utils/configs/chatglm.py +++ b/vllm/transformers_utils/configs/chatglm.py @@ -46,6 +46,8 @@ def __init__(self, self.kv_channels = kv_channels self.num_attention_heads = num_attention_heads self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.layernorm_epsilon = layernorm_epsilon diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 927cbeed073bf..9614f01d2b955 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -34,12 +34,26 @@ def get_max_input_len(self, """Get the maximum input length for the LoRA request.""" return self.max_input_length + def _raise_if_input_too_long(self, + encoded_tokens: List[str], + lora_request: Optional[LoRARequest] = None): + input_length = len(encoded_tokens) + if lora_request: + max_input_length = (lora_request.long_lora_max_len + or self.max_input_length) + else: + max_input_length = self.max_input_length + if max_input_length is not None and input_length > max_input_length: + raise ValueError("Input too long.", input_length, max_input_length) + def encode(self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret async def encode_async( self, @@ -47,7 +61,9 @@ async def encode_async( request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret def get_lora_tokenizer( self, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cd7af25654b52..e264fede0ee64 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -156,9 +156,15 @@ def load_model(self) -> None: ), "Model does not have embedding_padding_modules" self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.vocab_size, - self.lora_config, self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=self.model.config. + max_position_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): From f68470e803df575f294e67167b4b83adfe004cfa Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 19 May 2024 15:13:33 +0800 Subject: [PATCH 308/413] [Bugfix][Model] Add base class for vision-language models (#4809) --- tests/models/test_registry.py | 9 ++++ vllm/model_executor/model_loader/loader.py | 13 +++--- vllm/model_executor/models/llava.py | 48 +++++++++++----------- vllm/model_executor/models/vlm_base.py | 12 ++++++ 4 files changed, 53 insertions(+), 29 deletions(-) create mode 100644 tests/models/test_registry.py create mode 100644 vllm/model_executor/models/vlm_base.py diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py new file mode 100644 index 0000000000000..547ab10051f1b --- /dev/null +++ b/tests/models/test_registry.py @@ -0,0 +1,9 @@ +import pytest + +from vllm.model_executor.models import _MODELS, ModelRegistry + + +@pytest.mark.parametrize("model_cls", _MODELS) +def test_registry_imports(model_cls): + # Ensure all model classes can be imported successfully + ModelRegistry.load_model_cls(model_cls) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index dc568928b2859..d1ab207549790 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -26,11 +26,7 @@ download_weights_from_hf, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.llava import LlavaForConditionalGeneration - -_VISION_MODEL_CLASSES = [ - LlavaForConditionalGeneration, -] +from vllm.model_executor.models.vlm_base import VisionLanguageModelBase logger = init_logger(__name__) @@ -73,7 +69,12 @@ def _get_model_initialization_kwargs( "but LoRA is enabled. Support for this model may " "be added in the future. If this is important to you, " "please open an issue on github.") - elif model_class in _VISION_MODEL_CLASSES: + elif issubclass(model_class, VisionLanguageModelBase): + if vision_language_config is None: + raise ValueError("Provide `image_input_type` and other vision " + "related configurations through LLM entrypoint " + "or engine arguments.") + extra_kwargs["vision_language_config"] = vision_language_config return extra_kwargs diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3b99b337a2765..e8a5b6237d4db 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -19,6 +19,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .vlm_base import VisionLanguageModelBase + _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", @@ -40,7 +42,7 @@ def __init__(self, vision_hidden_size: int, text_hidden_size: int, text_hidden_size, bias=True) - def forward(self, image_features): + def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -50,29 +52,31 @@ def forward(self, image_features): def _merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: torch.Tensor, - image_token_id: int): + image_token_id: int) -> torch.Tensor: """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) - inputs_embeds[mask] = vision_embeddings.view(-1, + + image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1] + if mask.sum() != image_feature_size: + raise ValueError(f"image_feature_size should be {image_feature_size}, " + f"but found: {mask.sum()}") + + inputs_embeds[mask] = vision_embeddings.view(image_feature_size, vision_embeddings.shape[-1]) + return inputs_embeds -class LlavaForConditionalGeneration(nn.Module): + +class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, - config: "LlavaConfig", + config: LlavaConfig, vision_language_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional["QuantizationConfig"] = None) -> None: - super().__init__() - self.config = config - - self.vision_language_config = vision_language_config + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) - assert self.vision_language_config, ( - "Provide `image_input_type` and other vision " - "related configurations through LLM entrypoint " - "or engine arguments.") + self.config = config if self.vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): @@ -98,14 +102,12 @@ def __init__(self, config.vocab_size, logit_scale) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - image_input: Optional[torch.Tensor] = None - ) -> SamplerOutput: # noqa: E501 + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + image_input: Optional[torch.Tensor] = None) -> SamplerOutput: """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the @@ -172,7 +174,7 @@ def forward( image_features = image_input vision_embeddings = self.multi_modal_projector(image_features) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - _merge_vision_embeddings( + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) input_ids = None diff --git a/vllm/model_executor/models/vlm_base.py b/vllm/model_executor/models/vlm_base.py new file mode 100644 index 0000000000000..eb0aa96e50d59 --- /dev/null +++ b/vllm/model_executor/models/vlm_base.py @@ -0,0 +1,12 @@ +from torch import nn + +from vllm.config import VisionLanguageConfig + + +class VisionLanguageModelBase(nn.Module): + """Base class for all vision language models (VLMs).""" + + def __init__(self, vision_language_config: VisionLanguageConfig) -> None: + super().__init__() + + self.vision_language_config = vision_language_config From 27ce85476e6b170c5c90c65ac5c3268911135766 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Sun, 19 May 2024 11:37:34 -0400 Subject: [PATCH 309/413] [Kernel] Add marlin_24 unit tests (#4901) --- tests/kernels/test_marlin_gemm.py | 87 ++++- .../layers/quantization/gptq_marlin_24.py | 27 +- .../layers/quantization/utils/format_24.py | 308 ++++++++++++++++++ .../quantization/utils/marlin_24_perms.py | 58 ++++ .../layers/quantization/utils/marlin_perms.py | 58 ++++ .../layers/quantization/utils/marlin_utils.py | 214 +++++++----- 6 files changed, 649 insertions(+), 103 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/utils/format_24.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_24_perms.py create mode 100644 vllm/model_executor/layers/quantization/utils/marlin_perms.py diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index b0ad85c25c572..587fc3901eb7c 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -7,23 +7,32 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights) + MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, + marlin_quantize, marlin_weights) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] -K_CHUNKS = [128, 256] -N_CHUNKS = [64, 128, 256] +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 128, 256] + +MARLIN_24_K_CHUNKS = [128] +MARLIN_24_N_CHUNKS = [256] MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), (1, 7, 5), - (1, 7 * 4, 5 * 1), (13, 17, 67), (26, 37, 13), (67, 13, 11), @@ -31,14 +40,13 @@ def rand_data(shape): - data = torch.rand(shape).to(torch.half).cuda() - return data + return torch.randn(shape, dtype=torch.half, device="cuda") @pytest.mark.skipif(not is_marlin_supported(), reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", K_CHUNKS) -@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, @pytest.mark.skipif(not is_marlin_supported(), reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", K_CHUNKS) -@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @@ -136,7 +145,8 @@ def test_marlin_gemm( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, num_bits, group_size, act_order) - workspace = MarlinWorkspace(size_n) + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) output = ops.gptq_marlin_gemm( a_input, @@ -155,4 +165,55 @@ def test_marlin_gemm( torch.cuda.synchronize() - assert torch.allclose(output, output_ref, rtol=1e-2) + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + + workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) + + output_ref = torch.matmul(a_input, w_24_ref) + + output = ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 1bd6127104654..f5345c0443029 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -12,6 +12,15 @@ logger = init_logger(__name__) +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +GPTQ_MARLIN_24_SUPPORTED_SYM = [True] + class GPTQMarlin24Config(QuantizationConfig): """Config class for Marlin24. @@ -25,15 +34,17 @@ def __init__( self.weight_bits = weight_bits self.group_size = group_size - if self.weight_bits != 4 and self.weight_bits != 8: - raise ValueError("weight_bits must be 4 or 8. Got = {}".format( - self.weight_bits)) - - if self.group_size != 128 and self.group_size != -1: + # Verify + if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin_24 does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin24, but got group_size of " - f"{self.group_size}") + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // self.weight_bits diff --git a/vllm/model_executor/layers/quantization/utils/format_24.py b/vllm/model_executor/layers/quantization/utils/format_24.py new file mode 100644 index 0000000000000..01c8cf789204b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/format_24.py @@ -0,0 +1,308 @@ +# +# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). +# + +import torch + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask diff --git a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py new file mode 100644 index 0000000000000..12e77cb710687 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 +# +# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms_24(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single + + +marlin_24_perm = {} +marlin_24_scale_perm = {} +marlin_24_scale_perm_single = {} +for num_bits in [4, 8]: + perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) + marlin_24_perm[num_bits] = perm_24 + marlin_24_scale_perm[num_bits] = scale_perm_24 + marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_perms.py new file mode 100644 index 0000000000000..76bd2ff7c724e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +marlin_perm = {} +marlin_scale_perm = {} +marlin_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = get_perms(num_bits) + marlin_perm[num_bits] = perm + marlin_scale_perm[num_bits] = scale_perm + marlin_scale_perm_single[num_bits] = scale_perm_single diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 33b3169983475..0d027d0620ab3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,79 +1,28 @@ """This file is used for /tests and /benchmarks""" +import random + import numpy import torch -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) +from vllm.model_executor.layers.quantization.utils.format_24 import ( + mask_creator, sparse_semi_structured_from_dense_cutlass) +from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( + marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm, marlin_scale_perm, marlin_scale_perm_single) from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_pack_factor, quantize_weights, sort_weights) __cuda_arch = torch.cuda.get_device_capability() +MARLIN_TILE = 16 + def is_marlin_supported(): return __cuda_arch[0] >= 8 -# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 -# -# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def _get_perms(num_bits): - perm_list = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm = {} -_scale_perm = {} -_scale_perm_single = {} -for num_bits in [4, 8]: - perm, scale_perm, scale_perm_single = _get_perms(num_bits) - _perm[num_bits] = perm - _scale_perm[num_bits] = scale_perm - _scale_perm_single[num_bits] = scale_perm_single - - -def marlin_permute_weights(q_w, - size_k, - size_n, - num_bits, - tile=GPTQ_MARLIN_TILE): +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" @@ -83,15 +32,14 @@ def marlin_permute_weights(q_w, q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.reshape((size_k // tile, size_n * tile)) - q_w = q_w.reshape( - (-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape) + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) return q_w -def marlin_weights(q_w, size_k, size_n, num_bits): +def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits) + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) # Pack pack_factor = get_pack_factor(num_bits) @@ -101,7 +49,6 @@ def marlin_weights(q_w, size_k, size_n, num_bits): q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=numpy.uint32) - for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i @@ -110,15 +57,12 @@ def marlin_weights(q_w, size_k, size_n, num_bits): return q_packed -def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): +def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, + scale_perm_single): if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(_scale_perm[num_bits])))[:, - _scale_perm[num_bits]] + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape( - (-1, - len(_scale_perm_single[num_bits])))[:, - _scale_perm_single[num_bits]] + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s @@ -148,8 +92,11 @@ def marlin_quantize( q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_scale_perm[num_bits], + marlin_scale_perm_single[num_bits]) # Create result res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] @@ -159,15 +106,118 @@ def marlin_quantize( return res_list +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, marlin_24_perm[num_bits]) + marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_24_scale_perm[num_bits], + marlin_24_scale_perm_single[num_bits]) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + class MarlinWorkspace: - def __init__(self, out_features): - assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), ( - "out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}" - .format(out_features, GPTQ_MARLIN_MIN_THREAD_N)) + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) - max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) * - GPTQ_MARLIN_MAX_PARALLEL) + max_workspace_size = ((out_features // min_thread_n) * max_parallel) self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, From b57e6c59491ea7d60af413ad4a6455812b9c6c50 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 19 May 2024 18:11:30 -0700 Subject: [PATCH 310/413] [Kernel] Add flash-attn back (#4907) --- requirements-cuda.txt | 2 +- tests/kernels/test_flash_attn.py | 208 ++++++++++++++++++++++++++ tests/models/test_big_models.py | 2 +- tests/models/test_fp8.py | 10 +- vllm/attention/backends/flash_attn.py | 129 +++++++++------- vllm/attention/selector.py | 14 ++ 6 files changed, 304 insertions(+), 61 deletions(-) create mode 100644 tests/kernels/test_flash_attn.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index ba8c614d205d2..acb0164007dba 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -7,4 +7,4 @@ nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 -vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py new file mode 100644 index 0000000000000..22772d4ea4422 --- /dev/null +++ b/tests/kernels/test_flash_attn.py @@ -0,0 +1,208 @@ +from typing import List, Optional, Tuple + +import pytest +import torch +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_flash_attn_with_paged_kv( + kv_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_cache = torch.randn(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + ).squeeze(1) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@torch.inference_mode +def test_varlen_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window, + sliding_window) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + cu_kv_lens = torch.tensor([0] + kv_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index c02204f16ac68..10e7c64e34e75 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -12,7 +12,7 @@ # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", - "mosaicml/mpt-7b", + # "mosaicml/mpt-7b", # Broken # "Qwen/Qwen1.5-0.5B" # Broken, ] diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index e87a1783a83f1..664e951a89f2a 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -25,18 +25,18 @@ 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', ], "meta-llama/Meta-Llama-3-8B-Instruct": [ 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 856f399741375..0361dd3bd4ead 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,19 +1,15 @@ -"""Attention layer with Flash and PagedAttention. - -NOTE(woosuk): At the moment, this file includes a lot of duplicated code from -XFormers backend. The duplicated code will be removed once we use flash-attn or -flashinfer for all the attention operations. -""" +"""Attention layer with FlashAttention.""" from dataclasses import dataclass from typing import List, Optional, Tuple, Type import torch -from vllm_flash_attn import flash_attn_varlen_func +from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from vllm._C import cache_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) + +_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] class FlashAttentionBackend(AttentionBackend): @@ -37,8 +33,9 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -46,18 +43,26 @@ def swap_blocks( dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: torch.Tensor, ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass -class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): +class FlashAttentionMetadata(AttentionMetadata): """Metadata for FlashAttentionBackend. NOTE: Any python object stored here is not updated when it is @@ -99,6 +104,14 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # so far). context_lens_tensor: Optional[torch.Tensor] + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. @@ -219,11 +232,15 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - suppored_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in suppored_head_sizes: + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + if head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {suppored_head_sizes}.") + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") def forward( self, @@ -234,17 +251,20 @@ def forward( attn_metadata: FlashAttentionMetadata, kv_scale: float = 1.0, ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. + """Forward pass with FlashAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) @@ -252,16 +272,20 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) + key_cache = kv_cache[0] + value_cache = kv_cache[1] # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache(key, value, key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, kv_scale) + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens @@ -281,7 +305,8 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - if kv_cache is None or prefill_meta.block_tables.numel() == 0: + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. @@ -302,38 +327,34 @@ def forward( output[:num_prefill_tokens] = out else: # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - output[:num_prefill_tokens] = PagedAttention.forward_prefix( - query, - key, - value, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, ) + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = PagedAttention.forward_decode( - decode_query, + output[num_prefill_tokens:] = flash_attn_with_kvcache( + decode_query.unsqueeze(1), key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - kv_scale, - ) + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ).squeeze(1) # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 06f99718a4dee..5140c3cc86a31 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -93,6 +93,20 @@ def _which_attn_to_use( "torch.float16 or torch.bfloat16.") return _Backend.XFORMERS + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") + return _Backend.XFORMERS + + if block_size % 16 != 0: + logger.info("Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + return _Backend.XFORMERS + + if sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + return _Backend.XFORMERS + try: import vllm_flash_attn # noqa: F401 except ImportError: From 6287537a0c970bda1fc8b31f2bde1bcf2d26e151 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 20 May 2024 16:11:25 +0800 Subject: [PATCH 311/413] [Model] LLaVA model refactor (#4910) --- vllm/model_executor/models/llava.py | 137 ++++++++++++++++++++++------ 1 file changed, 107 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e8a5b6237d4db..fbd7638097286 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch from torch import nn @@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor, return inputs_embeds +class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, num_channels, height, width)""" + + +class LlavaImageFeatureInputs(TypedDict): + type: Literal["image_features"] + data: torch.Tensor + """Shape: (batch_size, image_feature_size, hidden_size)""" + + +LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs] + + class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, @@ -102,6 +117,90 @@ def __init__(self, config.vocab_size, logit_scale) self.sampler = Sampler() + def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != list( + self.vision_language_config.image_input_shape[1:]): + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"{self.vision_language_config.image_input_shape[1:]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _parse_and_validate_image_input( + self, data: object) -> Optional[LlavaImageInputs]: + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if data is None: + return None + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image pixel vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImagePixelInputs( + type="pixel_values", + data=self._validate_image_data(data), + ) + elif expected_input_type == ImageInputType.IMAGE_FEATURES: + if not isinstance(data, torch.Tensor): + raise TypeError("Image feature vector should be a tensor, " + f"but received type: {type(data)}") + + return LlavaImageFeatureInputs( + type="image_features", + data=self._validate_image_data(data), + ) + + return None + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = vision_tower(pixel_values.to(vision_tower.device), + output_hidden_states=True) + + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _process_image_pixels(self, + inputs: LlavaImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input(self, + image_input: LlavaImageInputs) -> torch.Tensor: + if image_input["type"] == "pixel_values": + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + else: + image_features = image_input["data"] + + return self.multi_modal_projector(image_features) + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, @@ -144,42 +243,20 @@ def forward(self, For PIXEL_VALUES, expecting [1, 3, 336, 336]. For IMAGE_FEATURES, expecting [1, 576, 1024]. """ - if image_input is not None: - if list(image_input.shape[1:]) != list( - self.vision_language_config.image_input_shape[1:]): - raise ValueError( - f"The expected image tensor shape is batch dimension " - f"plus " - f"{self.vision_language_config.image_input_shape[1:]}." - f" You supplied {image_input.shape}. " - f"If you are using vLLM's entrypoint, make sure your " - f"supplied image input is consistent with " - f"image_input_shape in engine args.") - if self.vision_tower is not None: - # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. - image_outputs = self.vision_tower(image_input, - output_hidden_states=True) - image_features = image_outputs.hidden_states[ - self.config.vision_feature_layer] - # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa - if self.config.vision_feature_select_strategy == "default": - image_features = image_features[:, 1:] - elif self.config.vision_feature_select_strategy == "full": - image_features = image_features - else: - raise ValueError( - f"Unexpected select feature strategy: " - f"{self.config.vision_feature_select_strategy}") - else: - image_features = image_input - vision_embeddings = self.multi_modal_projector(image_features) + parsed_image_input = self._parse_and_validate_image_input(image_input) + + if parsed_image_input is not None: + vision_embeddings = self._process_image_input(parsed_image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) + input_ids = None else: inputs_embeds = None + hidden_states = self.language_model(input_ids, positions, kv_caches, From da5a0b539d6a5fe0c0195513a797814d2c267540 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Mon, 20 May 2024 10:55:34 -0400 Subject: [PATCH 312/413] Remove marlin warning (#4918) --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index fdc0ebef4672e..34950a5d13cf5 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1519,10 +1519,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } } - printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM " - "GPU cache. This may " - "hurt performance. Consider upgrading your GPU.\n"); - max_m_blocks--; // Process less M blocks per invocation to reduce cache // usage } From 546a97ef691f242c899a5e0906d0e75f42694e95 Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Tue, 21 May 2024 01:45:06 +0800 Subject: [PATCH 313/413] [Misc]: allow user to specify port in distributed setting (#4914) --- vllm/envs.py | 7 +++++++ vllm/utils.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 68d8a074d0914..56ff79e0cdea9 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -3,6 +3,7 @@ if TYPE_CHECKING: VLLM_HOST_IP: str = "" + VLLM_PORT: Optional[int] = None VLLM_USE_MODELSCOPE: bool = False VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None @@ -96,6 +97,12 @@ 'VLLM_HOST_IP': lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + # used in distributed environment to manually set the communication port + # '0' is used to make mypy happy + 'VLLM_PORT': + lambda: int(os.getenv('VLLM_PORT', '0')) + if 'VLLM_PORT' in os.environ else None, + # If true, will load models from ModelScope instead of Hugging Face Hub. # note that the value is true or false, not numbers "VLLM_USE_MODELSCOPE": diff --git a/vllm/utils.py b/vllm/utils.py index f0e71f5e99b64..552b43e7f82b2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -282,6 +282,9 @@ def get_distributed_init_method(ip: str, port: int) -> str: def get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + return port # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: From 943e72ca56974b4d8b5a141182e717d2abd3a819 Mon Sep 17 00:00:00 2001 From: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Date: Mon, 20 May 2024 13:29:28 -0500 Subject: [PATCH 314/413] [Build/CI] Enabling AMD Entrypoints Test (#4834) Co-authored-by: Alexey Kondratiev --- .buildkite/test-pipeline.yaml | 3 ++- Dockerfile.rocm | 8 ++++++-- requirements-rocm.txt | 3 ++- tests/spec_decode/e2e/conftest.py | 8 ++++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6f5c46e23779f..def8a460e84a7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -60,7 +60,8 @@ steps: command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test - #mirror_hardwares: [amd] + mirror_hardwares: [amd] + commands: # these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index eefad79e79d83..9bfe8446a519d 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -92,19 +92,23 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ WORKDIR /vllm-workspace COPY . . +#RUN python3 -m pip install pynvml # to be removed eventually RUN python3 -m pip install --upgrade pip numba # make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 +# Workaround for ray >= 2.10.0 +ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 + +ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 CMD ["/bin/bash"] diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 903845b64d98f..cc42839a975d0 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,4 +2,5 @@ -r requirements-common.txt # Dependencies for AMD GPUs -ray == 2.9.3 +ray >= 2.10.0 +pytest-asyncio diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index da8b92711380e..7c5840baf3593 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -6,8 +6,12 @@ import pytest import ray import torch -from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit) + +from vllm.utils import is_hip + +if (not is_hip()): + from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs From f0eecee6106774e1e0f9b31c7438cde77654df52 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Mon, 20 May 2024 21:44:25 +0300 Subject: [PATCH 315/413] [Bugfix] Fix dummy weight for fp8 (#4916) Allow dummy load format for fp8, torch.uniform_ doesn't support FP8 at the moment Co-authored-by: Mor Zusman --- vllm/model_executor/model_loader/weight_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c1abde9af7701..a1642baa2c90c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -369,4 +369,11 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): - param.data.uniform_(low, high) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high) From 1937e29848c8de8634c5421612d57863aa0e2a51 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 20 May 2024 14:46:12 -0400 Subject: [PATCH 316/413] [Core] Sharded State Loader download from HF (#4889) --- vllm/model_executor/model_loader/loader.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index d1ab207549790..45ea8160a801b 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -423,6 +423,16 @@ def get_end_ptr(tensor: torch.Tensor) -> int: result[k] = t return result + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf(model_name_or_path, + self.load_config.download_dir, + allow_patterns, revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -433,6 +443,10 @@ def load_model(self, *, model_config: ModelConfig, from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, @@ -440,7 +454,7 @@ def load_model(self, *, model_config: ModelConfig, cache_config) rank = get_tensor_model_parallel_rank() pattern = os.path.join( - model_config.model, + local_model_path, self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern) From c3af44722cff56bba5fc912c8e16d9de02dfb532 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Mon, 20 May 2024 13:16:57 -0700 Subject: [PATCH 317/413] [Doc]Add documentation to benchmarking script when running TGI (#4920) --- benchmarks/benchmark_serving.py | 4 ++++ benchmarks/launch_tgi_server.sh | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 2c2d69da4a7d1..9c3fed4817de2 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -17,6 +17,10 @@ --dataset-path \ --request-rate \ # By default is inf --num-prompts # By default is 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. """ import argparse import asyncio diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index 64d3c4f4b3889..f491c90d0683e 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -4,7 +4,7 @@ PORT=8000 MODEL=$1 TOKENS=$2 -docker run --gpus all --shm-size 1g -p $PORT:80 \ +docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ ghcr.io/huggingface/text-generation-inference:1.4.0 \ --model-id $MODEL \ From 65ae8c2c8f52e0c98e4e26ad1255772d888592a6 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 20 May 2024 17:48:32 -0700 Subject: [PATCH 318/413] [Core] Fix scheduler considering "no LoRA" as "LoRA" (#4897) --- vllm/core/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c8da54f2889eb..7c70b1b244f7d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -744,8 +744,8 @@ def _schedule_default(self) -> SchedulerOutputs: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) From d130b573a0162173002b97e2112c6c1c10d0ca8e Mon Sep 17 00:00:00 2001 From: HUANG Fei Date: Tue, 21 May 2024 13:22:22 +0800 Subject: [PATCH 319/413] [Model] add rope_scaling support for qwen2 (#4930) --- vllm/model_executor/models/qwen2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 31ba6441f9f7a..97ab6168c3230 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -89,7 +89,8 @@ def __init__(self, use_sliding_window: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + sliding_window: Optional[int] = None, + rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -133,6 +134,7 @@ def __init__(self, rotary_dim=self.head_dim, max_position=max_position, base=self.rope_theta, + rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -169,6 +171,7 @@ def __init__( self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) use_sliding_window = (config.use_sliding_window and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( @@ -180,7 +183,8 @@ def __init__( use_sliding_window=use_sliding_window, cache_config=cache_config, quant_config=quant_config, - sliding_window=config.sliding_window) + sliding_window=config.sliding_window, + rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, From f12c3b5b3d076a67662b76d215fd875fd6cdf6d7 Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Tue, 21 May 2024 13:24:17 +0800 Subject: [PATCH 320/413] [Model] Add Phi-2 LoRA support (#4886) --- docs/source/models/supported_models.rst | 2 +- tests/lora/conftest.py | 5 ++ tests/lora/test_phi.py | 67 +++++++++++++++++++++++++ vllm/model_executor/models/phi.py | 33 +++++++++--- 4 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 tests/lora/test_phi.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 142c8f8573e2f..31d4b53bd4409 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -118,7 +118,7 @@ Alongside each architecture, we include some popular models that use it. * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - - + - ✅︎ * - :code:`Phi3ForCausalLM` - Phi-3 - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 5c648f72d8ddd..95fc65cdd1a8f 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -165,6 +165,11 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") +@pytest.fixture(scope="session") +def phi2_lora_files(): + return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") + + @pytest.fixture(scope="session") def long_context_lora_files_16k_1(): return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py new file mode 100644 index 0000000000000..a2b42ce4cb96f --- /dev/null +++ b/tests/lora/test_phi.py @@ -0,0 +1,67 @@ +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "microsoft/phi-2" + +PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + PROMPT_TEMPLATE.format( + sql_prompt= + "Which catalog publisher has published the most catalogs?", + context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), + PROMPT_TEMPLATE.format( + sql_prompt= + "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 + context= + "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + sql_prompt= + "How many marine species are found in the Southern Ocean?", # noqa: E501 + context= + "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=64, + stop="### End") + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_phi2_lora(phi2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=2, + enforce_eager=True) + + expected_lora_output = [ + "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 + "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 + "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 + ] + + output1 = do_sample(llm, phi2_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) + output2 = do_sample(llm, phi2_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i].startswith(expected_lora_output[i]) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index ed25a232f4208..193a29d20c894 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -42,7 +42,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -229,11 +229,32 @@ def forward( class PhiForCausalLM(nn.Module): - - def __init__(self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "dense", + "fc1", + "fc2", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + del lora_config # Unused. super().__init__() self.config = config self.quant_config = quant_config From e941f885843d4bcd239f805a9267729e9631556f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 21 May 2024 00:17:25 -0700 Subject: [PATCH 321/413] [Docs] Add acknowledgment for sponsors (#4925) --- README.md | 25 +++++++++++++++++++++++++ docs/source/community/sponsors.md | 24 ++++++++++++++++++++++++ docs/source/index.rst | 1 + 3 files changed, 50 insertions(+) create mode 100644 docs/source/community/sponsors.md diff --git a/README.md b/README.md index fc3f71b00c3c5..627e45d4de0c4 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,31 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more. We welcome and value any contributions and collaborations. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. +## Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Trainy +- UC Berkeley +- UC San Diego + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. + ## Citation If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): diff --git a/docs/source/community/sponsors.md b/docs/source/community/sponsors.md new file mode 100644 index 0000000000000..532ce77beb7b8 --- /dev/null +++ b/docs/source/community/sponsors.md @@ -0,0 +1,24 @@ +# Sponsors + +vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support! + + + + +- a16z +- AMD +- Anyscale +- AWS +- Crusoe Cloud +- Databricks +- DeepInfra +- Lambda Lab +- NVIDIA +- Replicate +- Roblox +- RunPod +- Trainy +- UC Berkeley +- UC San Diego + +We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index bab00e28e4018..5db1c9346c45d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -118,6 +118,7 @@ Documentation :caption: Community community/meetups + community/sponsors Indices and tables ================== From 757b62c49560baa6f294310a53032348a0d95939 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 21 May 2024 12:06:10 -0400 Subject: [PATCH 322/413] [CI/Build] Codespell ignore `build/` directory (#4945) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1c61a9e955b61..96f78c37cfefb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies" -skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data" +skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true From 14772eeb8e8ec76e5e70142d12a7332fcec28ccb Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Wed, 22 May 2024 00:30:52 +0800 Subject: [PATCH 323/413] [Bugfix] Fix flag name for `max_seq_len_to_capture` (#4935) Signed-off-by: kerthcet --- vllm/engine/arg_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d0edf0a75b710..1ba424c4eeb14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -341,9 +341,9 @@ def add_cli_args( help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode. ' - '(DEPRECATED. Use --max-seq_len-to-capture instead' + '(DEPRECATED. Use --max-seq-len-to-capture instead' ')') - parser.add_argument('--max-seq_len-to-capture', + parser.add_argument('--max-seq-len-to-capture', type=int, default=EngineArgs.max_seq_len_to_capture, help='Maximum sequence length covered by CUDA ' From 99eff67ba9155b5fec9a9abd939e3a29a1b42dce Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Wed, 22 May 2024 03:33:25 +0800 Subject: [PATCH 324/413] [Bugfix][Kernel] Add head size check for attention backend selection (#4944) --- vllm/attention/backends/flash_attn.py | 12 ++++++++---- vllm/attention/selector.py | 16 +++++++++++++--- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0361dd3bd4ead..0f4568070cfc4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -9,11 +9,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256] - class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + @staticmethod def get_name() -> str: return "flash-attn" @@ -237,10 +239,12 @@ def __init__( # paged KV cache. raise ValueError( "Sliding window is not supported in FlashAttention.") - if head_size not in _SUPPORTED_HEAD_SIZES: + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.") + f"Supported head sizes are: {support_head_sizes}.") def forward( self, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 5140c3cc86a31..51c25a81b4130 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -34,11 +34,21 @@ def get_attn_backend( sliding_window, dtype, kv_cache_dtype, block_size) if backend == _Backend.FLASH_ATTN: - logger.info("Using FlashAttention-2 backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - return FlashAttentionBackend - elif backend == _Backend.XFORMERS: + + # We check it here not in _which_attn_to_use because we cannot know + # the head size until we import FlashAttentionBackend. + supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size in supported_head_sizes: + logger.info("Using FlashAttention-2 backend.") + return FlashAttentionBackend + logger.info( + "Cannot use FlashAttention-2 backend for head size %d. " + "Using XFormers backend instead.", head_size) + backend = _Backend.XFORMERS + + if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 XFormersBackend) From 9b9a10d6cb89f18e054daa66f25cb8f17c723b2c Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 22 May 2024 05:32:35 +0000 Subject: [PATCH 325/413] [Frontend] Dynamic RoPE scaling (#4638) --- tests/test_config.py | 56 ++++++++++++++++++++++++++++++- vllm/config.py | 7 +++- vllm/engine/arg_utils.py | 18 +++++++--- vllm/engine/llm_engine.py | 10 +++--- vllm/transformers_utils/config.py | 10 +++++- 5 files changed, 89 insertions(+), 12 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 19db10630bbae..6bc51a53dc07c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -36,4 +36,58 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() is None mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW \ No newline at end of file + assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW + + +def test_rope_scaling(): + TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None + assert llama_model_config.max_model_len == 8192 + + llama_model_config = ModelConfig( + "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(llama_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert llama_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == LONGCHAT_ROPE_SCALING + assert longchat_model_config.max_model_len == 16384 + + longchat_model_config = ModelConfig( + "lmsys/longchat-13b-16k", + "lmsys/longchat-13b-16k", + tokenizer_mode="auto", + trust_remote_code=False, + dtype="float16", + seed=0, + rope_scaling=TEST_ROPE_SCALING, + ) + assert getattr(longchat_model_config.hf_config, "rope_scaling", + None) == TEST_ROPE_SCALING + assert longchat_model_config.max_model_len == 4096 diff --git a/vllm/config.py b/vllm/config.py index 44ed5635f9a35..3256c11967914 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -45,6 +45,9 @@ class ModelConfig: code_revision: The specific revision to use for the model code on Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. @@ -84,6 +87,7 @@ def __init__( seed: int, revision: Optional[str] = None, code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, @@ -102,6 +106,7 @@ def __init__( self.seed = seed self.revision = revision self.code_revision = code_revision + self.rope_scaling = rope_scaling self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.quantization_param_path = quantization_param_path @@ -116,7 +121,7 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision) + code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1ba424c4eeb14..0a9ec7472fbca 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,5 +1,6 @@ import argparse import dataclasses +import json from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -49,6 +50,7 @@ class EngineArgs: disable_log_stats: bool = False revision: Optional[str] = None code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False @@ -330,6 +332,11 @@ def add_cli_args( 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"type":"dynamic","factor":2.0}') parser.add_argument('--enforce-eager', action='store_true', help='Always use eager-mode PyTorch. If False, ' @@ -548,11 +555,12 @@ def create_engine_config(self, ) -> EngineConfig: model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.dtype, self.seed, self.revision, - self.code_revision, self.tokenizer_revision, self.max_model_len, - self.quantization, self.quantization_param_path, - self.enforce_eager, self.max_context_len_to_capture, - self.max_seq_len_to_capture, self.max_logprobs, - self.skip_tokenizer_init, self.served_model_name) + self.code_revision, self.rope_scaling, self.tokenizer_revision, + self.max_model_len, self.quantization, + self.quantization_param_path, self.enforce_eager, + self.max_context_len_to_capture, self.max_seq_len_to_capture, + self.max_logprobs, self.skip_tokenizer_init, + self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f6a5284093c1c..60e23d4df15bb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -104,10 +104,11 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " - "max_seq_len=%d, download_dir=%r, load_format=%s, " - "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " - "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " + "rope_scaling=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, @@ -117,6 +118,7 @@ def __init__( model_config.skip_tokenizer_init, model_config.tokenizer_mode, model_config.revision, + model_config.rope_scaling, model_config.tokenizer_revision, model_config.trust_remote_code, model_config.dtype, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1756c91a612f0..f36d84dbdf7f9 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,9 +2,12 @@ from transformers import AutoConfig, PretrainedConfig +from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, JAISConfig, MPTConfig, RWConfig) +logger = init_logger(__name__) + _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -18,7 +21,8 @@ def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None, - code_revision: Optional[str] = None) -> PretrainedConfig: + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None) -> PretrainedConfig: try: config = AutoConfig.from_pretrained( model, @@ -41,6 +45,10 @@ def get_config(model: str, config = config_class.from_pretrained(model, revision=revision, code_revision=code_revision) + if rope_scaling is not None: + logger.info("Updating rope_scaling from %r to %r", + getattr(config, "rope_scaling", None), rope_scaling) + config.update({"rope_scaling": rope_scaling}) return config From 5f6d10c14c17122e6d711a4829ee0ca672e07f6f Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 22 May 2024 03:18:41 -0400 Subject: [PATCH 326/413] [CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722) --- .clang-format | 26 + .github/workflows/clang-format.yml | 42 + csrc/activation_kernels.cu | 139 +- csrc/attention/attention_generic.cuh | 19 +- csrc/attention/attention_kernels.cu | 636 ++- csrc/attention/attention_utils.cuh | 11 +- csrc/attention/dtype_bfloat16.cuh | 74 +- csrc/attention/dtype_float16.cuh | 92 +- csrc/attention/dtype_float32.cuh | 88 +- csrc/attention/dtype_fp8.cuh | 32 +- csrc/cache.h | 44 +- csrc/cache_kernels.cu | 288 +- csrc/cpu/activation.cpp | 60 +- csrc/cpu/attention.cpp | 411 +- csrc/cpu/cache.cpp | 53 +- csrc/cpu/layernorm.cpp | 32 +- csrc/cpu/pos_encoding.cpp | 66 +- csrc/cpu/pybind.cpp | 75 +- csrc/cuda_compat.h | 9 +- csrc/cuda_utils.h | 7 +- csrc/cuda_utils_kernels.cu | 40 +- csrc/custom_all_reduce.cu | 55 +- csrc/custom_all_reduce.cuh | 105 +- csrc/custom_all_reduce_test.cu | 38 +- csrc/dispatch_utils.h | 42 +- csrc/layernorm_kernels.cu | 242 +- csrc/moe/moe_ops.cpp | 3 +- csrc/moe/moe_ops.h | 8 +- csrc/moe_align_block_size_kernels.cu | 211 +- csrc/ops.h | 330 +- csrc/pos_encoding_kernels.cu | 229 +- csrc/pybind.cpp | 142 +- csrc/quantization/aqlm/gemm_kernels.cu | 536 +-- csrc/quantization/awq/dequantize.cuh | 138 +- csrc/quantization/awq/gemm_kernels.cu | 611 +-- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 38 +- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 22 +- .../cutlass_w8a8/scaled_mm_dq_entry.cu | 47 +- csrc/quantization/fp8/amd/hip_float8.h | 216 +- csrc/quantization/fp8/amd/hip_float8_impl.h | 520 +-- csrc/quantization/fp8/amd/quant_utils.cuh | 711 ++-- csrc/quantization/fp8/common.cu | 87 +- csrc/quantization/fp8/nvidia/quant_utils.cuh | 138 +- csrc/quantization/gptq/compat.cuh | 70 +- csrc/quantization/gptq/matrix_view.cuh | 503 +-- csrc/quantization/gptq/q_gemm.cu | 3441 ++++++++--------- csrc/quantization/gptq/qdq_2.cuh | 107 +- csrc/quantization/gptq/qdq_3.cuh | 246 +- csrc/quantization/gptq/qdq_4.cuh | 203 +- csrc/quantization/gptq/qdq_8.cuh | 34 +- csrc/quantization/gptq/qdq_util.cuh | 58 +- csrc/quantization/gptq_marlin/gptq_marlin.cu | 696 ++-- csrc/quantization/gptq_marlin/gptq_marlin.cuh | 50 +- .../gptq_marlin/gptq_marlin_dtypes.cuh | 89 +- .../gptq_marlin/gptq_marlin_repack.cu | 94 +- .../marlin/dense/marlin_cuda_kernel.cu | 460 ++- csrc/quantization/marlin/sparse/common/base.h | 12 +- csrc/quantization/marlin/sparse/common/mem.h | 64 +- csrc/quantization/marlin/sparse/common/mma.h | 107 +- .../marlin/sparse/marlin_24_cuda_kernel.cu | 446 ++- .../squeezellm/quant_cuda_kernel.cu | 63 +- csrc/reduction_utils.cuh | 20 +- format.sh | 57 +- requirements-dev.txt | 1 + 64 files changed, 6571 insertions(+), 6963 deletions(-) create mode 100644 .clang-format create mode 100644 .github/workflows/clang-format.yml diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000000..7f9e6d720fae5 --- /dev/null +++ b/.clang-format @@ -0,0 +1,26 @@ +BasedOnStyle: Google +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +# Force pointers to the type for C++. +DerivePointerAlignment: false +PointerAlignment: Left + +# Reordering #include statements can (and currently will) introduce errors +SortIncludes: false + +# Style choices +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +IndentPPDirectives: BeforeHash + +IncludeCategories: + - Regex: '^<' + Priority: 4 + - Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/' + Priority: 3 + - Regex: '^"(qoda|\.\.)/' + Priority: 2 + - Regex: '.*' + Priority: 1 diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml new file mode 100644 index 0000000000000..e9b6e28fa6bcb --- /dev/null +++ b/.github/workflows/clang-format.yml @@ -0,0 +1,42 @@ +name: clang-format + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + clang-format: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install clang-format==18.1.5 + - name: Running clang-format + run: | + EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' + ) + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ + | xargs clang-format --dry-run --Werror \ No newline at end of file diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 24d972702c858..867f63f12de4b 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -10,11 +10,11 @@ namespace vllm { // Activation and gating kernel template. -template +template __global__ void act_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); @@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel( } } -template +template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template +template __device__ __forceinline__ T gelu_kernel(const T& x) { // Equivalent to PyTorch GELU with 'none' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 - const float f = (float) x; + const float f = (float)x; constexpr float ALPHA = M_SQRT1_2; - return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); } -template +template __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Equivalent to PyTorch GELU with 'tanh' approximation. // Refer to: // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 - const float f = (float) x; + const float f = (float)x; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float KAPPA = 0.044715; float x_cube = f * f * f; float inner = BETA * (f + KAPPA * x_cube); - return (T) (0.5f * f * (1.0f + ::tanhf(inner))); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); } -} // namespace vllm +} // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "act_and_mul_kernel", \ - [&] { \ - vllm::act_and_mul_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); - -void silu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); } -void gelu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); } -void gelu_tanh_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); } @@ -96,11 +90,11 @@ void gelu_tanh_and_mul( namespace vllm { // Element-wise activation kernel template. -template +template __global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); @@ -108,54 +102,49 @@ __global__ void activation_kernel( } } -} // namespace vllm +} // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template +template __device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template +template __device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } -} // namespace vllm +} // namespace vllm -void gelu_new( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh index 31fb401cbe2c1..62409c0cce93e 100644 --- a/csrc/attention/attention_generic.cuh +++ b/csrc/attention/attention_generic.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -22,31 +23,31 @@ namespace vllm { // A vector type to store Q, K, V elements. -template +template struct Vec {}; // A vector type to store FP32 accumulators. -template +template struct FloatVec {}; // Template vector operations. -template +template inline __device__ Acc mul(A a, B b); -template +template inline __device__ float sum(T v); -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ float dot(T a, T b) { return sum(mul(a, b)); } -template +template inline __device__ void zero(T& dst) { constexpr int WORDS = sizeof(T) / 4; union { @@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) { dst = tmp.raw; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 41b337dd91d36..d6203174e7275 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -27,15 +28,15 @@ #ifdef USE_ROCM #include #include "../quantization/fp8/amd/quant_utils.cuh" - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #else #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif #ifndef USE_ROCM -#define WARP_SIZE 32 + #define WARP_SIZE 32 #else -#define WARP_SIZE warpSize + #define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -45,7 +46,7 @@ namespace vllm { // Utility function for attention softmax. -template +template inline __device__ float block_sum(float* red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; @@ -82,31 +83,28 @@ inline __device__ float block_sum(float* red_smem, float sum) { // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -118,22 +116,29 @@ __device__ void paged_attention_kernel( } const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -143,13 +148,14 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; @@ -163,18 +169,21 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; @@ -193,44 +202,50 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } else { // Vector conversion from Quant_vec to K_vec. Quant_vec k_vec_quant = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - k_vecs[j] = fp8::scaled_convert(k_vec_quant, kv_scale); + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, kv_scale); } } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; @@ -285,13 +300,12 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -304,7 +318,8 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -315,18 +330,21 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -337,14 +355,17 @@ __device__ void paged_attention_kernel( if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { v_vec = *reinterpret_cast(v_ptr + offset); } else { - V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, kv_scale); + v_vec = fp8::scaled_convert(v_quant_vec, + kv_scale); } if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { @@ -367,8 +388,8 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. __syncthreads(); // Perform reduction across warps. @@ -405,9 +426,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -419,79 +440,75 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE> +template __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - typename cache_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - vllm::Fp8KVCacheDataType KV_DTYPE, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - const float kv_scale) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride, kv_scale); + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float kv_scale) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, kv_scale); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -499,9 +516,11 @@ __global__ void paged_attention_v2_reduce_kernel( const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -520,8 +539,9 @@ __global__ void paged_attention_v2_reduce_kernel( // Load max logits to shared memory. float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -550,9 +570,11 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -565,61 +587,45 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } from_float(out_ptr[i], acc); } } -} // namespace vllm - -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ - vllm::paged_attention_v1_kernel<<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel< \ + T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + kv_scale); // TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - vllm::Fp8KVCacheDataType KV_DTYPE, - int NUM_THREADS = 128> +template void paged_attention_v1_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -632,9 +638,10 @@ void paged_attention_v1_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -644,7 +651,8 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len @@ -683,19 +691,10 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -716,74 +715,45 @@ void paged_attention_v1_launcher( } void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V1_LAUNCHER_BLOCK_SIZE) -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - seq_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, \ - kv_scale); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - seq_lens_ptr, \ - max_num_partitions); - -template< - typename T, - typename CACHE_T, - int BLOCK_SIZE, - vllm::Fp8KVCacheDataType KV_DTYPE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale){ + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE)} +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, kv_scale); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int max_seq_len, - const c10::optional& alibi_slopes, - float kv_scale) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -796,9 +766,10 @@ void paged_attention_v2_launcher( assert(head_size % thread_group_size == 0); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -853,59 +824,50 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - seq_lens, \ - max_seq_len, \ - alibi_slopes, \ - kv_scale); +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + kv_scale); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v2( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& seq_lens, // [num_seqs] - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale) { - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) } #undef WARP_SIZE diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index ff64c4bd8f80c..cdcee42748998 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -26,7 +27,7 @@ namespace vllm { // Q*K^T operation. -template +template inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { using A_vec = typename FloatVec::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). @@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { return qk; } -template +template struct Qk_dot { - template + template static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { return qk_dot_(q, k); } }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 31e0cee01d2e1..3cdcb95e08099 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -28,8 +30,8 @@ #include #include - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; #endif #include @@ -50,37 +52,37 @@ struct bf16_8_t { }; // BF16 vector types for Q, K, V. -template<> +template <> struct Vec<__nv_bfloat16, 1> { using Type = __nv_bfloat16; }; -template<> +template <> struct Vec<__nv_bfloat16, 2> { using Type = __nv_bfloat162; }; -template<> +template <> struct Vec<__nv_bfloat16, 4> { using Type = bf16_4_t; }; -template<> +template <> struct Vec<__nv_bfloat16, 8> { using Type = bf16_8_t; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec<__nv_bfloat16> { using Type = float; }; -template<> +template <> struct FloatVec<__nv_bfloat162> { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { assert(false); #else #ifndef USE_ROCM - return a + b; + return a + b; #else - return __hadd(a, b); + return __hadd(a, b); #endif #endif } @@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); @@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { #endif } -template<> +template <> inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); } -template<> +template <> inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { bf16_4_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_4_t c; @@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { bf16_8_t c; c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); @@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); bf16_8_t c; @@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { return c; } -template<> +template <> inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { float fa = __bfloat162float(a); float fb = __bfloat162float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { float2 fa = bf1622float2(a); float2 fb = bf1622float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { return mul(bf162bf162(a), b); } -template<> +template <> inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { __nv_bfloat162 s = bf162bf162(a); Float4_ fc; @@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { __nv_bfloat162 s = bf162bf162(a); Float8_ fc; @@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { } // Vector fused multiply-add. -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf #endif } -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else @@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(__nv_bfloat16 v) { return __bfloat162float(v); } -template<> +template <> inline __device__ float sum(__nv_bfloat162 v) { float2 vf = bf1622float2(v); return vf.x + vf.y; } -template<> +template <> inline __device__ float sum(bf16_4_t v) { return sum(v.x) + sum(v.y); } -template<> +template <> inline __device__ float sum(bf16_8_t v) { return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); } @@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) { #endif } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index d3271e69cd69d..3a1815f0ed4fc 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -30,37 +32,37 @@ namespace vllm { // FP16 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = uint16_t; }; -template<> +template <> struct Vec { using Type = uint32_t; }; -template<> +template <> struct Vec { using Type = uint2; }; -template<> +template <> struct Vec { using Type = uint4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = Float4_; }; -template<> +template <> struct FloatVec { using Type = Float8_; }; @@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) { return b; #else union { - uint32_t u32; - uint16_t u16[2]; + uint32_t u32; + uint16_t u16[2]; } tmp; tmp.u16[0] = a; tmp.u16[1] = a; @@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) { } tmp; #ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif #else tmp.u16[0] = float_to_half(f.x); @@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { } // Vector multiplication. -template<> +template <> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { uint16_t c; #ifndef USE_ROCM @@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { uint32_t c; #ifndef USE_ROCM @@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { return c; } -template<> +template <> inline __device__ uint32_t mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ uint2 mul(uint2 a, uint2 b) { uint2 c; c.x = mul(a.x, b.x); @@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) { return c; } -template<> +template <> inline __device__ uint2 mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); uint2 c; @@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint4 a, uint4 b) { uint4 c; c.x = mul(a.x, b.x); @@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } -template<> +template <> inline __device__ uint4 mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); uint4 c; @@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) { return c; } -template<> +template <> inline __device__ float mul(uint16_t a, uint16_t b) { float fa = half_to_float(a); float fb = half_to_float(b); return fa * fb; } -template<> +template <> inline __device__ float2 mul(uint32_t a, uint32_t b) { float2 fa = half2_to_float2(a); float2 fb = half2_to_float2(b); return mul(fa, fb); } -template<> +template <> inline __device__ float2 mul(uint16_t a, uint32_t b) { return mul(h0_h0(a), b); } -template<> +template <> inline __device__ Float4_ mul(uint2 a, uint2 b) { Float4_ fc; fc.x = mul(a.x, b.x); @@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float4_ mul(uint16_t a, uint2 b) { uint32_t s = h0_h0(a); Float4_ fc; @@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint4 a, uint4 b) { Float8_ fc; fc.x = mul(a.x, b.x); @@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) { return fc; } -template<> +template <> inline __device__ Float8_ mul(uint16_t a, uint4 b) { uint32_t s = h0_h0(a); Float8_ fc; @@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; #ifndef USE_ROCM - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); #else - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); #endif return d; } @@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { } // Vector sum. -template<> +template <> inline __device__ float sum(uint16_t v) { return half_to_float(v); } -template<> +template <> inline __device__ float sum(uint32_t v) { float2 tmp = half2_to_float2(v); return tmp.x + tmp.y; } -template<> +template <> inline __device__ float sum(uint2 v) { uint32_t c = add(v.x, v.y); return sum(c); } -template<> +template <> inline __device__ float sum(uint4 v) { uint32_t c = add(v.x, v.y); c = add(c, v.z); @@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) { } // From float16 to float32. -inline __device__ float to_float(uint16_t u) { - return half_to_float(u); -} +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } -inline __device__ float2 to_float(uint32_t u) { - return half2_to_float2(u); -} +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } inline __device__ Float4_ to_float(uint2 u) { Float4_ tmp; @@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) { } // Zero-out a variable. -inline __device__ void zero(uint16_t& dst) { - dst = uint16_t(0); -} +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh index b200d2d226eb0..7c6a686db3ba9 100644 --- a/csrc/attention/dtype_float32.cuh +++ b/csrc/attention/dtype_float32.cuh @@ -1,6 +1,8 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp - * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -38,37 +40,35 @@ struct Float8_ { }; // FP32 vector types for Q, K, V. -template<> +template <> struct Vec { using Type = float; }; -template<> +template <> struct Vec { using Type = float2; }; -template<> +template <> struct Vec { using Type = float4; }; // FP32 accumulator vector types corresponding to Vec. -template<> +template <> struct FloatVec { using Type = float; }; -template<> +template <> struct FloatVec { using Type = float2; }; -template<> +template <> struct FloatVec { using Type = float4; }; // Vector addition. -inline __device__ float add(float a, float b) { - return a + b; -} +inline __device__ float add(float a, float b) { return a + b; } inline __device__ float2 add(float2 a, float2 b) { float2 c; @@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) { } // Vector multiplication. -template<> +template <> inline __device__ float mul(float a, float b) { return a * b; } -template<> +template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; c.x = a.x * b.x; @@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) { return c; } -template<> +template <> inline __device__ float2 mul(float a, float2 b) { float2 c; c.x = a * b.x; @@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) { return c; } -template<> +template <> inline __device__ float4 mul(float4 a, float4 b) { float4 c; c.x = a.x * b.x; @@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) { return c; } -template<> +template <> inline __device__ float4 mul(float a, float4 b) { float4 c; c.x = a * b.x; @@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) { } // Vector fused multiply-add. -inline __device__ float fma(float a, float b, float c) { - return a * b + c; -} +inline __device__ float fma(float a, float b, float c) { return a * b + c; } inline __device__ float2 fma(float2 a, float2 b, float2 c) { float2 d; @@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { } // Vector sum. -template<> +template <> inline __device__ float sum(float v) { return v; } -template<> +template <> inline __device__ float sum(float2 v) { return v.x + v.y; } -template<> +template <> inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } -template<> +template <> inline __device__ float sum(Float4_ v) { return v.x.x + v.x.y + v.y.x + v.y.y; } -template<> +template <> inline __device__ float sum(Float8_ v) { return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; } // Vector dot product. -inline __device__ float dot(float a, float b) { - return a * b; -} +inline __device__ float dot(float a, float b) { return a * b; } inline __device__ float dot(float2 a, float2 b) { float2 c = mul(a, b); @@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) { } // From float to float. -inline __device__ void from_float(float& dst, float src) { - dst = src; -} +inline __device__ void from_float(float& dst, float src) { dst = src; } -inline __device__ void from_float(float2& dst, float2 src) { - dst = src; -} +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } -inline __device__ void from_float(float4& dst, float4 src) { - dst = src; -} +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } // From float to float. -inline __device__ float to_float(float u) { - return u; -} +inline __device__ float to_float(float u) { return u; } -inline __device__ float2 to_float(float2 u) { - return u; -} +inline __device__ float2 to_float(float2 u) { return u; } -inline __device__ float4 to_float(float4 u) { - return u; -} +inline __device__ float4 to_float(float4 u) { return u; } -inline __device__ Float4_ to_float(Float4_ u) { - return u; -} +inline __device__ Float4_ to_float(Float4_ u) { return u; } -inline __device__ Float8_ to_float(Float8_ u) { - return u; -} +inline __device__ Float8_ to_float(Float8_ u) { return u; } // Zero-out a variable. -inline __device__ void zero(float& dst) { - dst = 0.f; -} +inline __device__ void zero(float& dst) { dst = 0.f; } -} // namespace vllm +} // namespace vllm diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh index 2b32ce372a64f..e714e321b0beb 100644 --- a/csrc/attention/dtype_fp8.cuh +++ b/csrc/attention/dtype_fp8.cuh @@ -4,38 +4,38 @@ #include #ifdef ENABLE_FP8 -#ifndef USE_ROCM -#include -#endif // USE_ROCM -#endif // ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 namespace vllm { enum class Fp8KVCacheDataType { - kAuto = 0, - kFp8E4M3 = 1, - kFp8E5M2 = 2, + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, }; // fp8 vector types for quantization of kv cache -template<> +template <> struct Vec { - using Type = uint8_t; + using Type = uint8_t; }; -template<> +template <> struct Vec { - using Type = uint16_t; + using Type = uint16_t; }; -template<> +template <> struct Vec { - using Type = uint32_t; + using Type = uint32_t; }; -template<> +template <> struct Vec { - using Type = uint2; + using Type = uint2; }; -} // namespace vllm +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index 8c176c452425e..435ae3e57f555 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -5,36 +5,24 @@ #include #include -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const torch::Tensor& block_mapping); +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const torch::Tensor& block_mapping); +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping); -void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const float kv_scale); +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, const float kv_scale); -void reshape_and_cache_flash( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); // Just for unittest -void convert_fp8( - torch::Tensor& dst_cache, - torch::Tensor& src_cache, - const float scale, - const std::string& kv_cache_dtype); +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const float scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index e5b74da6ad068..d924ac39b89ca 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -6,9 +6,9 @@ #include "dispatch_utils.h" #ifdef USE_ROCM -#include "quantization/fp8/amd/quant_utils.cuh" + #include "quantization/fp8/amd/quant_utils.cuh" #else -#include "quantization/fp8/nvidia/quant_utils.cuh" + #include "quantization/fp8/nvidia/quant_utils.cuh" #endif #include @@ -18,20 +18,17 @@ #ifdef USE_ROCM #include - typedef __hip_bfloat16 __nv_bfloat16; +typedef __hip_bfloat16 __nv_bfloat16; #endif -void swap_blocks( - torch::Tensor& src, - torch::Tensor& dst, - const torch::Tensor& block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; if (src_device.is_cuda() && dst_device.is_cuda()) { - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); memcpy_type = cudaMemcpyDeviceToDevice; } else if (src_device.is_cuda() && dst_device.is_cpu()) { memcpy_type = cudaMemcpyDeviceToHost; @@ -41,16 +38,17 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - // NOTE(youkaichao): keep in mind that `block_mapping` should be + // NOTE(youkaichao): keep in mind that `block_mapping` should be // a cpu tensor, otherwise every `item` call will require a gpu-cpu // synchronization. TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); - char *src_ptr = static_cast(src.data_ptr()); - char *dst_ptr = static_cast(dst.data_ptr()); + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); - const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. const int64_t num_blocks = block_mapping.size(0); @@ -59,29 +57,25 @@ void swap_blocks( int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; - cudaMemcpyAsync( - dst_ptr + dst_offset, - src_ptr + src_offset, - block_size_in_bytes, - memcpy_type, - stream); + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); } } namespace vllm { // Grid: (num_layers, num_pairs) -template -__global__ void copy_blocks_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -99,12 +93,11 @@ __global__ void copy_blocks_kernel( } } -} // namespace vllm +} // namespace vllm -void copy_blocks( - std::vector& key_caches, - std::vector& value_caches, - const torch::Tensor& block_mapping) { +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -118,8 +111,10 @@ void copy_blocks( int64_t key_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers]; for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { - key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); - value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); } // block_mapping is a 2D tensor with shape (num_pairs, 2). @@ -127,10 +122,12 @@ void copy_blocks( // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. - torch::Tensor key_cache_ptrs_tensor = torch::from_blob( - key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor value_cache_ptrs_tensor = torch::from_blob( - value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -139,31 +136,28 @@ void copy_blocks( const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( - key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { - vllm::copy_blocks_kernel<<>>( - key_cache_ptrs_tensor.data_ptr(), - value_cache_ptrs_tensor.data_ptr(), - block_mapping.data_ptr(), - numel_per_block); - })); + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); } namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x, - const float kv_scale) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, + const float kv_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -184,40 +178,39 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_value = value[src_value_idx]; if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { key_cache[tgt_key_idx] = tgt_key; value_cache[tgt_value_idx] = tgt_value; } else { - key_cache[tgt_key_idx] = fp8::scaled_convert(tgt_key, kv_scale); - value_cache[tgt_value_idx] = fp8::scaled_convert(tgt_value, kv_scale); + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, kv_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, kv_scale); } } } -template +template __global__ void reshape_and_cache_flash_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] - scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int block_stride, - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size) { + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, + // head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -232,43 +225,37 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; - const int64_t tgt_value_idx = block_idx * block_stride - + block_offset * num_heads * head_size - + head_idx * head_size - + head_offset; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; k_cache[tgt_value_idx] = key[src_key_idx]; v_cache[tgt_value_idx] = value[src_value_idx]; } } -} // namespace vllm +} // namespace vllm // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x, \ - kv_scale); +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, kv_scale); void reshape_and_cache( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, - const float kv_scale) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const float kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -283,17 +270,17 @@ void reshape_and_cache( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, CALL_RESHAPE_AND_CACHE) + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) } void reshape_and_cache_flash( - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype) -{ + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { // FIXME: only support auto datatype, does not support fp8 if (kv_cache_dtype != "auto") { TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); @@ -313,62 +300,47 @@ void reshape_and_cache_flash( const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_flash", - [&] { - vllm::reshape_and_cache_flash_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - k_cache.data_ptr(), - v_cache.data_ptr(), - slot_mapping.data_ptr(), - block_stride, - key_stride, - value_stride, - num_heads, - head_size, - block_size); - }); + key.scalar_type(), "reshape_and_cache_flash", [&] { + vllm::reshape_and_cache_flash_kernel + <<>>( + key.data_ptr(), value.data_ptr(), + k_cache.data_ptr(), v_cache.data_ptr(), + slot_mapping.data_ptr(), block_stride, key_stride, + value_stride, num_heads, head_size, block_size); + }); } namespace vllm { -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const float kv_scale, - const int64_t block_stride) { +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float kv_scale, + const int64_t block_stride) { const int64_t block_idx = blockIdx.x; for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { int64_t idx = block_idx * block_stride + i; - dst_cache[idx] = fp8::scaled_convert(src_cache[idx], kv_scale); + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], kv_scale); } } -} // namespace vllm +} // namespace vllm -#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ - kv_scale, \ - block_stride); +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); // Only for testing. -void convert_fp8( - torch::Tensor& dst_cache, - torch::Tensor& src_cache, - const float kv_scale, - const std::string& kv_cache_dtype) -{ +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const float kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); at::cuda::OptionalCUDAGuard device_guard(src_device); int64_t num_blocks = src_cache.size(0); @@ -398,13 +370,15 @@ void convert_fp8( } else if (src_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kFp8E4M3); + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); } } else { TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index 1bd24eb79d129..becd2ac42f17a 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -1,10 +1,10 @@ #include "cpu_types.hpp" namespace { -template -void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, - scalar_t *__restrict__ output) { +void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input, + scalar_t* __restrict__ output) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); @@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, } } -FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 ones(1.0); return x / (ones + (zeros - x).exp()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w2(0.044715f); @@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { return w3 * x * (ones + t); } -FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w2(0.5); return x * w2 * (ones + (x * w1).er()); } -FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { +FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w2(0.5); @@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); return x * w2 * (ones + inner.tanh()); } -}; // namespace +}; // namespace -void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { +void silu_and_mul(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "silu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(silu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_impl) + }); } -void gelu_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "gelu_and_mul_impl", [&] { - CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) - activation_kernel(num_tokens, d, - input.data_ptr(), - out.data_ptr()); - CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl) + }); } -void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; @@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] }); } -void gelu_new(torch::Tensor &out, torch::Tensor &input) { +void gelu_new(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); @@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) { }); } -void gelu_fast(torch::Tensor &out, torch::Tensor &input) { +void gelu_fast(torch::Tensor& out, torch::Tensor& input) { int num_tokens = input.numel() / input.size(-1); int d = input.size(-1); diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index c1d765be05598..54df69b7379d6 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -2,7 +2,8 @@ namespace { -template struct KernelVecType { +template +struct KernelVecType { using q_load_vec_type = void; using q_vec_type = void; using k_load_vec_type = void; @@ -11,7 +12,8 @@ template struct KernelVecType { using v_load_vec_type = void; }; -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::FP32Vec4; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16; @@ -21,7 +23,8 @@ template <> struct KernelVecType { }; #ifdef __AVX512BF16__ -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32; @@ -30,7 +33,8 @@ template <> struct KernelVecType { using v_load_vec_type = vec_op::BF16Vec16; }; #else -template <> struct KernelVecType { +template <> +struct KernelVecType { using q_load_vec_type = vec_op::BF16Vec8; using q_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::BF16Vec16; @@ -41,7 +45,7 @@ template <> struct KernelVecType { #endif template -FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, +FORCE_INLINE std::pair reduceSoftmax(T* data, const int size, const int capacity) { T max = data[0]; for (int i = 1; i < size; ++i) { @@ -67,10 +71,11 @@ FORCE_INLINE std::pair reduceSoftmax(T *data, const int size, } template -FORCE_INLINE std::pair -reduceSoftmaxAlibi(T *data, const int size, const int capacity, - const float alibi_slope, const int start_index, - const int seq_len) { +FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, + const int capacity, + const float alibi_slope, + const int start_index, + const int seq_len) { data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { @@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity, } template -FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, +FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { @@ -132,9 +137,9 @@ struct reduceQKBlockKernel { static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); - FORCE_INLINE static void call(const scalar_t *__restrict__ q, - const scalar_t *__restrict__ k_block, - float *__restrict__ logits, float scale, + FORCE_INLINE static void call(const scalar_t* __restrict__ q, + const scalar_t* __restrict__ k_block, + float* __restrict__ logits, float scale, const int token_num) { const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; @@ -196,8 +201,8 @@ struct reduceQKBlockKernel { template -FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, - acc_t &&acc) { +FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block, + acc_t&& acc) { using v_load_vec_type = typename KernelVecType::v_load_vec_type; constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); static_assert(BLOCK_SIZE == ELEM_NUM); @@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; }); } -}; // namespace +}; // namespace // Paged attention v1 namespace { template struct paged_attention_v1_impl { - static void - call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + static void call( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size, block_size] - const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const int num_seqs, const int num_heads) { + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { constexpr int x = 16 / sizeof(scalar_t); const int num_queries_per_kv = num_heads / num_kv_heads; @@ -243,32 +248,31 @@ struct paged_attention_v1_impl { size_t logits_bytes = parallel_work_item_num * max_seq_len_padded * sizeof(float); - float *logits = (float *)std::aligned_alloc( - 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_seq_len_padded] + float* logits = (float*)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { int seq_len = seq_lens[seq_idx]; - const int *seq_block_table = + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - const int last_block_token_num = - seq_len - (block_num - 1) * BLOCK_SIZE; - float *__restrict__ thread_block_logits = + const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE; + float* __restrict__ thread_block_logits = logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = thread_block_logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -282,8 +286,7 @@ struct paged_attention_v1_impl { block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, seq_len); } else { - reduceSoftmax(thread_block_logits, seq_len, - block_num * BLOCK_SIZE); + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } // Compute value @@ -293,14 +296,14 @@ struct paged_attention_v1_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = thread_block_logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -311,7 +314,7 @@ struct paged_attention_v1_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -340,16 +343,16 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); template void paged_attention_v1_impl_launcher( - torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_impl_launcher( \ - out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ +#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes); -#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V1_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V1_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace +} // namespace -void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, - torch::Tensor &key_cache, torch::Tensor &value_cache, +void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, - int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { + torch::Tensor& block_tables, torch::Tensor& seq_lens, + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { @@ -434,23 +436,24 @@ namespace { template struct paged_attention_v2_impl { static void call( - scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] - float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float - *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] const int num_kv_heads, const float scale, - const int - *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ seq_lens, // [num_seqs] + const int* __restrict__ block_tables, // [num_seqs, + // max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float *__restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, const int num_seqs, const int num_heads, const int max_num_partitions) { constexpr int x = 16 / sizeof(scalar_t); @@ -468,8 +471,7 @@ struct paged_attention_v2_impl { const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= seq_len) - continue; + if (start_token_idx >= seq_len) continue; const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; @@ -477,15 +479,14 @@ struct paged_attention_v2_impl { const int token_num = (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); - const int block_num = - (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = token_num - (block_num - 1) * BLOCK_SIZE; - const int *seq_block_table = block_tables + + const int* seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; - const scalar_t *__restrict__ q_vec_ptr = + const scalar_t* __restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; @@ -493,10 +494,10 @@ struct paged_attention_v2_impl { // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const scalar_t *__restrict__ k_block_cache_ptr = + const scalar_t* __restrict__ k_block_cache_ptr = k_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride; - float *__restrict__ head_block_logits = + float* __restrict__ head_block_logits = logits + block_idx * BLOCK_SIZE; reduceQKBlockKernel::call( @@ -510,13 +511,13 @@ struct paged_attention_v2_impl { logits, token_num, block_num * BLOCK_SIZE, alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, token_num, - block_num * BLOCK_SIZE); + max_and_sum = + reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } - auto &&[max_logit, exp_sum] = max_and_sum; + auto&& [max_logit, exp_sum] = max_and_sum; - scalar_t *__restrict__ output_buffer = nullptr; + scalar_t* __restrict__ output_buffer = nullptr; if (!no_reduce) { auto idx = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions + partition_idx; @@ -538,13 +539,13 @@ struct paged_attention_v2_impl { for (int head_part_idx = 0; head_part_idx < head_partition_num; ++head_part_idx) { vec_op::FP32Vec16 accums[head_elem_num_per_partition]; - scalar_t *__restrict__ out_ptr = + scalar_t* __restrict__ out_ptr = output_buffer + head_part_idx * head_elem_num_per_partition; for (int block_idx = 0; block_idx < block_num; ++block_idx) { const int64_t physical_block_idx = seq_block_table[block_idx]; - const float *__restrict__ prob_vec_ptr = + const float* __restrict__ prob_vec_ptr = logits + block_idx * BLOCK_SIZE; - const scalar_t *__restrict__ v_block_cache_ptr = + const scalar_t* __restrict__ v_block_cache_ptr = v_cache + physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -555,7 +556,7 @@ struct paged_attention_v2_impl { if (block_idx != block_num - 1) { const int64_t next_physical_block_idx = seq_block_table[block_idx + 1]; - const scalar_t *__restrict__ next_v_block_cache_ptr = + const scalar_t* __restrict__ next_v_block_cache_ptr = v_cache + next_physical_block_idx * kv_block_stride + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; @@ -587,8 +588,7 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; reducePartitonSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + @@ -603,11 +603,11 @@ struct paged_attention_v2_impl { using v_load_vec_type = typename KernelVecType::v_load_vec_type; static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); constexpr int head_elem_num_per_group = - 16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE - // didn't align with 64 bytes + 16; // Note: didn't align with the cacheline size, due to some + // HEAD_SIZE didn't align with 64 bytes static_assert(HEAD_SIZE % head_elem_num_per_group == 0); constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; - const float *__restrict__ rescale_factors = exp_sums; + const float* __restrict__ rescale_factors = exp_sums; #pragma omp parallel for collapse(3) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { @@ -616,17 +616,16 @@ struct paged_attention_v2_impl { const int partition_num = (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - if (partition_num == 1) - continue; + if (partition_num == 1) continue; - const float *__restrict__ seq_head_rescale_factors = + const float* __restrict__ seq_head_rescale_factors = rescale_factors + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - const scalar_t *__restrict__ seq_head_tmp_out = + const scalar_t* __restrict__ seq_head_tmp_out = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + group_idx * head_elem_num_per_group; - scalar_t *__restrict__ seq_head_output = + scalar_t* __restrict__ seq_head_output = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + group_idx * head_elem_num_per_group; @@ -645,21 +644,21 @@ struct paged_attention_v2_impl { } }; -#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ - paged_attention_v2_impl::call( \ - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ - key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, num_seqs, num_heads, \ +#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v2_impl::call( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ + key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); template void paged_attention_v2_impl_launcher( - torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, - torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, - int max_seq_len, const c10::optional &alibi_slopes) { + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher( int max_num_partitions = exp_sums.size(-1); // NOTE: alibi_slopes is optional. - const float *alibi_slopes_ptr = + const float* alibi_slopes_ptr = alibi_slopes - ? reinterpret_cast(alibi_slopes.value().data_ptr()) + ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T *out_ptr = reinterpret_cast(out.data_ptr()); - float *exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float *max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T *tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T *query_ptr = reinterpret_cast(query.data_ptr()); - T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int *block_tables_ptr = block_tables.data_ptr(); - int *seq_lens_ptr = seq_lens.data_ptr(); + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { - case 64: - LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); - break; - case 80: - LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); - break; - case 96: - LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); - break; - case 112: - LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); - break; - case 128: - LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); - break; - case 256: - LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; + case 64: + LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; } } -#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_impl_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, seq_lens, block_size, \ - max_seq_len, alibi_slopes); - -#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 16: \ - CALL_V2_KERNEL_LAUNCHER(T, 16); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_impl_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \ + alibi_slopes); + +#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_V2_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } -} // namespace - -void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, - torch::Tensor &max_logits, torch::Tensor &tmp_out, - torch::Tensor &query, torch::Tensor &key_cache, - torch::Tensor &value_cache, int num_kv_heads, - float scale, torch::Tensor &block_tables, - torch::Tensor &seq_lens, int block_size, +} // namespace + +void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& seq_lens, int block_size, int max_seq_len, - const c10::optional &alibi_slopes, - const std::string &kv_cache_dtype, float kv_scale) { + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 26e81685d623e..2890ba6e2bb32 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,25 +5,26 @@ namespace { template -void copy_blocks_cpu_impl( - std::vector &key_caches, - std::vector &value_caches, - const torch::Tensor& mapping_pairs, - const int element_num_per_block, const int layer_num) { +void copy_blocks_cpu_impl(std::vector& key_caches, + std::vector& value_caches, + const torch::Tensor& mapping_pairs, + const int element_num_per_block, + const int layer_num) { const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); + int64_t source_offset = + element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = element_num_per_block * mapping_pairs[pair][1].item(); - scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); - scalar_t *source_ptr = key_cache_ptr + source_offset; - scalar_t *target_ptr = key_cache_ptr + target_offset; + scalar_t* key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t* source_ptr = key_cache_ptr + source_offset; + scalar_t* target_ptr = key_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); - scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + scalar_t* value_cache_ptr = value_caches[layer].data_ptr(); source_ptr = value_cache_ptr + source_offset; target_ptr = value_cache_ptr + target_offset; std::memcpy(target_ptr, source_ptr, block_bytes); @@ -33,9 +34,9 @@ void copy_blocks_cpu_impl( template void reshape_and_cache_cpu_impl( - const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, - scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, - const int64_t *__restrict__ slot_mapping, const int num_tokens, + const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, + const int64_t* __restrict__ slot_mapping, const int num_tokens, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, const int x) { const int block_elem_num = num_heads * head_size * block_size; @@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl( int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_value_head_idx = token_idx * value_stride + head_idx * head_size; - const scalar_t *src_key_head_ptr = key + src_key_head_idx; - const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const scalar_t* src_key_head_ptr = key + src_key_head_idx; + const scalar_t* src_value_head_ptr = value + src_value_head_idx; const int64_t block_index = slot_idx / block_size; const int64_t block_offset = slot_idx % block_size; - scalar_t *target_key_head_ptr = key_cache + + scalar_t* target_key_head_ptr = key_cache + block_elem_num * block_index + head_idx * block_size * head_size; - scalar_t *target_value_head_ptr = value_cache + + scalar_t* target_value_head_ptr = value_cache + block_elem_num * block_index + head_idx * block_size * head_size; @@ -79,10 +80,10 @@ void reshape_and_cache_cpu_impl( } } } -}; // namespace +}; // namespace -void copy_blocks(std::vector &key_caches, - std::vector &value_caches, +void copy_blocks(std::vector& key_caches, + std::vector& value_caches, const torch::Tensor& block_mapping) { unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -100,10 +101,10 @@ void copy_blocks(std::vector &key_caches, }); } -void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, - torch::Tensor &key_cache, torch::Tensor &value_cache, - torch::Tensor &slot_mapping, - const std::string &kv_cache_dtype, float kv_scale) { +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); @@ -127,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, }); } -void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const torch::Tensor&block_mapping) { +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 467f0dc84982c..65d3ddcec5709 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -2,10 +2,10 @@ namespace { template -void rms_norm_impl(scalar_t *__restrict__ out, - const scalar_t *__restrict__ input, - const scalar_t *__restrict__ weight, const float epsilon, - const int num_tokens, const int hidden_size) { +void rms_norm_impl(scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out, } template -void fused_add_rms_norm_impl(scalar_t *__restrict__ input, - scalar_t *__restrict__ residual, - const scalar_t *__restrict__ weight, - const float epsilon, const int num_tokens, - const int hidden_size) { +void fused_add_rms_norm_impl(scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { using scalar_vec_t = vec_op::vec_t; constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); @@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, } } } -} // namespace +} // namespace -void rms_norm(torch::Tensor &out, torch::Tensor &input, - torch::Tensor &weight, float epsilon) { +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { CPU_KERNEL_GUARD_IN(rms_norm_impl) rms_norm_impl(out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, - hidden_size); + weight.data_ptr(), epsilon, num_tokens, + hidden_size); CPU_KERNEL_GUARD_OUT(rms_norm_impl) }); } -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, float epsilon) { +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 5dc1bde45ac5f..73bf77e46f538 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -4,16 +4,16 @@ namespace { template void rotary_embedding_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -26,7 +26,7 @@ void rotary_embedding_impl( #pragma omp parallel for for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; for (int i = 0; i < num_heads; ++i) { const int head_idx = i; @@ -94,16 +94,16 @@ void rotary_embedding_impl( template void rotary_embedding_gptj_impl( - const int64_t - *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t - *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or - /// [num_tokens, num_heads, head_size] - scalar_t - *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or - // [num_tokens, num_kv_heads, head_size] - const scalar_t - *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, + /// head_size] or [num_tokens, num_heads, + /// head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size, const int num_tokens) { @@ -113,13 +113,13 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - scalar_t *head_query = token_head + query; + scalar_t* head_query = token_head + query; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -141,12 +141,12 @@ void rotary_embedding_gptj_impl( for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { int64_t pos = positions[token_idx]; - const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; - const scalar_t *cos_cache_ptr = cache_ptr; - const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + const scalar_t* cos_cache_ptr = cache_ptr; + const scalar_t* sin_cache_ptr = cache_ptr + embed_dim; const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - scalar_t *head_key = key + token_head; + scalar_t* head_key = key + token_head; for (int j = 0; j < embed_dim; j += 1) { const int rot_offset = j; const int x_index = 2 * rot_offset; @@ -164,11 +164,11 @@ void rotary_embedding_gptj_impl( } } } -}; // namespace +}; // namespace -void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox) { +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp index bba044087f37c..63082393c8102 100644 --- a/csrc/cpu/pybind.cpp +++ b/csrc/cpu/pybind.cpp @@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); } diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 1ebb2e74a82fc..5909e5eaf5e60 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,7 +1,7 @@ #pragma once #ifdef USE_ROCM -#include + #include #endif #ifndef USE_ROCM @@ -17,7 +17,8 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #endif @@ -29,7 +30,8 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) #else #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) #endif @@ -41,4 +43,3 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif - diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 1483484faeb4a..2ba49b339e148 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -2,9 +2,6 @@ #include -int get_device_attribute( - int attribute, - int device_id); +int get_device_attribute(int attribute, int device_id); -int get_max_shared_memory_per_block_device_attribute( - int device_id); +int get_max_shared_memory_per_block_device_attribute(int device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 1a443ef3620cc..7d8e2e19720fa 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,34 +2,28 @@ #include #include #endif -int get_device_attribute( - int attribute, - int device_id) -{ - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } - else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), device); - return value; +int get_device_attribute(int attribute, int device_id) { + int device, value; + if (device_id < 0) { + cudaGetDevice(&device); + } else { + device = device_id; + } + cudaDeviceGetAttribute(&value, static_cast(attribute), + device); + return value; } - -int get_max_shared_memory_per_block_device_attribute( - int device_id) -{ -int attribute; -// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html -// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 +int get_max_shared_memory_per_block_device_attribute(int device_id) { + int attribute; + // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html + // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 #ifdef USE_ROCM - attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; + attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; #else - attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; + attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; #endif - return get_device_attribute(attribute, device_id); + return get_device_attribute(attribute, device_id); } diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 3906dcfc80dbf..0b1d95848525a 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -7,11 +7,11 @@ // fake pointer type using fptr_t = uint64_t; -static_assert(sizeof(void *) == sizeof(fptr_t)); +static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } @@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ -bool _is_weak_contiguous(torch::Tensor &t) { +bool _is_weak_contiguous(torch::Tensor& t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, return false; } -void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), - out.numel()); + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(inp.data_ptr()), - reinterpret_cast(out.data_ptr()), out.numel()); + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); break; } #endif @@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, } } -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); @@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { _all_reduce(_fa, inp, out, stream); } -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out) { +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, } void dispose(fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); delete fa; } int meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets) { - auto fa = reinterpret_cast(_fa); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_buffer(handles, offsets, t.data_ptr()); } std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa) { - auto fa = reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); return fa->get_graph_buffer_ipc_meta(); } -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets) { - auto fa = reinterpret_cast(_fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 750e68d42f6c6..1ed49b8aa9cae 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -31,9 +31,9 @@ struct Signal { alignas(128) uint32_t end[kMaxBlocks][8]; }; -struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal *signals[8]; }; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) { // scalar add functions // for some reason when compiling with Pytorch, the + operator for half and // bfloat is disabled so we call the intrinsics directly -DINLINE half &assign_add(half &a, half b) { +DINLINE half& assign_add(half& a, half b) { a = __hadd(a, b); return a; } -DINLINE float &assign_add(float &a, float b) { return a += b; } +DINLINE float& assign_add(float& a, float b) { return a += b; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } @@ -80,14 +80,14 @@ template <> DINLINE nv_bfloat16 downcast_s(float val) { return __float2bfloat16(val); } -DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { a = __hadd(a, b); return a; } #endif template -DINLINE array_t &packed_assign_add(array_t &a, array_t b) { +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { #pragma unroll for (int i = 0; i < N; i++) { assign_add(a.data[i], b.data[i]); @@ -128,7 +128,7 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { if (threadIdx.x < ngpus) { // reset flag for next time @@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); } @@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. + // the memory model. if constexpr (!final_sync) __threadfence_system(); if (threadIdx.x < ngpus) { // reset flag for next time @@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, // Latency = 1 p2p write sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]) - ; + while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); } template -DINLINE P packed_reduce(const P *ptrs[], int idx) { +DINLINE P packed_reduce(const P* ptrs[], int idx) { A tmp = upcast(ptrs[0][idx]); #pragma unroll for (int i = 1; i < ngpus; i++) { @@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; @@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1) // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { - ((P *)result)[idx] = - packed_reduce((const P **)&dp.ptrs[0], idx); + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } end_sync(sg, self_sg, rank); } template -DINLINE P *get_tmp_buf(volatile Signal *sg) { - return (P *)(((Signal *)sg) + 1); +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData *_dp, RankSignals sg, - volatile Signal *self_sg, T *__restrict__ result, + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1) int start = rank * part; int end = rank == ngpus - 1 ? size : start + part; int largest_part = part + size % ngpus; - const P *ptrs[ngpus]; - P *tmps[ngpus]; + const P* ptrs[ngpus]; + P* tmps[ngpus]; #pragma unroll for (int i = 0; i < ngpus; i++) { int target = (rank + i) % ngpus; - ptrs[i] = (const P *)_dp->ptrs[target]; + ptrs[i] = (const P*)_dp->ptrs[target]; tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; @@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1) int gather_from_rank = ((rank + i) % ngpus); if (gather_from_rank == ngpus - 1 || idx < part) { int dst_idx = gather_from_rank * part + idx; - ((P *)result)[dst_idx] = tmps[i][idx]; + ((P*)result)[dst_idx] = tmps[i][idx]; } } } @@ -261,14 +258,14 @@ class CustomAllreduce { // below are device pointers RankSignals sg_; - std::unordered_map buffers_; - Signal *self_sg_; + std::unordered_map buffers_; + Signal* self_sg_; // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; - std::vector graph_unreg_buffers_; + std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers - std::map ipc_handles_; + std::map ipc_handles_; /** * meta is a pointer to device metadata and temporary buffer for allreduce. @@ -279,22 +276,22 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t *handles, - const std::vector &offsets, int rank, + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), self_sg_(meta), - d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal *rank_sg; + Signal* rank_sg; if (i != rank_) { - char *handle = open_ipc_handle(&handles[i]); + char* handle = open_ipc_handle(&handles[i]); handle += offsets[i]; - rank_sg = (Signal *)handle; + rank_sg = (Signal*)handle; } else { rank_sg = self_sg_; } @@ -302,13 +299,13 @@ class CustomAllreduce { } } - char *open_ipc_handle(const void *ipc_handle) { + char* open_ipc_handle(const void* ipc_handle) { auto [it, new_handle] = - ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { - char *ipc_ptr; - CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, - *((const cudaIpcMemHandle_t *)ipc_handle), + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } @@ -323,7 +320,7 @@ class CustomAllreduce { std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; - void *base_ptr; + void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, @@ -331,8 +328,8 @@ class CustomAllreduce { (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( - (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); - offsets[i] = ((char *)ptr) - ((char *)base_ptr); + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); } return std::make_pair(handles, offsets); } @@ -344,13 +341,13 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector &handles, - const std::vector &offsets, void *self) { + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { if (i != rank_) { - char *handle = open_ipc_handle(handles[i].data()); + char* handle = open_ipc_handle(handles[i].data()); handle += offsets[i]; data.ptrs[i] = handle; } else { @@ -371,17 +368,17 @@ class CustomAllreduce { // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( - const std::vector &handles, - const std::vector> &offsets) { + const std::vector& handles, + const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = graph_unreg_buffers_[i]; - auto &rd = rank_data[i]; + auto& rd = rank_data[i]; for (int j = 0; j < world_size_; j++) { if (j != rank_) { - char *handle = + char* handle = open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; @@ -405,7 +402,7 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T *input, T *output, int size, + void allreduce(cudaStream_t stream, T* input, T* output, int size, int threads = 512, int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) @@ -418,7 +415,7 @@ class CustomAllreduce { std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); - RankData *ptrs; + RankData* ptrs; cudaStreamCaptureStatus status; CUDACHECK(cudaStreamIsCapturing(stream, &status)); if (status == cudaStreamCaptureStatusActive) { diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c34a50389c21c..f7868233076cd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -48,7 +48,7 @@ __global__ void dummy_kernel() { } template -__global__ void set_data(T *data, int size, int myRank) { +__global__ void set_data(T* data, int size, int myRank) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { data[idx] = myRank * 0.11f; @@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) { } template -__global__ void convert_data(const T *data1, const T *data2, double *fdata1, - double *fdata2, int size) { +__global__ void convert_data(const T* data1, const T* data2, double* fdata1, + double* fdata2, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { fdata1[idx] = data1[idx]; @@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1, } } -__global__ void init_rand(curandState_t *state, int size, int nRanks) { +__global__ void init_rand(curandState_t* state, int size, int nRanks) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { for (int i = 0; i < nRanks; i++) { @@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) { } template -__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, +__global__ void gen_data(curandState_t* state, T* data, double* ground_truth, int myRank, int nRanks, int size) { for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { @@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, } template -void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, +void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, int data_size, bool performance_test) { - T *result; + T* result; cudaStream_t stream; CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); @@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t data_handles[8]; - vllm::Signal *buffer; - T *self_data_copy; + vllm::Signal* buffer; + T* self_data_copy; /** * Allocate IPC buffer * @@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, MPI_COMM_WORLD)); - void *rank_data; + void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); std::vector offsets(nRanks, 0); vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, offsets, myRank); - auto *self_data = - reinterpret_cast(reinterpret_cast(buffer) + - sizeof(vllm::Signal) + data_size * sizeof(T)); + auto* self_data = + reinterpret_cast(reinterpret_cast(buffer) + + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { std::vector handles; handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - char *begin = (char *)&data_handles[i]; - char *end = (char *)&data_handles[i + 1]; + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; handles.emplace_back(begin, end); } std::vector offsets(nRanks, @@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, fa.register_buffer(handles, offsets, self_data); } - double *ground_truth; + double* ground_truth; CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); - curandState_t *states; + curandState_t* states; CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); gen_data<<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, @@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, CUDACHECK(cudaStreamDestroy(stream)); } -int main(int argc, char **argv) { +int main(int argc, char** argv) { int nRanks, myRank; MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); @@ -296,7 +296,7 @@ int main(int argc, char **argv) { ncclUniqueId id; ncclComm_t comm; if (myRank == 0) ncclGetUniqueId(&id); - MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, + MPICHECK(MPI_Bcast(static_cast(&id), sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 91abd9e85b4bb..3ecea03242f06 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,32 +6,30 @@ #include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) - -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index e56b4d2204005..70a2b3b0a07b1 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -11,26 +11,24 @@ #include #include - using __nv_bfloat16 = __hip_bfloat16; - using __nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float) input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); @@ -40,12 +38,12 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) input[blockIdx.x * hidden_size + idx]; - out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } - /* Converter structs for the conversion from torch types to HIP/CUDA types, and the associated type conversions within HIP/CUDA. These helpers need to be implemented for now because the relevant type conversion @@ -54,51 +52,68 @@ __global__ void rms_norm_kernel( Each struct should have the member static constexpr bool `exists`: If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. + If true, the struct should be fully defined as shown in the examples below. */ -template -struct _typeConvert { static constexpr bool exists = false; }; +template +struct _typeConvert { + static constexpr bool exists = false; +}; #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) // CUDA < 12.0 runs into issues with packed type conversion -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __half; using packed_hip_type = __half2; __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } - __device__ static inline hip_type convert(float x) { return __float2half_rn(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } }; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely -template<> +template <> struct _typeConvert { static constexpr bool exists = true; using hip_type = __nv_bfloat16; using packed_hip_type = __nv_bfloat162; - __device__ static inline float convert(hip_type x) { return __bfloat162float(x); } - __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } - __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } - __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } }; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) /* Vector POD struct to generate vectorized and packed FP16/BF16 ops for appropriate specializations of fused_add_rms_norm_kernel. Only functions that are necessary in that kernel are implemented. Alignment to 16 bytes is required to use 128-bit global memory ops. */ -template +template struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); using Converter = _typeConvert; @@ -108,51 +123,49 @@ struct alignas(16) _f16Vec { __device__ _f16Vec& operator+=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp += T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] += other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const _f16Vec& other) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i+1]}; - temp *= T2{other.data[i], other.data[i+1]}; + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll - for (int i = 0; i < width; ++i) - data[i] *= other.data[i]; +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; } return *this; } __device__ _f16Vec& operator*=(const float scale) { if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); temp_f.x *= scale; temp_f.y *= scale; T2 temp = Converter::convert(temp_f); data[i] = temp.x; - data[i+1] = temp.y; + data[i + 1] = temp.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float temp = Converter::convert(data[i]) * scale; data[i] = Converter::convert(temp); @@ -164,13 +177,13 @@ struct alignas(16) _f16Vec { __device__ float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i+1]}); + float2 z = Converter::convert(T2{data[i], data[i + 1]}); result += z.x * z.x + z.y * z.y; } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < width; ++i) { float x = Converter::convert(data[i]); result += x * x; @@ -184,15 +197,13 @@ struct alignas(16) _f16Vec { Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template -__global__ std::enable_if_t< - (width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); @@ -203,9 +214,12 @@ __global__ std::enable_if_t< /* These and the argument pointers are all declared `restrict` as they are not aliased in practice. Argument pointers should not be dereferenced in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); - auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); - auto* __restrict__ weight_v = reinterpret_cast*>(weight); + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; @@ -215,10 +229,11 @@ __global__ std::enable_if_t< residual_v[id] = temp; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -233,52 +248,50 @@ __global__ std::enable_if_t< } } - /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template -__global__ std::enable_if_t< - (width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] - scalar_t* __restrict__ residual, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, - const int num_tokens, - const int hidden_size) { +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { scalar_t z = input[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx]; - float x = (float) z; + float x = (float)z; variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ + calculation of max_block_size in fused_add_rms_norm */ if (num_tokens < 256) { variance = blockReduceSum(variance); - } else variance = blockReduceSum(variance); + } else + variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float) residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; } } -} // namespace vllm +} // namespace vllm -void rms_norm( - torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -286,40 +299,27 @@ void rms_norm( dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "rms_norm_kernel", - [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - epsilon, - num_tokens, - hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "fused_add_rms_norm_kernel", \ - [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>( \ - input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), \ - epsilon, \ - num_tokens, \ - hidden_size); \ - }); - -void fused_add_rms_norm( - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - float epsilon) { +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -342,8 +342,8 @@ void fused_add_rms_norm( auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ - && wt_ptr % 16 == 0; + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; if (ptrs_are_aligned && hidden_size % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp index 35c328499a22d..4122f7630d7c7 100644 --- a/csrc/moe/moe_ops.cpp +++ b/csrc/moe/moe_ops.cpp @@ -3,5 +3,6 @@ #include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); + m.def("topk_softmax", &topk_softmax, + "Apply topk softmax to the gating outputs."); } diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index a01be3e426d72..93e7844ac1993 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -2,8 +2,6 @@ #include -void topk_softmax( - torch::Tensor& topk_weights, - torch::Tensor& topk_indices, - torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); +void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output); diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index e01b23685ef4e..edc441d121029 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -7,119 +7,128 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) namespace vllm { namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; } +} // namespace template -__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, - int32_t *sorted_token_ids, - int32_t *expert_ids, - int32_t *total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, - size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - - int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) - int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are assigned - * to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding blocks - * and stores the corresponding expert_id for each block. - */ - for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { - expert_ids[i / block_size] = threadIdx.x; +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; } - - /** - * Each thread processes a token shard, calculating the index of each token after - * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and - * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], - * where * represents a padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] - * stores the indices of the tokens processed by the expert with expert_id within - * the current thread's token shard. - */ - int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } } - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors - const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); // set dynamic shared mem auto kernel = vllm::moe_align_block_size_kernel; - AT_CUDA_CHECK( - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); kernel<<<1, num_experts, shared_mem, stream>>>( - topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), - num_experts, - block_size, + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, topk_ids.numel()); - }); + }); } diff --git a/csrc/ops.h b/csrc/ops.h index 8c2c2ae6e1f5a..f5e0e423bb65d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,224 +2,136 @@ #include -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& seq_lens, - int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, - float kv_scale); - -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); - -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); - -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); - -void batched_rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets); - -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_tanh_and_mul( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); - -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); +void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, + torch::Tensor& key_cache, torch::Tensor& value_cache, + int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, + int block_size, int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale); + +void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& seq_lens, int block_size, + int max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale); + +void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, + float epsilon); + +void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, + torch::Tensor& weight, float epsilon); + +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, int head_size, + torch::Tensor& cos_sin_cache, bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_new(torch::Tensor& out, torch::Tensor& input); + +void gelu_fast(torch::Tensor& out, torch::Tensor& input); #ifndef USE_ROCM -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -); - -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -); - -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy); - -torch::Tensor marlin_gemm( - torch::Tensor& a, - torch::Tensor& b_q_weight, - torch::Tensor& b_scales, - torch::Tensor& workspace, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_24_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm( - torch::Tensor &a, - torch::Tensor &b_q_weight, - torch::Tensor &b_scales, - torch::Tensor &g_idx, - torch::Tensor &perm, - torch::Tensor &workspace, - int64_t num_bits, - int64_t size_m, - int64_t size_n, - int64_t size_k, - bool is_k_full); - -torch::Tensor gptq_marlin_repack( - torch::Tensor &b_q_weight, - torch::Tensor &perm, - int64_t size_k, - int64_t size_n, - int64_t num_bits); - -int cutlass_scaled_mm_dq( - torch::Tensor& out, - torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias); + +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes); + +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters); + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy); + +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, + int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, + int64_t size_m, int64_t size_n, + int64_t size_k); + +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full); + +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, + int64_t size_k, int64_t size_n, + int64_t num_bits); + +int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit); - -void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit); - -void static_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void dynamic_scaled_fp8_quant( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& scale); - -void moe_align_block_size( - torch::Tensor topk_ids, - int num_experts, - int block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit); + +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + +void moe_align_block_size(torch::Tensor topk_ids, int num_experts, + int block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM using fptr_t = uint64_t; -fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, - const std::vector &handles, - const std::vector &offsets, int rank, - bool full_nvlink); -bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int rank, + bool full_nvlink); +bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, bool full_nvlink); -void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); -void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, - torch::Tensor &out); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); void dispose(fptr_t _fa); int meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor &t, - const std::vector &handles, - const std::vector &offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector &handles, - const std::vector> &offsets); +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::pair, std::vector> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets); #endif diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index d80cb6973fad6..69d6dae1c26bc 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -7,14 +7,10 @@ namespace vllm { -template +template inline __device__ void apply_token_rotary_embedding( - scalar_t* __restrict__ arr, - const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, - int rot_offset, - int embed_dim) -{ + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { int x_index, y_index; scalar_t cos, sin; if (IS_NEOX) { @@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding( arr[y_index] = y * cos + x * sin; } -template +template inline __device__ void apply_rotary_embedding( - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* cache_ptr, - const int head_size, - const int num_heads, - const int num_kv_heads, - const int rot_dim, - const int token_idx, - const int64_t query_stride, - const int64_t key_stride) -{ + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } const int nk = num_kv_heads * embed_dim; @@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } -template +template __global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -template +template __global__ void batched_rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, - const int num_heads, - const int num_kv_heads, - const int head_size) { + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); } -} // namespace vllm +} // namespace vllm void rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; @@ -135,36 +141,21 @@ void rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); } /* @@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together and process in batched manner. */ void batched_rotary_embedding( - torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] - int head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, - int rot_dim, - torch::Tensor& cos_sin_cache_offsets // [num_tokens] + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); int num_heads = query.size(-1) / head_size; @@ -191,36 +183,21 @@ void batched_rotary_embedding( dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - query.scalar_type(), - "rotary_embedding", - [&] { - if (is_neox) { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } else { - vllm::batched_rotary_embedding_kernel<<>>( - positions.data_ptr(), - query.data_ptr(), - key.data_ptr(), - cos_sin_cache.data_ptr(), - cos_sin_cache_offsets.data_ptr(), - rot_dim, - query_stride, - key_stride, - num_heads, - num_kv_heads, - head_size); - } - }); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index f5b4865506568..cba07f0ae9f2a 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -8,116 +8,87 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); // Attention ops - ops.def( - "paged_attention_v1", - &paged_attention_v1, - "Compute the attention between an input query and the cached keys/values using PagedAttention."); - ops.def( - "paged_attention_v2", - &paged_attention_v2, - "PagedAttention V2."); + ops.def("paged_attention_v1", &paged_attention_v1, + "Compute the attention between an input query and the cached " + "keys/values using PagedAttention."); + ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); // Activation ops - ops.def( - "silu_and_mul", - &silu_and_mul, - "Activation function used in SwiGLU."); - ops.def( - "gelu_and_mul", - &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def( - "gelu_tanh_and_mul", - &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def( - "gelu_new", - &gelu_new, - "GELU implementation used in GPT-2."); - ops.def( - "gelu_fast", - &gelu_fast, - "Approximate GELU implementation."); + ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + ops.def("gelu_and_mul", &gelu_and_mul, + "Activation function used in GeGLU with `none` approximation."); + ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, + "Activation function used in GeGLU with `tanh` approximation."); + ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); + ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); // Layernorm - ops.def( - "rms_norm", - &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); + ops.def("rms_norm", &rms_norm, + "Apply Root Mean Square (RMS) Normalization to the input tensor."); - ops.def( - "fused_add_rms_norm", - &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); + ops.def("fused_add_rms_norm", &fused_add_rms_norm, + "In-place fused Add and RMS Normalization"); // Rotary embedding - ops.def( - "rotary_embedding", - &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def("rotary_embedding", &rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - ops.def( - "batched_rotary_embedding", - &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + ops.def("batched_rotary_embedding", &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " + "(supports multiple loras)"); // Quantization ops #ifndef USE_ROCM ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); + ops.def("marlin_gemm", &marlin_gemm, + "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, + "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, + "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, + "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."); + ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, + "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " + "per-row/column quantization."); #endif - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); - ops.def( - "moe_align_block_size", - &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, + "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, + "Compute FP8 quantized tensor and scaling factor"); + ops.def("moe_align_block_size", &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such " + "that it is divisible by the block size."); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def( - "swap_blocks", - &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def( - "copy_blocks", - ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def( - "reshape_and_cache", - &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "reshape_and_cache_flash", - &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def( - "convert_fp8", - &convert_fp8, - "Convert the key and value cache to fp8 data type"); + cache_ops.def("swap_blocks", &swap_blocks, + "Swap in (out) the cache blocks from src to dst"); + cache_ops.def("copy_blocks", ©_blocks, + "Copy the cache blocks from src to dst"); + cache_ops.def("reshape_and_cache", &reshape_and_cache, + "Reshape the key and value tensors and cache them"); + cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); + cache_ops.def("convert_fp8", &convert_fp8, + "Convert the key and value cache to fp8 data type"); // Cuda utils - pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def( - "get_device_attribute", - &get_device_attribute, - "Gets the specified device attribute."); + pybind11::module cuda_utils = + m.def_submodule("cuda_utils", "vLLM cuda utils"); + cuda_utils.def("get_device_attribute", &get_device_attribute, + "Gets the specified device attribute."); - cuda_utils.def( - "get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute, + "Gets the maximum shared memory per block device attribute."); #ifndef USE_ROCM // Custom all-reduce kernels @@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { custom_ar.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); #endif - } diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 4415316e1e8cd..255844eec56d4 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -25,32 +25,28 @@ #include #include - namespace vllm { namespace aqlm { __global__ void Code1x16MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - const int prob_m, - const int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m, + const int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -67,8 +63,7 @@ __global__ void Code1x16MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -76,22 +71,19 @@ __global__ void Code1x16MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { uint32_t dec[4]; - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); half2* a = reinterpret_cast(&dec); half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll - for (int j = 0; j < 4; j++) - res2 = __hfma2(a[j], b[j], res2); +#pragma unroll + for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); b_sh_rd++; } @@ -100,37 +92,33 @@ __global__ void Code1x16MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } __global__ void Code2x8MatVec( - const int4* __restrict__ A, - const int4* __restrict__ B, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const int4* __restrict__ A, const int4* __restrict__ B, + int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m, + int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -148,9 +136,8 @@ __global__ void Code2x8MatVec( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -161,8 +148,7 @@ __global__ void Code2x8MatVec( // We pad shared memory to avoid bank conflicts during reads __syncthreads(); for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) { - if (b_gl_rd + i < prob_k / 8) - sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; + if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i]; } __syncthreads(); b_gl_rd += 32 * 8; @@ -170,13 +156,15 @@ __global__ void Code2x8MatVec( int b_sh_rd = 9 * (threadIdx.x % 32); if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - half2* b = reinterpret_cast(&sh_b[b_sh_rd]); + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); + half2* b = reinterpret_cast(&sh_b[b_sh_rd]); half2 res2 = {}; - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2); res += __half2float(res2.x) + __half2float(res2.y); @@ -187,36 +175,31 @@ __global__ void Code2x8MatVec( } if (pred) { - #pragma unroll - for (int i = 16; i > 0; i /= 2) - res += __shfl_down_sync(0xffffffff, res, i); +#pragma unroll + for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i); if (threadIdx.x % 32 == 0) reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res); } } - __global__ void Code1x16Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long, sums to m. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -231,17 +214,15 @@ __global__ void Code1x16Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; auto dec = reinterpret_cast(&chunk); - // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't - // actually help us; this brings > 2x speedup. - asm volatile ( - "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" - : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) - : "l"((void*) &codebook[enc[i]]) - ); + // We bypass the L1 cache to avoid massive amounts of memory streaming + // that doesn't actually help us; this brings > 2x speedup. + asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];" + : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3]) + : "l"((void*)&codebook[enc[i]])); C[a_gl_rd * 8 + i] = chunk; } @@ -250,28 +231,25 @@ __global__ void Code1x16Dequant( } } - __global__ void Code2x8Dequant( - const int4* __restrict__ A, - int4* __restrict__ C, - const int4* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 + const int4* __restrict__ A, int4* __restrict__ C, + const int4* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int a_gl_stride = prob_k / 8 / 8; int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32); bool pred = a_gl_rd < prob_m; - if (pred) - { - // advance to the correct codebook, this easy because we only multiply one column of the codebook. + if (pred) { + // advance to the correct codebook, this easy because we only multiply one + // column of the codebook. auto codebook_size = &codebook_a_sizes.x; - while (a_gl_rd >= *codebook_size) - { - codebook += codebook_stride; - ++codebook_size; + while (a_gl_rd >= *codebook_size) { + codebook += codebook_stride; + ++codebook_size; } } @@ -290,9 +268,8 @@ __global__ void Code2x8Dequant( for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) { int4 dec = codebook[i]; - #pragma unroll - for (int j = 0; j < 8; j++) - sh_code[8 * i + (j + lane) % 8] = dec; +#pragma unroll + for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec; } __syncthreads(); @@ -302,12 +279,14 @@ __global__ void Code2x8Dequant( while (iters--) { if (pred && a_gl_rd < a_gl_end) { const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]); - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { int4 chunk; - half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); - half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); - #pragma unroll + half2* a0 = + reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]); + half2* a1 = + reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]); +#pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]); C[a_gl_rd * 8 + i] = chunk; @@ -317,22 +296,15 @@ __global__ void Code2x8Dequant( } } -inline int ceildiv(int a, int b) { - return (a + b - 1) / b; -} +inline int ceildiv(int a, int b) { return (a + b - 1) / b; } const int THREAD_M = 16; -void code1x16_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code1x16_matvec_cuda(const void* __restrict__ A, + const void* __restrict__ B, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -345,28 +317,16 @@ void code1x16_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - Code1x16MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + Code1x16MatVec<<>>( + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } -void code2x8_matvec_cuda( - const void* __restrict__ A, - const void* __restrict__ B, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, - const int codebook_stride -) { +void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B, + void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, + int prob_k, const int4 codebook_a_sizes, + const int codebook_stride) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); int waves = 0; @@ -379,30 +339,20 @@ void code2x8_matvec_cuda( int blocks = ceildiv(prob_m, thread_m); int threads = 32 * thread_m; int shared = 16 * (2 * 256 * 8 + 32 * 9); - cudaFuncSetAttribute( - Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8MatVec, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code2x8MatVec<<>>( - (const int4*) A, - (const int4*) B, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m, + prob_k, codebook_a_sizes, codebook_stride); } void code1x16_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - const int codebook_stride // as int4. + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 codebook_a_sizes, // cumulative sizes of A spanning each + // codebook, at most 3 long. + const int codebook_stride // as int4. ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -417,25 +367,21 @@ void code1x16_dequant_cuda( int threads = 32 * thread_m; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); Code1x16Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long. - codebook_stride // as int4. + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long. + codebook_stride // as int4. ); } // Dequantizes the code and codebook into weights. -void code2x8_dequant_cuda( - const void* __restrict__ A, - void* __restrict__ C, - const void* __restrict__ codebook, - int prob_m, - int prob_k, - const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols. - const int codebook_stride // as int4 +void code2x8_dequant_cuda( + const void* __restrict__ A, void* __restrict__ C, + const void* __restrict__ codebook, int prob_m, int prob_k, + const int4 + codebook_a_sizes, // cumulative sizes of A spanning each codebook, at + // most 3 long, corresponds to cols. + const int codebook_stride // as int4 ) { int sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0); @@ -451,74 +397,50 @@ void code2x8_dequant_cuda( int shared = 16 * (2 * 256 * 8 + 32 * 9); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cudaFuncSetAttribute( - Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared - ); + cudaFuncSetAttribute(Code2x8Dequant, + cudaFuncAttributeMaxDynamicSharedMemorySize, shared); Code2x8Dequant<<>>( - (const int4*) A, - (int4*) C, - (const int4*) codebook, - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride - ); + (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k, + codebook_a_sizes, codebook_stride); } -int codebook_stride(const torch::Tensor& codebooks) -{ +int codebook_stride(const torch::Tensor& codebooks) { return codebooks.stride(0) * codebooks.element_size() / sizeof(int4); } void code1x16_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long. + const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + const torch::Tensor& codebook, + const int4 codebook_a_sizes // cumulative sizes of A spanning each + // codebook, at most 3 long. ) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code1x16_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - codebook_stride(codebook) - ); + code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + codebook_stride(codebook)); } -torch::Tensor code1x16_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias) { +torch::Tensor code1x16_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code1x16_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); @@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat( return output; } -void code2x8_matvec( - const torch::Tensor& A, - const torch::Tensor& B, - torch::Tensor& C, - const torch::Tensor& codebook, - const int4 codebook_a_sizes -) { +void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B, + torch::Tensor& C, const torch::Tensor& codebook, + const int4 codebook_a_sizes) { const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); int prob_m = C.size(0); int prob_k = B.size(0); - code2x8_matvec_cuda( - A.data_ptr(), - B.data_ptr(), - C.data_ptr(), - codebook.data_ptr(), - prob_m, - prob_k, - codebook_a_sizes, - 2 * codebook_stride(codebook) - ); + code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(), + codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes, + 2 * codebook_stride(codebook)); } -torch::Tensor code2x8_matmat( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const int4 codebook_a_sizes, - const std::optional& bias -) { +torch::Tensor code2x8_matmat(const torch::Tensor& input, + const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const int4 codebook_a_sizes, + const std::optional& bias) { auto input_sizes = input.sizes(); auto out_features = codes.size(0) * codebooks.size(2); auto flat_input = input.reshape({-1, input.size(-1)}); - auto flat_output = torch::empty({flat_input.size(0), out_features}, - torch::TensorOptions() - .dtype(input.dtype()) - .device(input.device()) - ); + auto flat_output = torch::empty( + {flat_input.size(0), out_features}, + torch::TensorOptions().dtype(input.dtype()).device(input.device())); for (int i = 0; i < flat_input.size(0); ++i) { auto input_vec = flat_input.index({i}); auto output_vec = flat_output.index({i}); - code2x8_matvec( - codes.squeeze(2), - input_vec, - output_vec, - codebooks, - codebook_a_sizes - ); + code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks, + codebook_a_sizes); } flat_output *= scales.flatten().unsqueeze(0); if (bias.has_value()) { @@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat( } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) -{ +int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; int i = 0; int last = 0; assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) - { + for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { *cumulative_size = codebook_partition_sizes[i].item() + last; last = *cumulative_size; } // fill in the rest with unreachable. - for (; i < 4; ++i, ++cumulative_size) - { - *cumulative_size = last*10; + for (; i < 4; ++i, ++cumulative_size) { + *cumulative_size = last * 10; } return cumulative_sizes; } -} // namespace aqlm -} // namespace vllm - +} // namespace aqlm +} // namespace vllm -torch::Tensor aqlm_gemm( - const torch::Tensor& input, - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, - const std::optional& bias -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& scales, + const torch::Tensor& codebook_partition_sizes, + const std::optional& bias) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); - if (nbooks == 1 && entries == (1 << 16)) - { - return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 1 && entries == (1 << 16)) { + return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - if (nbooks == 2 && entries == (1 << 8)) - { - return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias); + if (nbooks == 2 && entries == (1 << 8)) { + return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, + cumulative_sizes, bias); } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } -torch::Tensor aqlm_dequant( - const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes -) -{ - int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); +torch::Tensor aqlm_dequant(const torch::Tensor& codes, + const torch::Tensor& codebooks, + const torch::Tensor& codebook_partition_sizes) { + int4 cumulative_sizes = + vllm::aqlm::accumulate_sizes(codebook_partition_sizes); int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); int const entries = codebooks.size(1); @@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant( assert(out_features = codebook_partition_sizes.sum().item()); auto weights = torch::empty({out_features, in_features}, - torch::TensorOptions() - .dtype(codebooks.dtype()) - .device(codebooks.device()) - ); + torch::TensorOptions() + .dtype(codebooks.dtype()) + .device(codebooks.device())); + + if (nbooks == 1 && entries == (1 << 16)) { + vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); - if (nbooks == 1 && entries == (1 << 16)) - { - vllm::aqlm::code1x16_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.) - // weights *= scales.index({"...", 0, 0}); - - return weights; + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation.) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - if (nbooks == 2 && entries == (1 << 8)) - { - vllm::aqlm::code2x8_dequant_cuda( - codes.data_ptr(), - weights.data_ptr(), - codebooks.data_ptr(), - out_features, - in_features, - cumulative_sizes, - vllm::aqlm::codebook_stride(codebooks)); - - // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation) - // weights *= scales.index({"...", 0, 0}); - - return weights; + if (nbooks == 2 && entries == (1 << 8)) { + vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(), + codebooks.data_ptr(), out_features, + in_features, cumulative_sizes, + vllm::aqlm::codebook_stride(codebooks)); + + // if you wanted to flip to scaling the weights, (though it's 30%-ish slower + // and not consistent with gemv implementation) weights *= + // scales.index({"...", 0, 0}); + + return weights; } - TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.") + TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, + " entries is not currently supported.") return {}; } diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh index d1d926de18d78..813ec6716cf54 100644 --- a/csrc/quantization/awq/dequantize.cuh +++ b/csrc/quantization/awq/dequantize.cuh @@ -1,11 +1,11 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq -Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +Modified from NVIDIA FasterTransformer: +https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ @@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor namespace vllm { namespace awq { -__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) -{ +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 assert(false); #else - uint4 result; + uint4 result; - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. - // This is the half2 {1032, 1032} represented as an integer. - // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - // static constexpr uint32_t NEG_72 = 0xd480d480; - // Haotian: Let's use {-64, -64}. - static constexpr uint32_t NEG_64 = 0xd400d400; + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); - return result; + return result; #endif } -} // namespace awq -} // namespace vllm +} // namespace awq +} // namespace vllm diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 5aefb0bd16aef..bb8e5bbb23d7f 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -1,14 +1,12 @@ /* Adapted from https://github.com/mit-han-lab/llm-awq @article{lin2023awq, - title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, - author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, - journal={arXiv}, - year={2023} + title={AWQ: Activation-aware Weight Quantization for LLM Compression and +Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, +Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ - #include #include @@ -20,26 +18,20 @@ namespace vllm { namespace awq { // Pack two half values. -static inline __device__ __host__ unsigned -__pack_half2(const half x, const half y) { - unsigned v0 = *((unsigned short *)&x); - unsigned v1 = *((unsigned short *)&y); +static inline __device__ __host__ unsigned __pack_half2(const half x, + const half y) { + unsigned v0 = *((unsigned short*)&x); + unsigned v1 = *((unsigned short*)&y); return (v1 << 16) | v0; } -template -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( - int G, - int split_k_iters, - half* __restrict__ A, - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - int M, - int IC, - int OC, - half* __restrict__ C) -{ +template +__global__ void __launch_bounds__(64) + gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters, + half* __restrict__ A, int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, int M, int IC, + int OC, half* __restrict__ C) { // Only support matrix n = 64 or 128 assert(N == 64 || N == 128); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 @@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( static constexpr int row_stride = 2 * 32 * 8 / N; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N; // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + bool ld_A_flag = + (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A - + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC - + (((int)threadIdx.x) % (32 / 8)) * 8; - - int* B_ptr = B - + ((int)threadIdx.y) * (OC / 8) * (256 / N) - + (((int)threadIdx.x) / (N / 8)) * (OC / 8) - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + (((int)threadIdx.x) % (N / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) - + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) - + (((int)threadIdx.x) % (32 / 8) ) * 8; - - half* B_shared_ptr = B_shared - + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) - + (((int)threadIdx.x) / (N / 8)) * (N + 8) - + (((int)threadIdx.x) % (N / 8)) * 8; - - int* zeros_ptr = zeros - + (((int)blockIdx_y) % j_factors1) * (N / 8) - + ((int)threadIdx.x) % (N / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * N - + (((int)threadIdx.x) % (N / 8)) * 8; - - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim - + (((int)blockIdx_y) % j_factors1) * N - + ((int)threadIdx.y) * (N / 2) - + (((int)threadIdx.x) % 4) * 2; + half* A_ptr = + A + + (((int)blockIdx_y) / j_factors1 * 16 + + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * + IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * (256 / N) + + (((int)threadIdx.x) / (N / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (N / 8) + + (((int)threadIdx.x) % (N / 8)) * 1; + // Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (N + 8) + + (((int)threadIdx.x) / (N / 8)) * (N + 8) + + (((int)threadIdx.x) % (N / 8)) * 8; + + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (N / 8) + + ((int)threadIdx.x) % (N / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * N + + (((int)threadIdx.x) % (N / 8)) * 8; + + half* C_ptr = + C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * N + ((int)threadIdx.y) * (N / 2) + + (((int)threadIdx.x) % 4) * 2; // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; @@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - if (ld_A_flag) - { + if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); - } - else - { + } else { *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + uint4 B_loaded_scale = + *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && + threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, + B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, + B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); } */ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) { - // B: 32 x 136 (128+8) float16 // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) - uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus + // zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * + // 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) + // * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * + // 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * + // 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * + // 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = + *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / + // 8)) * 8); - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x + // % (cta_N / 8)) * 8); // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = + // q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == + 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", + B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); } */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = + B_loaded_fp16; } __syncthreads(); @@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32( { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(A_shared[(k_0_1 * 16)])) + + (((((int)threadIdx.x) & 15) * 40) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(A_shared_warp + 0))[0]), + "=r"(((unsigned*)(A_shared_warp + 0))[1]), + "=r"(((unsigned*)(A_shared_warp + 0))[2]), + "=r"(((unsigned*)(A_shared_warp + 0))[3]) + : "r"(addr)); } for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) { { unsigned int addr; __asm__ __volatile__( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8)))) - ); + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, " + "addr; }\n" + : "=r"(addr) + : "l"((void*)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + + (((int)threadIdx.y) * (N / 2))) + + (ax1_0 * 16))])) + + (((((int)threadIdx.x) & 15) * (N + 8)) + + ((((int)threadIdx.x) >> 4) * 8))))); __asm__ __volatile__( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[0]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[1]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[2]), + "=r"(((unsigned*)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr)); } } for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#else + #else { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "=f"(((float*)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[0]), + "r"(((unsigned*)(B_shared_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[0]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[1]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[2]), + "f"(((float*)(C_warp + (j_0_4 * 8)))[3])); } { __asm__ __volatile__( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, " + "%13};\n" + : "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "=f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned*)(A_shared_warp + 0))[0]), + "r"(((unsigned*)(A_shared_warp + 0))[1]), + "r"(((unsigned*)(A_shared_warp + 0))[2]), + "r"(((unsigned*)(A_shared_warp + 0))[3]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), + "r"(((unsigned*)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[0]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[1]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[2]), + "f"(((float*)(C_warp + ((j_0_4 * 8) + 4)))[3])); } -#endif + #endif } } } -// TODO: Shang: Hoist loop invariance. + // TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { - int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; - if (row_offset < M) - { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); } } } #endif } -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, - half* __restrict__ scaling_factors, - int* __restrict__ zeros, - half* __restrict__ C, - int G -) -{ +__global__ void __launch_bounds__(64) + dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors, + int* __restrict__ zeros, half* __restrict__ C, int G) { int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; @@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights( uint32_t B_loaded = *(uint32_t*)B_ptr2; uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.x) + : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.y) + : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.z) + : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(B_loaded_fp16.w) + : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); *(uint4*)B_shared_ptr2 = B_loaded_fp16; @@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights( } } -} // namespace awq -} // namespace vllm - -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } +} // namespace awq +} // namespace vllm + +torch::Tensor awq_dequantize(torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, int split_k_iters, int thx, + int thy) { + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx == 0) { + x_thread = qout_c; + } + if (thy == 0) { + y_thread = in_c; + } + if (thx == 0 && thy == 0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + auto options = torch::TensorOptions() + .dtype(_scaling_factors.dtype()) + .device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::awq::dequantize_weights<<>>( - kernel, scaling_factors, zeros, de_kernel, G); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -386,61 +489,61 @@ torch::Tensor awq_dequantize( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) -{ - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); - - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); - - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); - - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 128 == 0) - { - int j_factors1 = num_out_channels / 128 / 1; - dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - else if (num_out_channels % 64 == 0) - { - int j_factors1 = num_out_channels / 64 / 1; - dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] - dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<>>( - group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, - num_out_channels, out_feats); - } - return _out_feats.sum(0); +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters) { + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions() + .dtype(_in_feats.dtype()) + .device(_in_feats.device()); + at::Tensor _out_feats = + torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = + reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (num_out_channels % 128 == 0) { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } else if (num_out_channels % 64 == 0) { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * + split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64> + <<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, + num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 3ec454f78c654..e62fe731a98d3 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -117,10 +117,10 @@ struct cutlass_2x_gemm { }; template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, using StrideC = Stride, Int<0>>; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - auto c_ptr = static_cast(out.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); auto a_scales_ptr = a_scales.data_ptr(); auto b_scales_ptr = b_scales.data_ptr(); @@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, } // namespace -void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a, } } -void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); @@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a, } } -void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 37b096de23e3b..12efcac7bb919 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -120,10 +120,10 @@ struct cutlass_3x_gemm { }; template -void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; - auto c_ptr = static_cast(out.data_ptr()); + auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; @@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a, } } // namespace -void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index a4e696d4a3322..dab73ac6c831e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -2,29 +2,29 @@ #include #include -void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, - torch::Tensor const &a_scales, - torch::Tensor const &b_scales); +void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); -void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, - torch::Tensor const &b, torch::Tensor const &a_scales, - torch::Tensor const &b_scales) { +void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { int32_t major_capability; int32_t minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, @@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a, // Checks for conformality TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && - b.size(1) == c.size(1)); + b.size(1) == c.size(1)); TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); // Check for strides and alignment - TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major - TORCH_CHECK(b.stride(0) == 1); // Column-major - TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); diff --git a/csrc/quantization/fp8/amd/hip_float8.h b/csrc/quantization/fp8/amd/hip_float8.h index 87c7c9ce66100..f9c80fcdec576 100644 --- a/csrc/quantization/fp8/amd/hip_float8.h +++ b/csrc/quantization/fp8/amd/hip_float8.h @@ -1,167 +1,137 @@ #pragma once #ifdef __HIPCC__ -#include + #include #else -#include -#include -#include -#include + #include + #include + #include + #include #endif #include "hip_float8_impl.h" -struct alignas(1) hip_fp8 -{ - struct from_bits_t - { - }; - HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } - uint8_t data; - - hip_fp8() = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; - HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; - explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) - : data(v) - { - } +struct alignas(1) hip_fp8 { + struct from_bits_t {}; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) {} #ifdef __HIP__MI300__ - // NOTE: ON-DEVICE... always optimal bias - explicit HIP_FP8_DEVICE hip_fp8(float v) - : data(hip_fp8_impl::to_fp8_from_fp32(v)) - { - } - - explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) - : hip_fp8(static_cast(v)) - { - } - - // Host only implementation using s/w simulation - explicit HIP_FP8_HOST -#else // __HIP__MI300__ - // both Host and DEVICE for non-MI300 using s/w simulation - explicit HIP_FP8_HOST_DEVICE -#endif // __HIP__MI300__ - hip_fp8(float v) - { - data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); - } - - explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) - : hip_fp8(static_cast(v)) - { - } + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) {} + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, + true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) {} #ifdef __HIP__MI300__ - // upcast using device specific intrinsic - explicit inline HIP_FP8_DEVICE operator float() const - { - float fval; - uint32_t i32val = static_cast(data); - - // upcast - asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); - - return fval; - } - - explicit inline HIP_FP8_HOST operator float() const -#else // __HIP__MI300__ - explicit inline HIP_FP8_HOST_DEVICE operator float() const -#endif // __HIP__MI300__ - { - return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); - } + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" + : "=v"(fval) + : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( + data); + } }; -namespace std -{ -inline hip_fp8 sin(hip_fp8 a) -{ - return hip_fp8(sinf(float(a))); -} -inline hip_fp8 cos(hip_fp8 a) -{ - return hip_fp8(cosf(float(a))); -} -HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) -{ - return a; -} -} // namespace std +namespace std { +inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } +inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } +} // namespace std // Special operator overloading -inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) -{ - return os << float(f8); +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { + return os << float(f8); } // all + operator overloading with mixed types -// mixed types, always converts to f32, does computation in f32, and returns float -inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) -{ - return (fa + float(b)); +// mixed types, always converts to f32, does computation in f32, and returns +// float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { + return (fa + float(b)); } -inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) -{ - return (float(a) + fb); +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { + return (float(a) + fb); } -inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) -{ - return hip_fp8(float(a) + float(b)); +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { + return hip_fp8(float(a) + float(b)); } -inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) -{ - return a = hip_fp8(float(a) + float(b)); +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { + return a = hip_fp8(float(a) + float(b)); } // overloading multiplication, always returns float, -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) -{ - return float(a) * float(b); +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { + return float(a) * float(b); } -inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) -{ - return (a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { + return (a * float(b)); } -inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) -{ - return (float(a) * b); +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { + return (float(a) * b); } -inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) -{ - return ((float)a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { + return ((float)a * float(b)); } -inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) -{ - return ((float)a * float(b)); +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { + return ((float)a * float(b)); } // overloading for compare -inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) -{ - return (a.data == b.data); +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { + return (a.data == b.data); } -inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) -{ - return (a.data != b.data); +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { + return (a.data != b.data); } -inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) >= static_cast(b); +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { + return static_cast(a) >= static_cast(b); } -inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) -{ - return static_cast(a) > static_cast(b); +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { + return static_cast(a) > static_cast(b); } diff --git a/csrc/quantization/fp8/amd/hip_float8_impl.h b/csrc/quantization/fp8/amd/hip_float8_impl.h index e05905b4e49e8..90251c3539534 100644 --- a/csrc/quantization/fp8/amd/hip_float8_impl.h +++ b/csrc/quantization/fp8/amd/hip_float8_impl.h @@ -1,316 +1,316 @@ #pragma once -#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) -#define __HIP__MI300__ +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ #endif #ifdef __HIPCC__ -#define HIP_FP8_HOST_DEVICE __host__ __device__ -#define HIP_FP8_HOST __host__ -#define HIP_FP8_DEVICE __device__ + #define HIP_FP8_HOST_DEVICE __host__ __device__ + #define HIP_FP8_HOST __host__ + #define HIP_FP8_DEVICE __device__ #else -#define HIP_FP8_HOST_DEVICE -#define HIP_FP8_HOST -#define HIP_FP8_DEVICE + #define HIP_FP8_HOST_DEVICE + #define HIP_FP8_HOST + #define HIP_FP8_DEVICE #endif -namespace hip_fp8_impl -{ +namespace hip_fp8_impl { #ifdef __HIP__MI300__ -HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) -{ - uint8_t i8data; - union { - float fval; - uint32_t i32val; - uint8_t i8val[4]; // NOTE: not endian independent - } val; - - uint32_t ival = 0; - val.fval = v; - - if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping - val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); - } - - ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, - false); // false -> WORD0 - val.i32val = ival; - i8data = val.i8val[0]; - - return i8data; +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; } -#endif // __HIP__MI300__ +#endif // __HIP__MI300__ -HIP_FP8_HOST inline int clz(uint32_t x) -{ - return __builtin_clz(x); -} +HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } #if defined(__HIPCC__) || defined(__CUDA_ARCH__) -HIP_FP8_DEVICE inline int clz(uint32_t x) -{ - return __clz(x); -} +HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } #endif template -HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) -{ +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, + uint32_t rng = 0) { #ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; + constexpr bool is_half = std::is_same::value; #else - constexpr bool is_half = false; + constexpr bool is_half = false; #endif - constexpr bool is_float = std::is_same::value; - static_assert(wm + we == 7, "wm+we==7"); - static_assert(is_half || is_float, "Only half and float can be cast to f8"); - - const int mfmt = (sizeof(T) == 4) ? 23 : 10; - uint32_t x; + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { if (sizeof(T) == 4) { - x = reinterpret_cast(_x); + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } } else { - x = reinterpret_cast(_x); + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } } - - uint32_t head, mantissa; - int exponent, bias; - uint32_t sign; - + } else { if (sizeof(T) == 4) { - head = x & 0xFF800000; - mantissa = x & 0x7FFFFF; - exponent = (head >> 23) & 0xFF; - sign = head >> 31; - bias = 127; + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } } else { - head = x & 0xFC00; - mantissa = x & 0x3FF; - exponent = (head >> 10) & 0x1F; - sign = head >> 15; - bias = 15; + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } } - - uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); - - // Deal with inf and NaNs - if (negative_zero_nan) { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return 0x80; - } - } else { - // if(__hisinf(x) || __hisnan(x)) - if ((x & 0x7C00) == 0x7C00) { - return 0x80; - } - } - } else { - if (sizeof(T) == 4) { - if ((x & 0x7F800000) == 0x7F800000) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } else { - if ((x & 0x7C00) == 0x7C00) { - return signed_inf + (mantissa != 0 ? 1 : 0); - } - } - } - if (x == 0) { - return 0; - } - - // First need to check if it is normal or denorm as there is a difference of - // implicit 1 Then need to adjust the exponent to align with the F8 exponent, - // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng - // to mantissa and truncate. And for RNE, no need to add rng. Then probably - // need to check whether there is carry and adjust exponent and mantissa again - - // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent - // bits - const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); - const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal - // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) - // f8_exponent is the converted f8 exponent with bias encoding - // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, - // the difference needs to be adjusted and mantissa shifted - int act_exponent, f8_exponent, exponent_diff; - - if (exponent == 0) { // fp32/fp16 is in denormal. - /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ - act_exponent = exponent - bias + 1; - exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal - } else { // fp32/fp16 is normal with implicit 1 - act_exponent = exponent - bias; - if (act_exponent <= f8_denormal_act_exponent) { - /* This is the case where fp32/fp16 is normal but it is in f8 denormal - range. For example fp8 nanoo mode, denormal exponent is -7, but if the - fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, - Therefore it needs to be adjust to -6 and mantissa shift right by 1. - So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ - exponent_diff = f8_denormal_act_exponent - act_exponent; - } else { // both fp32/fp16 and f8 are in normal range - exponent_diff = 0; // exponent_diff=0 does not mean there is no difference - // for this case, - // act_exponent could be larger. Just that it does not need shift mantissa - } - mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the +fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no + // difference for this case, act_exponent could be + // larger. Just that it does not need shift mantissa } - - bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == - static_cast(1 << (mfmt - wm + exponent_diff - 1)); - /* This part is a bit tricky. The judgment of whether it is a tie needs to be - done before we shift right as shift right could rip off some residual part - and make something not midpoint look like midpoint. For example, the fp16 - number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after - shift right by 4 bits, it would look like midpoint. + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. */ - if (exponent_diff > 0) { - mantissa >>= exponent_diff; - } else if (exponent_diff == -1) { - mantissa <<= -exponent_diff; + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit + // that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent } - bool implicit_one = mantissa & (1 << mfmt); - // if there is no implicit 1, it means the f8 is denormal and need to adjust - // to denorm exponent - f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); - - // Now we have the exponent and mantissa adjusted - uint32_t drop_mask = (1 << (mfmt - wm)) - 1; - bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that - // is not truncated is 1 - mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; - - // Now we deal with overflow - if (f8_exponent == 0) { - if ((1 << mfmt) & mantissa) { - f8_exponent = 1; // denormal overflow to become normal, promote exponent - } - } else { - if ((1 << (mfmt + 1)) & mantissa) { - mantissa >>= 1; - f8_exponent++; - } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; } + } - mantissa >>= (mfmt - wm); - - // above range: quantize to maximum possible float of the same sign - const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); - if (f8_exponent > max_exp) { - if (clip) { - mantissa = (1 << wm) - 1; - f8_exponent = max_exp; - } else { - return signed_inf; - } - } + mantissa >>= (mfmt - wm); - if (f8_exponent == 0 && mantissa == 0) { - return negative_zero_nan ? 0 : (sign << 7); + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; } - mantissa &= (1 << wm) - 1; - return (sign << 7) | (f8_exponent << wm) | mantissa; + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; } template -inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) -{ +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { #ifdef __HIPCC__ - constexpr bool is_half = std::is_same::value; + constexpr bool is_half = std::is_same::value; #else - constexpr bool is_half = false; + constexpr bool is_half = false; #endif - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "only half and float are supported"); + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); - constexpr int weo = is_half ? 5 : 8; - constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); - T fInf, fNegInf, fNaN, fNeg0; + T fInf, fNegInf, fNaN, fNeg0; #ifdef __HIPCC__ - if (is_half) { - const uint16_t ihInf = 0x7C00; - const uint16_t ihNegInf = 0xFC00; - const uint16_t ihNaN = 0x7C01; - const uint16_t ihNeg0 = 0x8000; - fInf = reinterpret_cast(ihInf); - fNegInf = reinterpret_cast(ihNegInf); - fNaN = reinterpret_cast(ihNaN); - fNeg0 = reinterpret_cast(ihNeg0); - } else + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else #endif - if (is_float) { - const uint32_t ifInf = 0x7F800000; - const uint32_t ifNegInf = 0xFF800000; - const uint32_t ifNaN = 0x7F800001; - const uint32_t ifNeg0 = 0x80000000; - fInf = reinterpret_cast(ifInf); - fNegInf = reinterpret_cast(ifNegInf); - fNaN = reinterpret_cast(ifNaN); - fNeg0 = reinterpret_cast(ifNeg0); - } - - if (x == 0) { - return 0; - } - - uint32_t sign = x >> 7; - uint32_t mantissa = x & ((1 << wm) - 1); - int exponent = (x & 0x7F) >> wm; - if (negative_zero_nan) { - if (x == 0x80) { - return fNaN; - } - } else { - if (x == 0x80) { - return fNeg0; - } - if (exponent == ((1 << we) - 1)) { - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; - } - } - typename std::conditional::type retval; - if (we == 5 && is_half && !negative_zero_nan) { - retval = x << 8; - return reinterpret_cast(retval); + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; } - - const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); - - // subnormal input - if (exponent == 0) { - // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + clz(mantissa) - (32 - wm); - mantissa <<= sh; - exponent += 1 - sh; - mantissa &= ((1 << wm) - 1); + } else { + if (x == 0x80) { + return fNeg0; } - exponent += exp_low_cutoff - 1; - mantissa <<= wmo - wm; - - // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) - if (exponent <= 0) { - mantissa |= 1 << wmo; - mantissa >>= 1 - exponent; - exponent = 0; - } - - if (sizeof(T) == 2) { - retval = (sign << 15) | (exponent << 10) | mantissa; - } else { - retval = (sign << 31) | (exponent << 23) | mantissa; + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); } -} // namespace hip_fp8_impl +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index df0329f79d361..35123d7fc65d4 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -9,566 +9,567 @@ #include "../../../attention/dtype_float32.cuh" #include "../../../attention/dtype_bfloat16.cuh" -namespace vllm -{ +namespace vllm { #ifdef USE_ROCM namespace fp8 { -#ifdef ENABLE_FP8 + #ifdef ENABLE_FP8 template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; } template -__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale) -{ - return x; +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; } // fp8 -> half template <> -__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8); - return res.x; +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0]; - tmp.h2r.y.data = f2[1]; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = vec_conversion(static_cast(a)); - tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); - return tmp.u32; -#endif +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const uint32_t& a) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = vec_conversion((uint16_t)a); - tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); - return tmp.u32x2; +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const uint2& a) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = vec_conversion(a.x); - tmp.u64[1] = vec_conversion(a.y); - return tmp.u64x2; +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f); +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) -{ - __nv_bfloat162 res; - res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); - res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); - return res; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) -{ - bf16_4_t res; - res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); - res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) -{ - bf16_4_t tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float vec_conversion(const uint8_t& a) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8); +__inline__ __device__ float vec_conversion(const uint8_t& a) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); } // fp8x2 -> float2 template <> -__inline__ __device__ float2 vec_conversion(const uint16_t& a) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0]; - res.y = f2[1]; - return res; -#else - float2 res; - res.x = vec_conversion(static_cast(a)); - res.y = vec_conversion(static_cast(a >> 8U)); - return res; -#endif +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; + #else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) -{ - Float4_ res; - res.x = vec_conversion((uint16_t)a); - res.y = vec_conversion((uint16_t)(a >> 16U)); - return res; +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ vec_conversion(const uint2& a) -{ - Float4_ tmp1, tmp2; - tmp1 = vec_conversion(a.x); - tmp2 = vec_conversion(a.y); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // half -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) -{ - hip_fp8 res{__bfloat162float(a)}; - return res.data; +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + hip_fp8 res{__bfloat162float(a)}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t vec_conversion(const float& a) -{ - hip_fp8 f8(a); - return f8.data; +__inline__ __device__ uint8_t vec_conversion(const float& a) { + hip_fp8 f8(a); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const uint32_t& a) -{ - Float4_ tmp = vec_conversion(a); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } // float2 -> half2 template <> -__inline__ __device__ uint32_t vec_conversion(const float2& a) -{ - union { - half2 float16; - uint32_t uint32; - }; +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; - float16 = __float22half2_rn(a); - return uint32; + float16 = __float22half2_rn(a); + return uint32; } // Float4 -> half2x2 template <> -__inline__ __device__ uint2 vec_conversion(const Float4_& a) -{ - uint2 b; - float2 val; - val.x = a.x.x; - val.y = a.x.y; - b.x = vec_conversion(val); +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); - val.x = a.y.x; - val.y = a.y.y; - b.y = vec_conversion(val); - return b; + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; } // Float4 -> float4 template <> -__inline__ __device__ float4 vec_conversion(const Float4_& a) -{ - float4 b; - b.x = a.x.x; - b.y = a.x.y; - b.z = a.y.x; - b.w = a.y.y; - return b; +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; } // Float8 -> half2x4 template <> -__inline__ __device__ uint4 vec_conversion(const Float8_& a) -{ - uint4 b; - b.x = vec_conversion(a.x); - b.y = vec_conversion(a.y); - b.z = vec_conversion(a.z); - b.w = vec_conversion(a.w); - return b; +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; } // float2 -> bfloat162 template <> -__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) -{ - __nv_bfloat162 b = __float22bfloat162_rn(a); - return b; +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; } // Float4 -> bfloat162x2 template <> -__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) -{ - bf16_4_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - return b; +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; } // Float8 -> bfloat162x4 template <> -__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) -{ - bf16_8_t b; - b.x = __float22bfloat162_rn(a.x); - b.y = __float22bfloat162_rn(a.y); - b.z = __float22bfloat162_rn(a.z); - b.w = __float22bfloat162_rn(a.w); - return b; +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; } +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains -/* Scaled and vectorized conversions, for data exchange between high and low precision domains - - Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) - s.t. - Quantize(HP / scale) => FP8 - Dequant(FP8) * scale => HP + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP */ // fp8 -> half template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, const float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, const float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, const float scale) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); - return tmp.u32; -#endif +__inline__ __device__ uint32_t scaled_vec_conversion( + const uint16_t& a, const float scale) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, const float scale) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, const float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; +__inline__ __device__ uint4 +scaled_vec_conversion(const uint2& a, const float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, + const float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) -{ - __nv_bfloat162 res; - res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); - return res; +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + const float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, const float scale) -{ - bf16_4_t res; - res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ bf16_4_t scaled_vec_conversion( + const uint32_t& a, const float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, const float scale) -{ - bf16_4_t tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, const float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, const float scale) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, const float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; } // fp8x2 -> float2 template <> -__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, const float scale) -{ -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; -#else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); - return res; -#endif +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, const float scale) { + #if defined(__HIP__MI300__) && \ + defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ res; - res.x = scaled_vec_conversion((uint16_t)a, scale); - res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, const float scale) -{ - Float4_ tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, const float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } - /* Quantize(HP / scale) => FP8 */ // TODO(Hai): vectorized to add // half -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, const float scale) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, const float scale) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)/scale}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data) / scale}; + return f8.data; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, const float scale) -{ - hip_fp8 res{__bfloat162float(a)/scale}; - return res.data; +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, const float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; } // float -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, const float scale) -{ - hip_fp8 f8(a/scale); - return f8.data; +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, const float scale) { + hip_fp8 f8(a / scale); + return f8.data; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ tmp = scaled_vec_conversion(a, scale); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; } -#endif // ENABLE_FP8 + #endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin &x) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x); } -#endif + #endif assert(false); } template -__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale); } -#endif + #endif assert(false); } -// The following macro is used to dispatch the conversion function based on the -// data type of the key and value cache. The FN is a macro that calls a function -// with template. -#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ - } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ - } \ - } + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } -} // fp8 -#endif // USE_ROCM -} // namespace vllm +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index b9c5d39277ca5..55be3305a9b8c 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -10,17 +10,20 @@ namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { - float old; - old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : - __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); - return old; + return old; } #define FP8_E4M3_MAX std::numeric_limits::max() -template -__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( + const scalar_t val, const float scale) { float x = static_cast(val) / scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); @@ -32,11 +35,10 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar // So to get the right answer, *scale needs to be initialized to // a value <= 0.0 and we need to wait for all thread blocks to // finish before consuming *scale. -template -__global__ void segmented_max_reduction( - float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { __shared__ float cache[1024]; int i = blockDim.x * blockIdx.x + threadIdx.x; @@ -56,7 +58,7 @@ __global__ void segmented_max_reduction( int ib = blockDim.x / 2; while (ib != 0) { if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; + cache[threadIdx.x] = cache[threadIdx.x + ib]; } __syncthreads(); ib /= 2; @@ -64,16 +66,16 @@ __global__ void segmented_max_reduction( // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); + atomicMaxFloat(scale, + cache[0] / std::numeric_limits::max()); } } -template -__global__ void scaled_fp8_quant_kernel( - c10::Float8_e4m3fn* __restrict__ out, - const scalar_t* __restrict__ input, - const float* __restrict__ scale, - int64_t num_elems) { +template +__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, + const scalar_t* __restrict__ input, + const float* __restrict__ scale, + int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { out[i] = scaled_fp8_conversion(input[i], *scale); @@ -81,12 +83,11 @@ __global__ void scaled_fp8_quant_kernel( } } -} // namespace vllm +} // namespace vllm -void static_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -95,21 +96,16 @@ void static_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "scaled_fp8_quant_kernel", - [&] { - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); }); } -void dynamic_scaled_fp8_quant( - torch::Tensor& out, // [..., d] - torch::Tensor& input, // [..., d] - torch::Tensor& scale) // [1] +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); @@ -118,18 +114,11 @@ void dynamic_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "scaled_fp8_quant_kernel", - [&] { - vllm::segmented_max_reduction<<>>( - scale.data_ptr(), - input.data_ptr(), - num_elems); - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - scale.data_ptr(), - num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel", [&] { + vllm::segmented_max_reduction<<>>( + scale.data_ptr(), input.data_ptr(), num_elems); + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); }); } - diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index 4eeacf7a6f9d9..cde26dbda18cf 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -10,9 +10,9 @@ namespace vllm { #ifndef USE_ROCM namespace fp8 { -#ifdef ENABLE_FP8 + #ifdef ENABLE_FP8 -#if 0 // Disable the following code to reduce the binary size. + #if 0 // Disable the following code to reduce the binary size. template __inline__ __device__ Tout vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { @@ -177,13 +177,13 @@ __inline__ __device__ uint8_t vec_conversion( template <> __inline__ __device__ uint8_t vec_conversion( const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); -#else + #else __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); return (uint8_t)res; -#endif + #endif } // float -> fp8 @@ -276,7 +276,7 @@ __inline__ __device__ bf16_8_t vec_conversion( from_float(b, a); return b; } -#endif + #endif /* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = @@ -286,14 +286,14 @@ __inline__ __device__ bf16_8_t vec_conversion( template __inline__ __device__ Tout scaled_vec_conversion( - const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) { + const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { return x; } // fp8 -> half template <> __inline__ __device__ uint16_t scaled_vec_conversion( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); return float_to_half(half_to_float(tmp.x) * scale); @@ -302,7 +302,7 @@ __inline__ __device__ uint16_t scaled_vec_conversion( // fp8x2 -> half2 template <> __inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint16_t u16[2]; @@ -317,7 +317,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion( // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint2 u32x2; @@ -333,7 +333,7 @@ __inline__ __device__ uint2 scaled_vec_conversion( // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 -scaled_vec_conversion(const uint2 &a, const float scale, +scaled_vec_conversion(const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint4 u64x2; @@ -348,7 +348,7 @@ scaled_vec_conversion(const uint2 &a, const float scale, template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // Note there is no direct convert function from fp8 to bf16. // fp8 -> half @@ -362,7 +362,7 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>( template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, @@ -375,7 +375,7 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>( // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, @@ -388,7 +388,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion( // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t scaled_vec_conversion( - const uint2 &a, const float scale, + const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -404,9 +404,8 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion( // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( - const uint8_t &a, const float scale, + const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { - // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); uint16_t tmp = res.x; @@ -418,7 +417,7 @@ __inline__ __device__ float scaled_vec_conversion( // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8x2 -> half2 uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); @@ -429,7 +428,7 @@ __inline__ __device__ float2 scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ Float4_ scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ res; res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); @@ -441,7 +440,7 @@ __inline__ __device__ Float4_ scaled_vec_conversion( // fp8x8 -> float8 template <> __inline__ __device__ Float8_ scaled_vec_conversion( - const uint2 &a, const float scale, + const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); @@ -457,7 +456,7 @@ __inline__ __device__ Float8_ scaled_vec_conversion( // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const uint16_t &a, const float scale, + const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); @@ -467,21 +466,21 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16 &a, const float scale, + const __nv_bfloat16& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); -#else + #else __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; -#endif + #endif } // float -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const float &a, const float scale, + const float& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); @@ -491,78 +490,81 @@ __inline__ __device__ uint8_t scaled_vec_conversion( // fp8x4 -> float4 template <> __inline__ __device__ float4 scaled_vec_conversion( - const uint32_t &a, const float scale, + const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } -#endif // ENABLE_FP8 + #endif // ENABLE_FP8 template -__inline__ __device__ Tout convert(const Tin &x) { -#if 0 // Disable the following code to reduce the binary size. +__inline__ __device__ Tout convert(const Tin& x) { + #if 0 // Disable the following code to reduce the binary size. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return vec_conversion(x, __NV_E5M2); } -#endif + #endif assert(false); } template -__inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) { -#ifdef ENABLE_FP8 +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, __NV_E5M2); } -#endif + #endif assert(false); } -// The following macro is used to dispatch the conversion function based on the -// data type of the key and value cache. The FN is a macro that calls a function -// with template. -#define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ - if (KV_DTYPE == "auto") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ - } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ - } \ - } else { \ - if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ - } else if (KV_DTYPE == "fp8_e5m2") { \ - if (SRC_DTYPE == at::ScalarType::Float) { \ - FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::Half) { \ - FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ - } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ - FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else if (KV_DTYPE == "fp8_e5m2") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ } else { \ - TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ - } else { \ - TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ - } \ - } + } -} // namespace fp8 -#endif // not USE_ROCM -} // namespace vllm +} // namespace fp8 +#endif // not USE_ROCM +} // namespace vllm diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh index 4da0bc6e2df38..1b3fb3d39103f 100644 --- a/csrc/quantization/gptq/compat.cuh +++ b/csrc/quantization/gptq/compat.cuh @@ -9,54 +9,54 @@ namespace vllm { namespace gptq { // atomicAdd for half types, to support CC < 7.x -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; +__device__ __forceinline__ void atomicAdd_half(half* address, half val) { + unsigned int* address_as_ui = + (unsigned int*)((char*)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) + : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); } // atomicAdd for half2 types -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } while (assumed != old); } // #if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + #if __CUDA_ARCH__ < 700 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } +__device__ __forceinline__ void atomicAdd(half* address, half val) { + atomicAdd_half(address, val); +} -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif + #if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { + atomicAdd_half2(address, val); +} + #endif -#endif + #endif #endif } // namespace gptq diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index eda3436eb5375..2b6719fbdc1bc 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/turboderp/exllama */ #ifndef _matrix_view_cuh @@ -13,260 +14,280 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo namespace vllm { namespace gptq { -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } - - __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const - { - half2* ptr = (half2*) item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __low2half(i01); - items[1] = __high2half(i01); - items[2] = __low2half(i23); - items[3] = __high2half(i23); - } - __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2float(__low2half(i01)); - items[1] = __half2float(__high2half(i01)); - items[2] = __half2float(__low2half(i23)); - items[3] = __half2float(__high2half(i23)); - } - - __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const - { - half2* ptr = (half2*)item_ptr(row, column); - half2 i01 = ptr[0]; - half2 i23 = ptr[1]; - items[0] = __half2half2(__low2half(i01)); - items[1] = __half2half2(__high2half(i01)); - items[2] = __half2half2(__low2half(i23)); - items[3] = __half2half2(__high2half(i23)); - } +class MatrixView_half { + public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { + return __half2half2(data[row * width + column]); + } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { + return &data[row * width + column]; + } + + __device__ __forceinline__ void item4(half (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, + int column) const { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } }; -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } - - __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) - { - half2 v01 = __halves2half2(v0, v1); - half2 v23 = __halves2half2(v2, v3); - half2* ptr = (half2*) item_ptr(row, column); - ptr[0] = v01; - ptr[1] = v23; - } +class MatrixView_half_rw { + public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ half item(int row, int column) const { + return data[row * width + column]; + } + __device__ __forceinline__ half2 item_half2(int row, int column) const { + return ((half2*)data)[(row * width + column) / 2]; + } + __device__ __forceinline__ half* item_ptr(int row, int column) { + return &data[row * width + column]; + } + __device__ __forceinline__ void set(int row, int column, half value) { + data[row * width + column] = value; + } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { + ((half2*)data)[(row * width + column) / 2] = value; + } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, + half v2, half v3) { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*)item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } }; -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x07) * 4; - uint32_t d = data[row * width / 8 + column / 8] >> shift; - items[0] = d & 0x0f; - items[1] = (d >> 4) & 0x0f; - items[2] = (d >> 8) & 0x0f; - items[3] = (d >> 12) & 0x0f; - } +class MatrixView_q4_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } }; -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +class MatrixView_q4_column { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { + return data[row / 8 * width + column]; + } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, + int column) { + return &data[row / 8 * width + column]; + } }; -class MatrixView_q2_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x0f) * 2; - return (data[row * width / 16 + column / 16] >> shift) & 0x03; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x0f) * 2; - uint32_t d = data[row * width / 16 + column / 16] >> shift; - items[0] = d & 0x03; - items[1] = (d >> 2) & 0x03; - items[2] = (d >> 4) & 0x03; - items[3] = (d >> 6) & 0x03; - } +class MatrixView_q2_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x0f) * 2; + return (data[row * width / 16 + column / 16] >> shift) & 0x03; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x0f) * 2; + uint32_t d = data[row * width / 16 + column / 16] >> shift; + items[0] = d & 0x03; + items[1] = (d >> 2) & 0x03; + items[2] = (d >> 4) & 0x03; + items[3] = (d >> 6) & 0x03; + } }; -class MatrixView_q3_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int z_w = column * 3 / 32; - int z_mod = column & 0x1f; - - if (z_mod == 10) { - return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); - } else if (z_mod == 21) { - return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); - } else if (z_mod < 10) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; - } else if (z_mod < 21) { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; - } else { - return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; - } +class MatrixView_q3_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int z_w = column * 3 / 32; + int z_mod = column & 0x1f; + + if (z_mod == 10) { + return (data[row * width * 3 / 32 + z_w] >> 30) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4); + } else if (z_mod == 21) { + return (data[row * width * 3 / 32 + z_w] >> 31) | + ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6); + } else if (z_mod < 10) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07; + } else if (z_mod < 21) { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07; + } else { + return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07; } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x1f); - uint32_t d; - if (shift <= 4) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); - } else if (shift == 8) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); - } else if (shift <= 16) { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); - } else if (shift == 20) { - d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); - } else { - d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); - } - items[0] = d & 0x07; - items[1] = (d >> 3) & 0x07; - items[2] = (d >> 6) & 0x07; - items[3] = (d >> 9) & 0x07; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x1f); + uint32_t d; + if (shift <= 4) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3); + } else if (shift == 8) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8); + } else if (shift <= 16) { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32); + } else if (shift == 20) { + d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | + ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4); + } else { + d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64); } + items[0] = d & 0x07; + items[1] = (d >> 3) & 0x07; + items[2] = (d >> 6) & 0x07; + items[3] = (d >> 9) & 0x07; + } }; -class MatrixView_q8_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x03) * 8; - return (data[row * width / 4 + column / 4] >> shift) & 0xff; - } - - __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const - { - int shift = (column & 0x03) * 8; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - } - - __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const - { - int shift = (column & 0x03) * 2; - uint32_t d = data[row * width / 4 + column / 4] >> shift; - items[0] = d & 0xff; - items[1] = (d >> 8) & 0xff; - items[2] = (d >> 16) & 0xff; - items[3] = (d >> 24) & 0xff; - } +class MatrixView_q8_row { + public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, + const int height, + const int width) + : data(data), height(height), width(width) {} + + __device__ __forceinline__ int item(int row, int column) const { + int shift = (column & 0x03) * 8; + return (data[row * width / 4 + column / 4] >> shift) & 0xff; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, + int column) const { + int shift = (column & 0x03) * 8; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, + int column) const { + int shift = (column & 0x03) * 2; + uint32_t d = data[row * width / 4 + column / 4] >> shift; + items[0] = d & 0xff; + items[1] = (d >> 8) & 0xff; + items[2] = (d >> 16) & 0xff; + items[3] = (d >> 24) & 0xff; + } }; } // namespace gptq diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index cc56649917a8a..480c4986c3821 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1,5 +1,6 @@ /* -Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +Adapted from https://github.com/turboderp/exllamav2 and +https://github.com/qwopqwop200/GPTQ-for-LLaMa */ #include @@ -32,2044 +33,1824 @@ namespace gptq { #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) -#include -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); + #include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( + hipblasHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); } -#define hipblasHgemm __compat_hipblasHgemm + #define hipblasHgemm __compat_hipblasHgemm -// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_hgemm __compat_hipblasHgemm + // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. + #define rocblas_operation_none HIPBLAS_OP_N + #define rocblas_hgemm __compat_hipblasHgemm #endif -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hadd2(result, g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __half2float(__low2half(result)) + __half2float(__high2half(result)); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); } -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, + const half2 g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); - return fma(result_f, qs_f, g_result); +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, + const float g_result, + const float qs_f) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + float result_f = + __half2float(__low2half(result)) + __half2float(__high2half(result)); + return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h) -{ - // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127 - - float result = {}; - #pragma unroll - for (int i = 0; i < 4; i++) - { - half2 w01 = dq[i]; - float w0 = __low2float(w01); - float w1 = __high2float(w01); - float x0 = __half2float(*a_ptr++); - float x1 = __half2float(*a_ptr++); - result = fma(w0, x0, result); - result = fma(w1, x1, result); - } - float qs = __half2float(qs_h); - result *= qs; - half result_h = __float2half_rn(result); - return __hadd(result_h, g_result); +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, + const half g_result, + const half qs_h) { + // Use FP32 accumulator to avoid potential overflow since unscaled weights are + // in the range -128..127 + + float result = {}; +#pragma unroll + for (int i = 0; i < 4; i++) { + half2 w01 = dq[i]; + float w0 = __low2float(w01); + float w1 = __high2float(w01); + float x0 = __half2float(*a_ptr++); + float x1 = __half2float(*a_ptr++); + result = fma(w0, x0, result); + result = fma(w1, x1, result); + } + float qs = __half2float(qs_h); + result *= qs; + half result_h = __float2half_rn(result); + return __hadd(result_h, g_result); } -__forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } -__forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); - half result_h = __hadd(__low2half(result), __high2half(result)); - return __hfma(result_h, qs_h, g_result); +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, + const half g_result, + const half qs_h) { + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; +#pragma unroll + for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + half result_h = __hadd(__low2half(result), __high2half(result)); + return __hfma(result_h, qs_h, g_result); } - -typedef void (*fp_gemm_half_q_half_gptq_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const int* -); - +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, + const uint32_t*, const half*, + half*, const int, const int, + const int, const int, + const int*); template -__global__ void gemm_half_q_half_gptq_4bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_4bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - float scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - // Column result - float block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][4]; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); - } - - b_ptr += size_n; - a_ptr += 8; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 4; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], + block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], + block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], + block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], + block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), + __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), + __float2half_rn(block_c[m][3])); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_2bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_2bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 2); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 2); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - - b_ptr += size_n; - a_ptr += 16; - } - - k += 16; +#pragma unroll + for (int j = 0; j < 1; j++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + + b_ptr += size_n; + a_ptr += 16; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 16; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_3bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_3bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - // Zero output - if (n >= size_n) return; + // Zero output + if (n >= size_n) return; - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / 32 * 3; + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / 32 * 3; - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } - - #pragma unroll - for (int j = 0; j < 1; j++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 32; - } - - k += 32; +#pragma unroll + for (int j = 0; j < 1; j++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + +#pragma unroll + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 32; } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } template -__global__ void gemm_half_q_half_gptq_8bit_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, +__global__ void gemm_half_q_half_gptq_8bit_kernel( + const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int* __restrict__ b_q_perm -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Zero output - if (n >= size_n) return; - - if (blockIdx.z == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const int size_m, const int size_n, const int size_k, const int groups, + const int* __restrict__ b_q_perm) { + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) { + for (int m = 0; m < m_count; ++m) { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) + a0 = a_ptr[b_q_perm[offset_k + t]]; + else + a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; } + } - __syncthreads(); - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - int qk = offset_k / (32 / 8); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - int zeros[4]; - half scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - // Column result - half block_c[m_count][4] = {}; - - // Dequantize and multiply - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4(scales, group, n); - } + // Zero output + if (n >= size_n) return; - #pragma unroll - for (int j = 0; j < 4; j++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); - block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); - block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); - block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); - } - a_ptr += 8; - } - k += 32; + if (blockIdx.z == 0) { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 8); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + half scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); + // Column result + half block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4(scales, group, n); } - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); - half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); +#pragma unroll + for (int j = 0; j < 4; j++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + for (int m = 0; m < m_count; m++) { + block_c[m][0] = + dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], scales[0]); + block_c[m][1] = + dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], scales[1]); + block_c[m][2] = + dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], scales[2]); + block_c[m][3] = + dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], scales[3]); + } + a_ptr += 8; } + k += 32; + } + + for (int m = 0; m < m_count; m++) { + half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); + half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); + atomicAdd(out, result01); + atomicAdd(out + 1, result23); + } } fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( - bool first_block, const int m_count, const int bit) -{ - #define SELECT_KERNEL(M_COUNT) \ - if (m_count == M_COUNT) { \ - if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ - if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ - if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ - if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ - } - #if BLOCK_M_SIZE_MAX >= 1 - SELECT_KERNEL(1); - #endif - #if BLOCK_M_SIZE_MAX >= 2 - SELECT_KERNEL(2); - #endif - #if BLOCK_M_SIZE_MAX >= 3 - SELECT_KERNEL(3); - #endif - #if BLOCK_M_SIZE_MAX >= 4 - SELECT_KERNEL(4); - #endif - #if BLOCK_M_SIZE_MAX >= 5 - SELECT_KERNEL(5); - #endif - #if BLOCK_M_SIZE_MAX >= 6 - SELECT_KERNEL(6); - #endif - #if BLOCK_M_SIZE_MAX >= 7 - SELECT_KERNEL(7); - #endif - #if BLOCK_M_SIZE_MAX >= 8 - SELECT_KERNEL(8); - #endif - return NULL; + bool first_block, const int m_count, const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ + } +#if BLOCK_M_SIZE_MAX >= 1 + SELECT_KERNEL(1); +#endif +#if BLOCK_M_SIZE_MAX >= 2 + SELECT_KERNEL(2); +#endif +#if BLOCK_M_SIZE_MAX >= 3 + SELECT_KERNEL(3); +#endif +#if BLOCK_M_SIZE_MAX >= 4 + SELECT_KERNEL(4); +#endif +#if BLOCK_M_SIZE_MAX >= 5 + SELECT_KERNEL(5); +#endif +#if BLOCK_M_SIZE_MAX >= 6 + SELECT_KERNEL(6); +#endif +#if BLOCK_M_SIZE_MAX >= 7 + SELECT_KERNEL(7); +#endif +#if BLOCK_M_SIZE_MAX >= 8 + SELECT_KERNEL(8); +#endif + return NULL; } +void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* c, int size_m, int size_n, int size_k, + int m_count, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = + pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(a, b_q_weight, b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, + size_k, groups, b_q_perm); +} -void gemm_half_q_half_cuda_part -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* c, - int size_m, - int size_n, - int size_k, - int m_count, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); - gridDim.y = DIVIDE(size_m, m_count); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); +__global__ void reconstruct_exllama_8bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - a, - b_q_weight, - b_gptq_qzeros, - b_gptq_scales, - c, - size_m, - size_n, - size_k, - groups, - b_q_perm - ); -} + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; -__global__ void reconstruct_exllama_8bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - // b offset - int qk = offset_k / (32 / 8); + // b offset + int qk = offset_k / (32 / 8); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); - __syncthreads(); + __syncthreads(); - int k = offset_k; - int lk = 0; + int k = offset_k; + int lk = 0; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } - for (int p = 0; p < 4; p++) - { - int4 load_int4[2]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][4]; - dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, zeros[0] + 1); - dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, zeros[1] + 1); - dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, zeros[2] + 1); - dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, zeros[3] + 1); - - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + for (int p = 0; p < 4; p++) { + int4 load_int4[2]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][4]; + dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, + zeros[0] + 1); + dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, + zeros[1] + 1); + dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, + zeros[2] + 1); + dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, + zeros[3] + 1); + + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - k += 32; + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } } + k += 32; + } } -__global__ void reconstruct_exllama_4bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_4bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); } - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; - - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // b offset - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - __syncthreads(); - - int k = offset_k; - int lk = 0; - - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + for (int p = 0; p < 4; p++) { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, + false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, + false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, + false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, + false); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); } - - for (int p = 0; p < 4; p++) - { - half2 dq[4][4]; - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + } else { + for (int j = 0; j < 4; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_3bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_3bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / 32* 3; + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / 32 * 3; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + + __syncthreads(); - for (int p = 0; p < 1; p++) - { - int4 load_int4[3]; - load_int4[0] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*) b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*) b_ptr); b_ptr += size_n; - - half2 dq[4][16]; - dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n, zeros[0] + 1); - dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n, zeros[1] + 1); - dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n, zeros[2] + 1); - dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n, zeros[3] + 1); - - if (b_q_perm) - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 16; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + } + + for (int p = 0; p < 1; p++) { + int4 load_int4[3]; + load_int4[0] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[1] = *((int4*)b_ptr); + b_ptr += size_n; + load_int4[2] = *((int4*)b_ptr); + b_ptr += size_n; + + half2 dq[4][16]; + dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], + size_n, zeros[0] + 1); + dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], + size_n, zeros[1] + 1); + dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], + size_n, zeros[2] + 1); + dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], + size_n, zeros[3] + 1); + + if (b_q_perm) { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 16; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); } - k += 32; + } } + k += 32; + } } -__global__ void reconstruct_exllama_2bit_kernel -( - const uint32_t* __restrict__ b_q_weight, - const int* __restrict__ b_q_perm, +__global__ void reconstruct_exllama_2bit_kernel( + const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - __shared__ int perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } + const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half* __restrict__ b) { + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - // Column - int n = offset_n + t * 4; - if (n >= size_n) return; + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - // Find initial group - int groupsize = size_k / groups; - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - // b offset - int qk = offset_k / (32 / 2); + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + if (b_q_perm) { + if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + } - // Initial zeros/scale - int zeros[4]; - half2 scales[4]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; - __syncthreads(); + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; - int k = offset_k; - int lk = 0; + // b offset + int qk = offset_k / (32 / 2); - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - } + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - for (int p = 0; p < 2; p++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 8; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - } - k += 32; - } -} + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); -void reconstruct_exllama -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_q_perm, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + __syncthreads(); - auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; - if (bit == 2) { - reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; - } else if (bit == 3) { - reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; - } else if (bit == 8) { - reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + int k = offset_k; + int lk = 0; + + while (k < end_k) { + if (k == nextgroup) { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - reconstruct_exllama_kernel<<>> - ( - b_q_weight, - b_q_perm, - b_gptq_qzeros, - b_gptq_scales, - height, - width, - groups, - out - ); + for (int p = 0; p < 2; p++) { + const int4* b_ptr4 = (int4*)b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][8]; + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + + b_ptr += size_n; + // half* dqh = (half*)dq; + if (b_q_perm) { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), + __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), + __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } else { + for (int j = 0; j < 8; j++) { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), + __low2half(dq[1][j]), __low2half(dq[2][j]), + __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), + __high2half(dq[1][j]), __high2half(dq[2][j]), + __high2half(dq[3][j])); + } + } + } + k += 32; + } } +void reconstruct_exllama(const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_q_perm, + half* out, int height, int width, int groups, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto reconstruct_exllama_kernel = reconstruct_exllama_4bit_kernel; + if (bit == 2) { + reconstruct_exllama_kernel = reconstruct_exllama_2bit_kernel; + } else if (bit == 3) { + reconstruct_exllama_kernel = reconstruct_exllama_3bit_kernel; + } else if (bit == 8) { + reconstruct_exllama_kernel = reconstruct_exllama_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>>( + b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, + out); +} __global__ void gemm_half_q_half_alt_4bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 8; - int vec_height = height * 4; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 8; - int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - __shared__ half2 deq2[256][8]; - int val = threadIdx.x / 8; - int off = threadIdx.x % 8; - for (; val < 256; val += BLOCK_KN_SIZE / 8) { - deq2[val][off] = __halves2half2( - __int2half_rn(val & 0xF), __int2half_rn(val >> 4) - ); + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = + __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - + 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); - } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 8; - int k = 0; - int z_w = w / 8; - int z_mod = (w % 8) * 4; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[4]; - half2 zeros_tmp[4]; - for (int tmp_k = 0; tmp_k < 4; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); - res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), + blockvec[m][k + 2], res2); + res2 = __hfma2( + __hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), + blockvec[m][k + 3], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 4; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } - __global__ void gemm_half_q_half_alt_8bit_kernel( - const half2* __restrict__ vec, - const uint32_t* __restrict__ mat, - half* __restrict__ mul, - const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int height, - int width -) -{ - int zero_width = width / 4; - int vec_height = height * 2; - const int blockwidth2 = BLOCK_KN_SIZE / 2; - int b = blockIdx.y * BLOCK_M_SIZE_MAX; - int b_end = min(BLOCK_M_SIZE_MAX, batch - b); - int h = BLOCK_KN_SIZE * blockIdx.z / 4; - int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; - int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - - __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; - if (threadIdx.x < h_end) { - for (int m = 0; m < b_end; ++m) { - blockvec[m][threadIdx.x] = - vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + - threadIdx.x]; - } + const half2* __restrict__ vec, const uint32_t* __restrict__ mat, + half* __restrict__ mul, const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + int batch, int height, int width) { + int zero_width = width / 4; + int vec_height = height * 2; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 4; + int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; } - - - if (blockIdx.z == 0) - { - for (int m = 0; m < b_end; m++) - mul[(b + m) * width + w] = __int2half_rn(0); + } + + if (blockIdx.z == 0) { + for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 4; + int k = 0; + int z_w = w / 4; + int z_mod = (w % 4) * 8; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[2]; + half2 zeros_tmp[2]; + for (int tmp_k = 0; tmp_k < 2; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, + __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), + __hmul(scale_f2, + __int2half_rn( + -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; } - __syncthreads(); - - int i = width * h + w; - int g_h = h * 4; - int k = 0; - int z_w = w / 4; - int z_mod = (w % 4) * 8; - half2 res2; - half res[BLOCK_M_SIZE_MAX] = {}; - - unsigned int tmp; - while (k < h_end) { - tmp = mat[i]; - half2 scales_tmp[2]; - half2 zeros_tmp[2]; - for (int tmp_k = 0; tmp_k < 2; tmp_k++) { - int g = g_idx[g_h + (k + tmp_k) * 2]; - int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; - half scale_f = scales[g * width + w]; - half scale_f2 = scales[g2 * width + w]; - half2 scale = __halves2half2(scale_f, scale_f2); - half2 zero = __halves2half2( - __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), - __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)) - ); - scales_tmp[tmp_k] = scale; - zeros_tmp[tmp_k] = zero; - } - for (int m = 0; m < b_end; m++) { + for (int m = 0; m < b_end; m++) { #ifndef USE_ROCM - res2 = {}; + res2 = {}; #else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); #endif - half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), __int2half_rn((tmp >> 8) & 0xFF)); - res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); - half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), __int2half_rn((tmp >> 24) & 0xFF)); - res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF), + __int2half_rn((tmp >> 8) & 0xFF)); + res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]), + blockvec[m][k + 0], res2); + half2 v34 = __halves2half2(__int2half_rn((tmp >> 16) & 0xFF), + __int2half_rn((tmp >> 24) & 0xFF)); + res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]), + blockvec[m][k + 1], res2); #ifndef USE_ROCM - res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); #else - res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); + res[m] = __hadd( + res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); #endif - } - i += width; - k += 2; - } - for (int m = 0; m < b_end; m++) { - atomicAdd(&mul[(b + m) * width + w], res[m]); } + i += width; + k += 2; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } } -void gemm_half_q_half_alt -( - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - int size_m, - int size_n, - int size_k, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); - gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); - gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - - auto kernel = gemm_half_q_half_alt_4bit_kernel; - if (bit == 8) { - kernel = gemm_half_q_half_alt_8bit_kernel; - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - (const half2*) a, - b_q_weight, - c, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - size_m, - size_k / 32 * bit, - size_n - ); +void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, int size_m, int size_n, int size_k, + int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + auto kernel = gemm_half_q_half_alt_4bit_kernel; + if (bit == 8) { + kernel = gemm_half_q_half_alt_8bit_kernel; + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>( + (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); } -template -__global__ void reconstruct_gptq_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32 / bit; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - T w_zeros_(w_zeros, group, width); - - uint32_t w_read = w[blockIdx.y * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += bit) - { - int group = g_idx[row + s / bit]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } +template +__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, const int width, + const int group, + half* __restrict__ out) { + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32 / bit; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + T w_zeros_(w_zeros, group, width); + + uint32_t w_read = w[blockIdx.y * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int s = 0; s < 32; s += bit) { + int group = g_idx[row + s / bit]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = + __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), + w_scale); + *out_ptr = w_item; + out_ptr += out_.width; + } } -__global__ void reconstruct_gptq_3bit_kernel -( - const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, - const int width, - const int group, - half* __restrict__ out -) -{ - // Start of block - int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; - int row = blockIdx.y * 32; - if (column >= width) return; - - // Views - - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, group, width); - MatrixView_q3_row w_zeros_(w_zeros, group, width); - - uint32_t w1 = w[(blockIdx.y * 3) * width + column]; - uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; - uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int i = 0; i < 32; i += 1) - { - int group = g_idx[row + i]; - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - int w_item; - if (i == 10) { - w_item = (w1 >> 30) | ((w2 << 2) & 0x4); - } else if (i == 21) { - w_item = (w2 >> 31) | ((w3 << 1) & 0x6); - } else if (i < 10) { - w_item = ((w1 >> (i * 3)) & 0x7); - } else if (i < 21) { - w_item = ((w2 >> (i * 3 - 32)) & 0x7); - } else { - w_item = ((w3 >> (i * 3 - 64)) & 0x7); - } - *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); - out_ptr += out_.width; +__global__ void reconstruct_gptq_3bit_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + half* __restrict__ out) { + // Start of block + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 32; + if (column >= width) return; + + // Views + + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q3_row w_zeros_(w_zeros, group, width); + + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; + uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; + uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; + half* out_ptr = out_.item_ptr(row, column); + +#pragma unroll + for (int i = 0; i < 32; i += 1) { + int group = g_idx[row + i]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + int w_item; + if (i == 10) { + w_item = (w1 >> 30) | ((w2 << 2) & 0x4); + } else if (i == 21) { + w_item = (w2 >> 31) | ((w3 << 1) & 0x6); + } else if (i < 10) { + w_item = ((w1 >> (i * 3)) & 0x7); + } else if (i < 21) { + w_item = ((w2 >> (i * 3 - 32)) & 0x7); + } else { + w_item = ((w3 >> (i * 3 - 64)) & 0x7); } + *out_ptr = __hmul(__int2half_rn(w_item - w_zero), w_scale); + out_ptr += out_.width; + } } -void reconstruct_gptq -( - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* out, - int height, - int width, - int groups, - int bit -) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.y = DIVIDE(height, 32 / bit); - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - - auto kernel = reconstruct_gptq_kernel; - if (bit == 2) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 8) { - kernel = reconstruct_gptq_kernel; - } else if (bit == 3) { - kernel = reconstruct_gptq_3bit_kernel; - gridDim.y = DIVIDE(height, 32); - } - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - b_q_weight, - b_gptq_scales, - b_gptq_qzeros, - b_g_idx, - height, - width, - groups, - out - ); +void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, half* out, + int height, int width, int groups, int bit) { + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 32 / bit); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + auto kernel = reconstruct_gptq_kernel; + if (bit == 2) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 8) { + kernel = reconstruct_gptq_kernel; + } else if (bit == 3) { + kernel = reconstruct_gptq_3bit_kernel; + gridDim.y = DIVIDE(height, 32); + } + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>>(b_q_weight, b_gptq_scales, + b_gptq_qzeros, b_g_idx, height, + width, groups, out); } - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, - const int* b_g_idx, - half* c, - half* temp_dq, - int size_m, - int size_n, - int size_k, - int groups, - bool use_exllama, - int bit -) -{ - bool use_reconstruct; +void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, const int* b_g_idx, + half* c, half* temp_dq, int size_m, int size_n, + int size_k, int groups, bool use_exllama, int bit) { + bool use_reconstruct; + if (use_exllama) { + use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || + (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + } else { + // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so + // we disabled them for now. + use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + } + if (use_reconstruct) { + // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { - use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } else { - // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so we disabled them for now. - use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups, bit); } - if (use_reconstruct) { - // Reconstruct FP16 matrix, then cuBLAS - if (use_exllama) { - reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, - size_k, size_n, groups, bit); - } - else - { - reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); - } - const half alpha = __float2half(1.0f); - const half beta = __float2half(0.0f); - cublasHgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - size_n, size_m, size_k, - &alpha, temp_dq, size_n, - a, size_k, - &beta, c, size_n); + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, size_n, size_m, size_k, + &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); + } else if (use_exllama) { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c, last_chunk, size_n, size_k, + BLOCK_M_SIZE_MAX, groups, bit); } - else if (use_exllama) - { - // Quantized matmul - int max_chunks = size_m / BLOCK_M_SIZE_MAX; - int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; - int last_chunk_size = size_m - last_chunk; - - if (max_chunks) - { - gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, - groups, bit); - } - if (last_chunk_size) - { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, - b_gptq_scales, b_g_idx, c + last_chunk * size_n, - last_chunk_size, size_n, size_k, last_chunk_size, - groups, bit); - } - } - else - { - gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + if (last_chunk_size) { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, + b_gptq_qzeros, b_gptq_scales, b_g_idx, + c + last_chunk * size_n, last_chunk_size, + size_n, size_k, last_chunk_size, groups, bit); } + } else { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k, bit); + } } -__global__ void shuffle_4bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_4bit_8(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 8; + } } -__global__ void shuffle_8bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; } +__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_8bit_4(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 4; + } } -__global__ void shuffle_2bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; } +__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_2bit_16(b_ptr, size_n); + b_ptr += 1 * size_n; + k += 16; + } } -__global__ void shuffle_3bit_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; } +__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n) { + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { + shuffle_3bit_32(b_ptr, size_n); + b_ptr += 3 * size_n; + k += 32; + } } -__global__ void make_sequential_4bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 3; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 8; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_2bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 4; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 16; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 4; - int w2_subrow = source_row & 0x0f; - int w2_row_shift = w2_subrow << 1; - int wnew2_row_shift = i << 1; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000300000003; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 4; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 16; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 4; + int w2_subrow = source_row & 0x0f; + int w2_row_shift = w2_subrow << 1; + int wnew2_row_shift = i << 1; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000300000003; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_3bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w_column >= w_width) return; - int w_new_row = blockIdx.y * 3; - int q_perm_idx = blockIdx.y << 5; - uint32_t dst[3] = {0, 0, 0}; - - #pragma unroll - for (int i = 0; i < 32; i++) - { - int source_row = q_perm[q_perm_idx++]; - int z_w = (source_row / 32) * 3; - int z_mod = source_row % 32; - int z_bit; - - if (z_mod != 10){ - if (z_mod != 21){ - z_bit = z_mod; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - - uint64_t src; - if (z_mod == 10) { - src = (w[z_w * w_width + w_column] >> 30) | ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); - } else if (z_mod == 21){ - src = (w[z_w * w_width + w_column] >> 31) | ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); +__global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w_column >= w_width) return; + int w_new_row = blockIdx.y * 3; + int q_perm_idx = blockIdx.y << 5; + uint32_t dst[3] = {0, 0, 0}; + +#pragma unroll + for (int i = 0; i < 32; i++) { + int source_row = q_perm[q_perm_idx++]; + int z_w = (source_row / 32) * 3; + int z_mod = source_row % 32; + int z_bit; + + if (z_mod != 10) { + if (z_mod != 21) { + z_bit = z_mod; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - src = w[z_w * w_width + w_column]; - src >>= z_bit; - src &= 0x07; + z_bit *= 3; } + } else { + z_w += 1; + } + } - z_w = 0; - if (i != 10){ - if (i != 21){ - z_bit = i; - if (z_bit > 21){ - z_bit *= 3; - z_bit -= 64; - z_w += 2; - } else if (z_bit > 10){ - z_bit *= 3; - z_bit -= 32; - z_w += 1; - } else { - z_bit *= 3; - } - } else { - z_w += 1; - } - } - if (i == 10) { - dst[z_w] |= (src & 0x03) << 30; - dst[z_w + 1] |= ((src & 0x4) >> 2); - } else if (i == 21) { - dst[z_w] |= (src & 0x01) << 31; - dst[z_w + 1] |= ((src & 0x6) >> 1); + uint64_t src; + if (z_mod == 10) { + src = (w[z_w * w_width + w_column] >> 30) | + ((w[(z_w + 1) * w_width + w_column] << 2) & 0x4); + } else if (z_mod == 21) { + src = (w[z_w * w_width + w_column] >> 31) | + ((w[(z_w + 1) * w_width + w_column] << 1) & 0x6); + } else { + src = w[z_w * w_width + w_column]; + src >>= z_bit; + src &= 0x07; + } + + z_w = 0; + if (i != 10) { + if (i != 21) { + z_bit = i; + if (z_bit > 21) { + z_bit *= 3; + z_bit -= 64; + z_w += 2; + } else if (z_bit > 10) { + z_bit *= 3; + z_bit -= 32; + z_w += 1; } else { - dst[z_w] |= (src << z_bit); + z_bit *= 3; } + } else { + z_w += 1; + } + } + if (i == 10) { + dst[z_w] |= (src & 0x03) << 30; + dst[z_w + 1] |= ((src & 0x4) >> 2); + } else if (i == 21) { + dst[z_w] |= (src & 0x01) << 31; + dst[z_w + 1] |= ((src & 0x6) >> 1); + } else { + dst[z_w] |= (src << z_bit); } - w_new[w_new_row * w_width + w_column] = dst[0]; - w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; - w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; + } + w_new[w_new_row * w_width + w_column] = dst[0]; + w_new[(w_new_row + 1) * w_width + w_column] = dst[1]; + w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; } -__global__ void make_sequential_8bit_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - int w_new2_row = blockIdx.y; - int q_perm_idx = w_new2_row << 2; - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 2; - int w2_subrow = source_row & 0x03; - int w2_row_shift = w2_subrow << 3; - int wnew2_row_shift = i << 3; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x000000ff000000ff; - src <<= wnew2_row_shift; - dst |= src; - } - w_new2[w_new2_row * w2_stride + w2_column] = dst; +__global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_width) { + const uint64_t* w2 = (uint64_t*)w; + uint64_t* w_new2 = (uint64_t*)w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 2; + uint64_t dst = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 2; + int w2_subrow = source_row & 0x03; + int w2_row_shift = w2_subrow << 3; + int wnew2_row_shift = i << 3; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x000000ff000000ff; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; } +void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, + int width, int bit) { + if (q_perm) { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); -void shuffle_exllama_weight -( - uint32_t* q_weight, - int* q_perm, - int height, - int width, - int bit -) -{ - if (q_perm) - { - uint32_t* new_qweight = NULL; - cudaMalloc(&new_qweight, height / 32 * bit * width * sizeof(uint32_t)); - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 32 * bit; - - auto kernel = make_sequential_4bit_kernel; - if (bit == 2) { - kernel = make_sequential_2bit_kernel; - } else if (bit == 3) { - kernel = make_sequential_3bit_kernel; - gridDim.y = height / 32; - } else if (bit == 8) { - kernel = make_sequential_8bit_kernel; - } - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>> - ( - q_weight, - new_qweight, - q_perm, - width - ); - // Replace qweights - cudaMemcpyAsync(q_weight, new_qweight, height / 32 * bit * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - // Cleanup - cudaDeviceSynchronize(); - cudaFree(new_qweight); - } dim3 blockDim, gridDim; blockDim.x = THREADS_X; blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = 1; - auto shuffle_kernel = shuffle_4bit_kernel; + gridDim.y = height / 32 * bit; + + auto kernel = make_sequential_4bit_kernel; if (bit == 2) { - shuffle_kernel = shuffle_2bit_kernel; + kernel = make_sequential_2bit_kernel; } else if (bit == 3) { - shuffle_kernel = shuffle_3bit_kernel; + kernel = make_sequential_3bit_kernel; + gridDim.y = height / 32; } else if (bit == 8) { - shuffle_kernel = shuffle_8bit_kernel; + kernel = make_sequential_8bit_kernel; } const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - shuffle_kernel<<>>(q_weight, height, width); + kernel<<>>(q_weight, new_qweight, q_perm, + width); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, + height / 32 * bit * width * sizeof(uint32_t), + cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + auto shuffle_kernel = shuffle_4bit_kernel; + if (bit == 2) { + shuffle_kernel = shuffle_2bit_kernel; + } else if (bit == 3) { + shuffle_kernel = shuffle_3bit_kernel; + } else if (bit == 8) { + shuffle_kernel = shuffle_8bit_kernel; + } + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq } // namespace vllm -torch::Tensor gptq_gemm -( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); - at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); - - vllm::gptq::gemm_half_q_half_cuda - ( - at::cuda::getCurrentCUDABlasHandle(), - (const half*) a.data_ptr(), - (const uint32_t*) b_q_weight.data_ptr(), - (const uint32_t*)b_gptq_qzeros.data_ptr(), - (const half*) b_gptq_scales.data_ptr(), - b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), - (half*) c.data_ptr(), - (half*) temp_dq.data_ptr(), - c.size(0), // m - c.size(1), // n - a.size(1), // k - b_gptq_qzeros.size(0), // group number - use_exllama, - bit - ); - return c; +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty( + {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda( + at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(), + (const uint32_t*)b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*)b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*)b_g_idx.data_ptr(), + (half*)c.data_ptr(), (half*)temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama, bit); + return c; } -void gptq_shuffle -( - torch::Tensor q_weight, - torch::Tensor q_perm, - int bit -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); - vllm::gptq::shuffle_exllama_weight( - (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), - q_weight.size(0) * 32 / bit, - q_weight.size(1), - bit - ); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*)q_weight.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 + ? NULL + : (int*)q_perm.data_ptr(), + q_weight.size(0) * 32 / bit, q_weight.size(1), bit); } diff --git a/csrc/quantization/gptq/qdq_2.cuh b/csrc/quantization/gptq/qdq_2.cuh index 295872a91de37..ca0f810608d1b 100644 --- a/csrc/quantization/gptq/qdq_2.cuh +++ b/csrc/quantization/gptq/qdq_2.cuh @@ -14,71 +14,60 @@ namespace gptq { // // ffddbb99 77553311 eeccaa88 66442200 -__forceinline__ __device__ void shuffle_2bit_16 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; +__forceinline__ __device__ void shuffle_2bit_16(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; - #pragma unroll - for (int i = 0; i < 8; i++) - { - uint32_t qa0 = qa & 0x03; - uint32_t qa1 = (qa & 0x0c) >> 2; - qa >>= 4; - qb |= (qa1 << (i * 2 + 16)); - qb |= (qa0 << (i * 2)); - } - q[0] = qb; +#pragma unroll + for (int i = 0; i < 8; i++) { + uint32_t qa0 = qa & 0x03; + uint32_t qa1 = (qa & 0x0c) >> 2; + qa >>= 4; + qb |= (qa1 << (i * 2 + 16)); + qb |= (qa0 << (i * 2)); + } + q[0] = qb; } -__forceinline__ __device__ void dequant_2bit_16 -( - const uint32_t q_0, - half2 (&dq)[8], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y4_ = __float2half_rn(1.0f / 4.0f); - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y4 = __halves2half2(y4_, y4_); - const half2 y16 = __halves2half2(y16_, y16_); - const half2 y64 = __halves2half2(y64_, y64_); +__forceinline__ __device__ void dequant_2bit_16(const uint32_t q_0, + half2 (&dq)[8], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y4_ = __float2half_rn(1.0f / 4.0f); + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y4 = __halves2half2(y4_, y4_); + const half2 y16 = __halves2half2(y16_, y16_); + const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z4 = __half2half2(z4_); - const half2 z16 = __half2half2(z16_); - const half2 z64 = __half2half2(z64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero)); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z4 = __half2half2(z4_); + const half2 z16 = __half2half2(z16_); + const half2 z64 = __half2half2(z64_); - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 - half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 - half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 - qa >>= 8; - half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 - half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 - half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 - half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024 + half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024 + half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024 + qa >>= 8; + half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024 + half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024 + half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024 + half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024 - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y4, z4); - dq[2] = __hfma2(q2.as_half2, y16, z16); - dq[3] = __hfma2(q3.as_half2, y64, z64); - dq[4] = __hadd2(q4.as_half2, z1); - dq[5] = __hfma2(q5.as_half2, y4, z4); - dq[6] = __hfma2(q6.as_half2, y16, z16); - dq[7] = __hfma2(q7.as_half2, y64, z64); + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y4, z4); + dq[2] = __hfma2(q2.as_half2, y16, z16); + dq[3] = __hfma2(q3.as_half2, y64, z64); + dq[4] = __hadd2(q4.as_half2, z1); + dq[5] = __hfma2(q5.as_half2, y4, z4); + dq[6] = __hfma2(q6.as_half2, y16, z16); + dq[7] = __hfma2(q7.as_half2, y64, z64); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_3.cuh b/csrc/quantization/gptq/qdq_3.cuh index 3e7ecde752ba3..0d5c2adf5dbbe 100644 --- a/csrc/quantization/gptq/qdq_3.cuh +++ b/csrc/quantization/gptq/qdq_3.cuh @@ -11,128 +11,136 @@ namespace gptq { // vjjjhhhf ffdddbbb uiiiggge eecccaaa // vtttrrrp ppnnnlll usssqqqo oommmkkk -__forceinline__ __device__ void shuffle_3bit_32 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0 * stride]; - uint32_t qb = q[1 * stride]; - uint32_t qc = q[2 * stride]; - - // qa: aa999888 77766655 54443332 22111000 - // qb: lkkkjjji iihhhggg fffeeedd dcccbbba - // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll - - uint32_t qd = qc >> 26; - qc <<= 4; - qc |= qb >> 28; - qb <<= 2; - qb |= qa >> 30; - - // qa: ..999888 77766655 54443332 22111000 - // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa - // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk - // qd: vvvuuu - - uint32_t za = 0; - uint32_t zb = 0; - uint32_t zc = 0; - - for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); } - for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); } - - // za: 9997775 55333111 8886664 44222000 - // zb: jjjhhhf ffdddbbb iiiggge eecccaaa - // zc: tttrrrp ppnnnlll sssqqqo oommmkkk - // qd: vvvuuu - - za |= ((qd & 0x01) >> 0) << 15; - zb |= ((qd & 0x02) >> 1) << 15; - zc |= ((qd & 0x04) >> 2) << 15; - za |= ((qd & 0x08) >> 3) << 31; - zb |= ((qd & 0x10) >> 4) << 31; - zc |= ((qd & 0x20) >> 5) << 31; - - // za: v9997775 55333111 u8886664 44222000 (u, v lsb) - // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa - // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk - - q[0 * stride] = za; - q[1 * stride] = zb; - q[2 * stride] = zc; -} - -__forceinline__ __device__ void dequant_3bit_32 -( - const uint32_t q_0, - const uint32_t q_1, - const uint32_t q_2, - half2 (&dq)[16], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y8_ = __float2half_rn(1.0f / 8.0f); - const half y64_ = __float2half_rn(1.0f / 64.0f); - const half2 y8 = __halves2half2(y8_, y8_); - const half2 y64 = __halves2half2(y64_, y64_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); - const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); - const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); - const half2 z8 = __halves2half2(z8_, z8_); - const half2 z64 = __halves2half2(z64_, z64_); - - uint32_t qa = q_0; - uint32_t qb = q_1; - uint32_t qc = q_2; - - half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 +__forceinline__ __device__ void shuffle_3bit_32(uint32_t* q, int stride) { + uint32_t qa = q[0 * stride]; + uint32_t qb = q[1 * stride]; + uint32_t qc = q[2 * stride]; + + // qa: aa999888 77766655 54443332 22111000 + // qb: lkkkjjji iihhhggg fffeeedd dcccbbba + // qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll + + uint32_t qd = qc >> 26; + qc <<= 4; + qc |= qb >> 28; + qb <<= 2; + qb |= qa >> 30; + + // qa: ..999888 77766655 54443332 22111000 + // qb: ..jjjiii hhhgggff feeedddc ccbbbaaa + // qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk + // qd: vvvuuu + + uint32_t za = 0; + uint32_t zb = 0; + uint32_t zc = 0; + + for (int i = 0; i < 5; i++) { + uint32_t t0 = qa & 0x07; + uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; - half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 - half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 - qa >>= 9; - qa &= 0x00010001; - half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 - half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + za |= (t0 << (i * 3)); + za |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qb & 0x07; + uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; - half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 - half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 - half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 - qb >>= 8; - qb &= 0x00020002; - half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 - half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + zb |= (t0 << (i * 3)); + zb |= (t1 << (i * 3 + 16)); + } + for (int i = 0; i < 5; i++) { + uint32_t t0 = qc & 0x07; + uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; - half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 - half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 - half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 - qc >>= 7; - qc &= 0x00040004; - half2_uint32 q15((qa | qb | qc) | c0); - - dq[ 0] = __hadd2( q0.as_half2, z1); - dq[ 1] = __hfma2( q1.as_half2, y8, z8); - dq[ 2] = __hadd2( q2.as_half2, z1); - dq[ 3] = __hfma2( q3.as_half2, y8, z8); - dq[ 4] = __hfma2( q4.as_half2, y64, z64); - dq[ 5] = __hadd2( q5.as_half2, z1); - dq[ 6] = __hfma2( q6.as_half2, y8, z8); - dq[ 7] = __hadd2( q7.as_half2, z1); - dq[ 8] = __hfma2( q8.as_half2, y8, z8); - dq[ 9] = __hfma2( q9.as_half2, y64, z64); - dq[10] = __hadd2(q10.as_half2, z1); - dq[11] = __hfma2(q11.as_half2, y8, z8); - dq[12] = __hadd2(q12.as_half2, z1); - dq[13] = __hfma2(q13.as_half2, y8, z8); - dq[14] = __hfma2(q14.as_half2, y64, z64); - dq[15] = __hadd2(q15.as_half2, z1); + zc |= (t0 << (i * 3)); + zc |= (t1 << (i * 3 + 16)); + } + + // za: 9997775 55333111 8886664 44222000 + // zb: jjjhhhf ffdddbbb iiiggge eecccaaa + // zc: tttrrrp ppnnnlll sssqqqo oommmkkk + // qd: vvvuuu + + za |= ((qd & 0x01) >> 0) << 15; + zb |= ((qd & 0x02) >> 1) << 15; + zc |= ((qd & 0x04) >> 2) << 15; + za |= ((qd & 0x08) >> 3) << 31; + zb |= ((qd & 0x10) >> 4) << 31; + zc |= ((qd & 0x20) >> 5) << 31; + + // za: v9997775 55333111 u8886664 44222000 (u, v lsb) + // zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa + // zc: vtttrrrp ppnnnlll usssqqqo oommmkkk + + q[0 * stride] = za; + q[1 * stride] = zb; + q[2 * stride] = zc; +} + +__forceinline__ __device__ void dequant_3bit_32(const uint32_t q_0, + const uint32_t q_1, + const uint32_t q_2, + half2 (&dq)[16], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y8_ = __float2half_rn(1.0f / 8.0f); + const half y64_ = __float2half_rn(1.0f / 64.0f); + const half2 y8 = __halves2half2(y8_, y8_); + const half2 y64 = __halves2half2(y64_, y64_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero)); + const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero)); + const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half); + const half2 z8 = __halves2half2(z8_, z8_); + const half2 z64 = __halves2half2(z64_, z64_); + + uint32_t qa = q_0; + uint32_t qb = q_1; + uint32_t qc = q_2; + + half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024 + qa >>= 6; + half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024 + half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024 + qa >>= 9; + qa &= 0x00010001; + half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024 + half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024 + qb >>= 6; + half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024 + half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024 + half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024 + qb >>= 8; + qb &= 0x00020002; + half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024 + half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024 + qc >>= 6; + half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024 + half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024 + half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024 + qc >>= 7; + qc &= 0x00040004; + half2_uint32 q15((qa | qb | qc) | c0); + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y8, z8); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y8, z8); + dq[4] = __hfma2(q4.as_half2, y64, z64); + dq[5] = __hadd2(q5.as_half2, z1); + dq[6] = __hfma2(q6.as_half2, y8, z8); + dq[7] = __hadd2(q7.as_half2, z1); + dq[8] = __hfma2(q8.as_half2, y8, z8); + dq[9] = __hfma2(q9.as_half2, y64, z64); + dq[10] = __hadd2(q10.as_half2, z1); + dq[11] = __hfma2(q11.as_half2, y8, z8); + dq[12] = __hadd2(q12.as_half2, z1); + dq[13] = __hfma2(q13.as_half2, y8, z8); + dq[14] = __hfma2(q14.as_half2, y64, z64); + dq[15] = __hadd2(q15.as_half2, z1); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh index 881f353f6564d..7f65d2d2819b1 100644 --- a/csrc/quantization/gptq/qdq_4.cuh +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -13,133 +13,112 @@ namespace gptq { // // 77775555 33331111 66664444 22220000 -__forceinline__ __device__ void shuffle_4bit_8 -( - uint32_t* q, - int stride -) -{ - uint32_t qa = q[0]; - uint32_t qb = 0; - - #pragma unroll - for (int i = 0; i < 4; i++) - { - uint32_t qa0 = qa & 0x0f; - uint32_t qa1 = (qa & 0xf0) >> 4; - qa >>= 8; - qb |= (qa1 << (i * 4 + 16)); - qb |= (qa0 << (i * 4)); - } - q[0] = qb; -} - -__forceinline__ __device__ void dequant_4bit_8 -( - const uint32_t q_0, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - const uint32_t c0 = 0x64006400; - const half y16_ = __float2half_rn(1.0f / 16.0f); - const half2 y16 = __halves2half2(y16_, y16_); - const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); - const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - const half2 z1 = __half2half2(z1_.as_half); - const half2 z16 = __half2half2(z16_); - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 +__forceinline__ __device__ void shuffle_4bit_8(uint32_t* q, int stride) { + uint32_t qa = q[0]; + uint32_t qb = 0; + +#pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} - dq[0] = __hadd2(q0.as_half2, z1); - dq[1] = __hfma2(q1.as_half2, y16, z16); - dq[2] = __hadd2(q2.as_half2, z1); - dq[3] = __hfma2(q3.as_half2, y16, z16); +__forceinline__ __device__ void dequant_4bit_8(const uint32_t q_0, + half2 (&dq)[4], int stride, + const uint32_t zero) { + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero); + const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + const half2 z1 = __half2half2(z1_.as_half); + const half2 z16 = __half2half2(z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale -( - const uint32_t zero, - const half scale, - half2 (&z1z16)[2], - half2 (&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale( + const uint32_t zero, const half scale, half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - half2 scale2 = __half2half2(scale); + half2 scale2 = __half2half2(scale); - z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); - z1z16[1] = __hmul2(scale2, __half2half2(z16)); + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __hmul2(scale2, __half2half2(y1)); - y1y16[1] = __hmul2(scale2, __half2half2(y16)); + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); } -__forceinline__ __device__ void dequant_4bit_8_prep_zero -( - const uint32_t zero, - half2(&z1z16)[2], - half2(&y1y16)[2] -) -{ - half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); - half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); +__forceinline__ __device__ void dequant_4bit_8_prep_zero(const uint32_t zero, + half2 (&z1z16)[2], + half2 (&y1y16)[2]) { + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); - z1z16[0] = __half2half2(z1.as_half); - z1z16[1] = __half2half2(z16); + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); - const half y1 = __float2half_rn(1.0f); - const half y16 = __float2half_rn(1.0f / 16.0f); + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); - y1y16[0] = __half2half2(y1); - y1y16[1] = __half2half2(y16); + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); } - -__forceinline__ __device__ void dequant_4bit_8_gptq -( - const uint32_t q_0, - half2 (&dq)[4], - half2 (&z1z16)[2], - half2 (&y1y16)[2], - int stride, - bool scaled -) -{ - const uint32_t c0 = 0x64006400; - - uint32_t qa = q_0; - half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) - half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) - qa >>= 8; - half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) - half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) - - if (scaled) - { - dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) - dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); - } - else - { - dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) - dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) - dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) - dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) - } +__forceinline__ __device__ void dequant_4bit_8_gptq(const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, bool scaled) { + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | + c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | + c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | + c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | + c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) { + dq[0] = __hfma2(q0.as_half2, y1y16[0], + z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } else { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], + z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], + z1z16[1]); // half2( q[6] - z, q[7] - z ) + } } } // namespace gptq } // namespace vllm diff --git a/csrc/quantization/gptq/qdq_8.cuh b/csrc/quantization/gptq/qdq_8.cuh index 0c7ad7876140b..feb5d220424b0 100644 --- a/csrc/quantization/gptq/qdq_8.cuh +++ b/csrc/quantization/gptq/qdq_8.cuh @@ -10,28 +10,18 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -__forceinline__ __device__ void shuffle_8bit_4 -( - uint32_t* q, - int stride -) -{ -} - -__forceinline__ __device__ void dequant_8bit_8 -( - const uint32_t q_0, - const uint32_t q_1, - half2 (&dq)[4], - int stride, - const uint32_t zero -) -{ - half dqh[8]; - for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero); - for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); - - for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +__forceinline__ __device__ void shuffle_8bit_4(uint32_t* q, int stride) {} + +__forceinline__ __device__ void dequant_8bit_8(const uint32_t q_0, + const uint32_t q_1, + half2 (&dq)[4], int stride, + const uint32_t zero) { + half dqh[8]; + for (int i = 0; i < 4; i++) dqh[i] = dq_ns(exb(q_0, i * 8, 0xff), zero); + for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero); + + for (int i = 0; i < 4; i++) + dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); } } // namespace gptq diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh index 1722a9aa6cb34..9426408fec502 100644 --- a/csrc/quantization/gptq/qdq_util.cuh +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -8,51 +8,47 @@ Copied from https://github.com/turboderp/exllamav2 namespace vllm { namespace gptq { -union half2_uint32 -{ - uint32_t as_uint32; - half2 as_half2; - __device__ half2_uint32(uint32_t val) : as_uint32(val) {} - __device__ half2_uint32(half2 val) : as_half2(val) {} +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} }; -union half_uint16 -{ - uint16_t as_uint16; - half as_half; - __device__ half_uint16(uint16_t val) : as_uint16(val) {} - __device__ half_uint16(half val) : as_half(val) {} +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} }; // Max_scale premultiplied by 1/256 -__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) -{ - int qs_i = qs + 1; - half qs_h = __int2half_rn(qs_i * qs_i); - qs_h = __hmul(qs_h, max_scale); - return qs_h; +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) { + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; } -__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) -{ - return __hmul(__int2half_rn(q - qzero), scale); +__forceinline__ __device__ half dq(const int q, const int qzero, + const half scale) { + return __hmul(__int2half_rn(q - qzero), scale); } -__forceinline__ __device__ half dq_ns(const int q, const int qzero) -{ - //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); - return __int2half_rn(q - qzero); +__forceinline__ __device__ half dq_ns(const int q, const int qzero) { + // return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); } -__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) -{ - return (int)((q >> shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q, const int shift, + const int mask) { + return (int)((q >> shift) & mask); } -__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) -{ - return (int)(__funnelshift_rc(q0, q1, shift) & mask); +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, + const int shift, const int mask) { + return (int)(__funnelshift_rc(q0, q1, shift) & mask); } } // namespace gptq diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 34950a5d13cf5..c573b9041065b 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -22,53 +22,58 @@ #include "gptq_marlin.cuh" #include "gptq_marlin_dtypes.cuh" -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\ - std::is_same::value || std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template inline std::string str(T x) { return std::to_string(x); } +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +template +inline std::string str(T x) { + return std::to_string(x); +} namespace gptq_marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -81,24 +86,26 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. template -__device__ inline void mma(const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - typename ScalarType::FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } @@ -107,8 +114,9 @@ __device__ inline void mma(const typename ScalarType::FragA &a_frag, // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template -__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -118,7 +126,8 @@ __device__ inline void ldsm4(typename ScalarType::FragA &frag_a, const // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -140,8 +149,10 @@ __device__ inline uint32_t prmt(uint32_t a) { // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: -// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 template __device__ inline typename ScalarType::FragB dequant_4bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); @@ -161,16 +172,17 @@ __device__ inline typename ScalarType::FragB dequant_4bit(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } template <> -__device__ inline typename ScalarType::FragB dequant_4bit(int q) { +__device__ inline typename ScalarType::FragB +dequant_4bit(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -184,7 +196,7 @@ __device__ inline typename ScalarType::FragB dequant_4bit(&lo), + frag_b[0] = __hfma2(*reinterpret_cast(&lo), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); frag_b[1] = __hfma2(*reinterpret_cast(&hi), @@ -193,10 +205,12 @@ __device__ inline typename ScalarType::FragB dequant_4bit __device__ inline typename ScalarType::FragB dequant_8bit(int q) { STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); @@ -214,24 +228,26 @@ __device__ inline typename ScalarType::FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { +__device__ inline typename ScalarType::FragB +dequant_8bit(int q) { typename ScalarType::FragB frag_b; float fp32_intermediates[4]; - uint32_t * fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); fp32_intermediates[0] -= 8388736.f; @@ -240,8 +256,10 @@ __device__ inline typename ScalarType::FragB dequant_8bit(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); return frag_b; } @@ -249,30 +267,32 @@ __device__ inline typename ScalarType::FragB dequant_8bit -__device__ inline void scale(typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s, int i) { +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Same as above, but for act_order (each K is multiplied individually) template -__device__ inline void scale4(typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s_1, - typename ScalarType::FragS &frag_s_2, - typename ScalarType::FragS &frag_s_3, - typename ScalarType::FragS &frag_s_4, +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, int i) { - using scalar_t2 = typename ScalarType::scalar_t2; + using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); @@ -280,14 +300,15 @@ __device__ inline void scale4(typename ScalarType::FragB &frag_b, // Given 2 floats multiply by 2 scales (halves) template -__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { - scalar_t *s_ptr = reinterpret_cast(&s); +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -302,7 +323,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -321,11 +342,10 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, int size_m, +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { @@ -341,9 +361,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int offset = row * row_stride; - half const *a_row_half = - reinterpret_cast(a_int4_ptr + offset); - half *out_half = reinterpret_cast(out_int4_ptr + offset); + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); int base_k = 0; @@ -374,31 +393,32 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -445,11 +465,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -465,27 +485,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -605,7 +620,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; @@ -623,13 +638,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -639,30 +654,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_g_idx = sh_b + (stages * b_sh_stage); - int4 *sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; @@ -706,18 +721,18 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } @@ -730,10 +745,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int full_pipe = a_off; int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; if (cur_k < prob_k && cur_k < slice_k_finish) { - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int4 const *cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], @@ -742,7 +757,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } else { if constexpr (group_blocks != -1) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { // Only fetch scales if this tile starts a new group @@ -782,15 +797,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. auto fetch_to_registers = [&](int k, int pipe) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } }; @@ -805,8 +821,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk return; } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); int group_id_1 = sh_g_idx_int_ptr[0]; int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; @@ -822,10 +838,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; @@ -838,9 +854,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int k_blocks = cur_k / 16; int cur_group_id = k_blocks / group_blocks; - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } } @@ -867,7 +883,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // thread-id) int warp_id = threadIdx.x / 32; int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N int warp_row = warp_id / n_warps; int warp_col = warp_id % n_warps; @@ -875,7 +891,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cur_k += warp_row * 16; int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + @@ -883,45 +899,44 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (is_same_group[pipe]) { if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); } for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); } return; } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread + 9}; // Tensor core offsets per thread -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; int group_id = sh_g_idx_int_ptr[actual_k]; int rel_group_id = group_id - sh_first_group_id; - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; } }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -933,7 +948,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk frag_b1 = dequant_4bit(b_quant_shift); } else { - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -943,8 +958,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Apply scale to frag_b0 if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + scale4(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); } else { if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k % 2][j], 0); @@ -953,8 +969,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Apply scale to frag_b1 if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + scale4(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); } else { if constexpr (group_blocks != -1) { @@ -962,7 +979,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -987,38 +1004,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -1049,39 +1066,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -1115,8 +1132,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { - scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); // For per-column quantization we finally apply the scale here (only for // 4-bit) @@ -1124,13 +1142,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk res = __hmul2(res, s[0]); } - ((scalar_t2 *)sh)[idx] = res; + ((scalar_t2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -1147,7 +1165,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -1162,7 +1180,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < stages - 1; i++) { if (has_act_order && i == 0) { int last_g_idx = slice_k_start + stages * tb_k * 2; @@ -1193,9 +1211,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // have even length meaning that the next iteration will always start at // index 0. -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); fetch_scales_to_registers(k + 1, pipe); @@ -1261,8 +1279,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } else { @@ -1270,8 +1288,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } } @@ -1282,31 +1300,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); } } } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -1315,13 +1337,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } // Update slice k/n for scales loading @@ -1341,23 +1362,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } -#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ - prob_k, locks); \ - } + #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } typedef struct { int thread_k; @@ -1389,7 +1411,7 @@ thread_config_t large_batch_thread_configs[] = { }; -int get_scales_cache_size(thread_config_t const &th_config, int prob_m, +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; @@ -1402,15 +1424,15 @@ int get_scales_cache_size(thread_config_t const &th_config, int prob_m, if (group_size == -1) { tb_groups = 1; } else if (group_size == 0) { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size } else { tb_groups = div_ceil(tb_k, group_size); } if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { @@ -1420,7 +1442,7 @@ int get_scales_cache_size(thread_config_t const &th_config, int prob_m, } } -bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int scales_cache_size, int max_shared_mem) { int pack_factor = 32 / num_bits; @@ -1451,12 +1473,12 @@ bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); } -bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, int max_shared_mem) { @@ -1519,43 +1541,43 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, } } - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ - \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) template -void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, - void *g_idx, void *perm, void *a_tmp, int prob_m, - int prob_n, int prob_k, void *workspace, int num_bits, +void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, + void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, int num_bits, bool has_act_order, bool is_k_full, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, int sms, int max_par) { @@ -1639,15 +1661,15 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; - const int *g_idx_ptr = (const int *)g_idx; - const int *perm_ptr = (const int *)perm; - int4 *a_tmp_ptr = (int4 *)a_tmp; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; - int *locks = (int *)workspace; + int* locks = (int*)workspace; if (has_act_order) { // Permute A columns @@ -1673,8 +1695,7 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = (16 * exec_cfg.max_m_blocks) * par; i += exec_cfg.max_m_blocks * (par - 1); thread_m_blocks = exec_cfg.max_m_blocks; @@ -1709,11 +1730,11 @@ void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, } } -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &g_idx, - torch::Tensor &perm, torch::Tensor &workspace, +torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& g_idx, + torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full) { // Verify num_bits @@ -1824,18 +1845,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), + a_tmp.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, gptq_marlin::max_par); } else if (a.scalar_type() == at::ScalarType::BFloat16) { gptq_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), - g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, - num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, gptq_marlin::max_par); + a.data_ptr(), b_q_weight.data_ptr(), + c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order, + is_k_full, num_groups, group_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, + gptq_marlin::max_par); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index 35ea48aaba310..ba5368ea8835f 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -11,22 +11,23 @@ namespace gptq_marlin { -// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per -// schedule allows some more latency hiding. At the same time, we want relatively few warps to have -// many registers per warp and small tiles. +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. static constexpr int default_threads = 256; -static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; -static constexpr int max_par = 16; +static constexpr int max_par = 16; template struct Vec { - T elems[n]; + T elems[n]; __device__ T& operator[](int i) { return elems[i]; } }; @@ -35,30 +36,35 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - // No support for async +// No support for async #else -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } -__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} template __device__ inline void cp_async_wait() { @@ -67,4 +73,4 @@ __device__ inline void cp_async_wait() { #endif -} // namespace gptq_marlin +} // namespace gptq_marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh index 7881abbe4cbbf..ca1b7099d6ec7 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh @@ -5,58 +5,73 @@ #include #include - namespace gptq_marlin { template -class ScalarType { -}; +class ScalarType {}; template <> class ScalarType { -public: - using scalar_t = half; - using scalar_t2 = half2; - - // Matrix fragments for tensor core instructions; their precise layout is - // documented here: - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; - - static __device__ float inline num2float(const half x) { return __half2float(x); } - - static __device__ half2 inline num2num2(const half x) { return __half2half2(x); } - - static __device__ half2 inline nums2num2(const half x1, const half x2) { return __halves2half2(x1, x2); } - - static __host__ __device__ half inline float2num(const float x) { return __float2half(x); } + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } }; template <> class ScalarType { -public: - using scalar_t = nv_bfloat16; - using scalar_t2 = nv_bfloat162; + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; - using FragA = Vec; - using FragB = Vec; - using FragC = Vec; - using FragS = Vec; + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static __device__ float inline num2float(const nv_bfloat16 x) { return __bfloat162float(x); } - - static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { return __bfloat162bfloat162(x); } - - static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) { return __halves2bfloat162(x1, x2); } - - static __host__ __device__ nv_bfloat16 inline float2num(const float x) { return __float2bfloat16(x); } + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } #endif }; -} +} // namespace gptq_marlin #endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 0d3da6240dbca..4adc158eb14ea 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -12,14 +12,14 @@ static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 template -__global__ void -marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) {} -} // namespace gptq_marlin +} // namespace gptq_marlin -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -30,10 +30,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, #else template -__global__ void -marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, int size_k, int size_n) { +__global__ void marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, + int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -61,8 +61,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int perm_size = tile_k_size / 4; - int4 *sh_perm_ptr = sh; - int4 *sh_pipe_ptr = sh_perm_ptr; + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; if constexpr (has_perm) { sh_pipe_ptr += perm_size; } @@ -76,7 +76,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, auto load_perm_to_shared = [&](int k_tile_id) { int first_k_int4 = (k_tile_id * tile_k_size) / 4; - int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); if (threadIdx.x < perm_size) { sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; @@ -92,22 +92,22 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int first_n = n_tile_id * tile_n_size; - int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; if constexpr (has_perm) { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; - uint32_t const *sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); + uint32_t const* sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; int src_k_packed = src_k / pack_factor; cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( + reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); } @@ -120,7 +120,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( + reinterpret_cast( &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); } @@ -151,10 +151,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int sh_stride = 64; constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); uint32_t vals[8]; @@ -176,17 +176,16 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } } else { - uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; -#pragma unroll + #pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; @@ -206,7 +205,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } @@ -218,7 +217,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t res1 = 0; uint32_t res2 = 0; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); @@ -230,14 +229,14 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; -#pragma unroll + #pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; @@ -248,7 +247,7 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { -#pragma unroll + #pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); @@ -260,21 +259,21 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } } -} // namespace gptq_marlin - -#define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin::marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin::marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } +} // namespace gptq_marlin + + #define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } -torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, +torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 @@ -318,11 +317,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, bool has_perm = perm.size(0) != 0; // Get ptrs - uint32_t const *b_q_weight_ptr = - reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const *perm_ptr = - reinterpret_cast(perm.data_ptr()); - uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + uint32_t const* b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); // Get dev info int dev = b_q_weight.get_device(); diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 002a70001885d..03d66cecedf1f 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -25,7 +25,10 @@ #include -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin { @@ -38,9 +41,10 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; using I4 = Vec; @@ -51,29 +55,32 @@ using I4 = Vec; using FragA = Vec; using FragB = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Asynchronous global->shared copy -__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -82,28 +89,30 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, - FragC &frag_c) { - const uint32_t *a = reinterpret_cast(&a_frag); - const uint32_t *b = reinterpret_cast(&frag_b); - float *c = reinterpret_cast(&frag_c); - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), - "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -113,7 +122,8 @@ __device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -138,24 +148,24 @@ __device__ inline FragB dequant(int q) { const int MUL = 0x2c002c00; const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -170,7 +180,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -187,26 +197,27 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { } } -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -241,11 +252,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -261,27 +272,22 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -293,29 +299,30 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; init_slice(); - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // We typically use `constexpr` to indicate that this value is a compile-time // constant constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory + 8; // delta between subsequent A tiles in global memory int a_gl_rd_delta_i = a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile constexpr int a_sh_wr_delta = - a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads + (thread_n_blocks / 4)); // between shared memory tile reads constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile + a_sh_stride * 16; // within a shared memory tile constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile + a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; @@ -368,7 +375,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -387,13 +394,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -403,16 +410,16 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; @@ -421,34 +428,33 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -475,37 +481,35 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); // If there are no groups, we can just scale the final output once and can // avoid doing so for each weight. - if (group_blocks != -1) - scale(frag_b0, frag_s[k % 2][j], 0); + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) - scale(frag_b1, frag_s[k % 2][j], 1); -#pragma unroll + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -530,38 +534,38 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. -#pragma unroll + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -571,9 +575,9 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -592,39 +596,39 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int row = (threadIdx.x % 32) / 4; if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half *>(&c_red)[j]); + __half2float(reinterpret_cast<__half*>(&c_red)[j]); } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half *>(&c)[j] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -658,17 +662,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) { + auto write = [&](int idx, float c0, float c1, FragS& s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); if (group_blocks == - -1) // for per-column quantization we finally apply the scale here + -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); - ((half2 *)sh)[idx] = res; + ((half2*)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -685,7 +689,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -699,9 +703,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -711,12 +714,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { -#pragma unroll + #pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { @@ -728,8 +731,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk matmul(k); } slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -742,8 +744,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if (group_blocks == -1 && last) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce(); @@ -751,17 +752,17 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -770,13 +771,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -787,26 +787,27 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void -Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Marlin is not implemented yet for SM < 8.0 assert(false); @@ -819,10 +820,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = - 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; @@ -831,7 +832,7 @@ static constexpr int tile_size = 16; static constexpr int max_par = 16; static constexpr int pack_factor_4bit = - 8; // We have 8 4-bit vals inside a 32 bit + 8; // We have 8 4-bit vals inside a 32 bit #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ GROUP_BLOCKS, NUM_THREADS) \ @@ -858,23 +859,23 @@ thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || @@ -907,7 +908,6 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, } thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { @@ -926,20 +926,20 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) -void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, - int prob_n, int prob_k, void *workspace, int groupsize = -1, +void marlin_cuda(const void* A, const void* B, void* C, void* s, int prob_m, + int prob_n, int prob_k, void* workspace, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1, int max_par = 16) { int tot_m = prob_m; @@ -996,12 +996,12 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, " is not divisible by group_blocks = ", group_blocks); } - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; @@ -1011,8 +1011,7 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_m = 64 * par; i += 4 * (par - 1); thread_m_blocks = 4; @@ -1041,12 +1040,11 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, } } -} // namespace marlin +} // namespace marlin -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, +torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_scales, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k) { - // Verify M TORCH_CHECK(size_m == a.size(0), "Shape mismatch: a.size(0) = " + str(a.size(0)) + @@ -1074,9 +1072,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); diff --git a/csrc/quantization/marlin/sparse/common/base.h b/csrc/quantization/marlin/sparse/common/base.h index 929b39d7642f1..16018d331bec2 100644 --- a/csrc/quantization/marlin/sparse/common/base.h +++ b/csrc/quantization/marlin/sparse/common/base.h @@ -26,12 +26,14 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } // corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee // this. -template struct Vec { +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) { return elems[i]; } + __device__ T& operator[](int i) { return elems[i]; } }; -template struct ShapeBase { +template +struct ShapeBase { static constexpr int M = M_, N = N_, K = K_; }; @@ -44,6 +46,6 @@ using FragA = Vec; using FragB = Vec; using FragM = Vec; using FragC = Vec; -using FragS = Vec; // quantization scales +using FragS = Vec; // quantization scales -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mem.h b/csrc/quantization/marlin/sparse/common/mem.h index a49d15ca544eb..83e3578d2f511 100644 --- a/csrc/quantization/marlin/sparse/common/mem.h +++ b/csrc/quantization/marlin/sparse/common/mem.h @@ -21,41 +21,44 @@ namespace marlin_24 { // Predicated asynchronous global->shared copy; used for inputs A where we apply // predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred_zfill(void *smem_ptr, - const void *glob_ptr, +__device__ inline void cp_async4_pred_zfill(void* smem_ptr, + const void* glob_ptr, bool pred = true, const bool zfill = false) { const int BYTES = 16; int src_in_bytes = (zfill ? 0 : BYTES); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes)); } -__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Asynchronous global->shared copy -__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -64,22 +67,23 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template __device__ inline void cp_async_wait() { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem)); } -__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_m); +__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_m); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) @@ -88,8 +92,8 @@ __device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) { // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. -__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { - uint32_t *a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n" @@ -98,7 +102,7 @@ __device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) { } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int *lock, int count) { +__device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { int state = -1; do @@ -113,7 +117,7 @@ __device__ inline void barrier_acquire(int *lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int *lock, bool reset = false) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -129,4 +133,4 @@ __device__ inline void barrier_release(int *lock, bool reset = false) { : "l"(lock), "r"(val)); } } -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index 9319456677d36..45ab67a78a1de 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -22,51 +22,56 @@ namespace marlin_24 { // m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32 // output/accumulation. -__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1, - const FragA &frag_b, FragC &frag_c, FragM &frag_m, +__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, + const FragA& frag_b, FragC& frag_c, FragM& frag_m, const int psel) { - const uint32_t *a0 = reinterpret_cast(&a_frag0); - const uint32_t *a1 = reinterpret_cast(&a_frag1); - const uint32_t *b = reinterpret_cast(&frag_b); - const uint32_t *e = reinterpret_cast(&frag_m); - float *c = reinterpret_cast(&frag_c); + const uint32_t* a0 = reinterpret_cast(&a_frag0); + const uint32_t* a1 = reinterpret_cast(&a_frag1); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* e = reinterpret_cast(&frag_m); + float* c = reinterpret_cast(&frag_c); if (psel == 0) { - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); } else { - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), - "r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), - "f"(c[2]), "f"(c[3]), "r"(e[0])); - asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " - "{%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) - : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), - "r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), - "f"(c[6]), "f"(c[7]), "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]), + "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), + "r"(e[0])); + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " + "{%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + : "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]), + "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]), + "r"(e[0])); } } // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template __device__ inline int lop3(int a, int b, int c) { +template +__device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -120,11 +125,11 @@ __device__ inline FragB dequant_4bit(int q) { const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } @@ -143,24 +148,24 @@ __device__ inline FragB dequant_8bit(int q) { static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. -__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } -__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, - FragS &s0, float *c4, float *c5, float *c6, - float *c7, FragS &s1) { +__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3, + FragS& s0, float* c4, float* c5, float* c6, + float* c7, FragS& s1) { *c0 = __fmul_rn(*c0, __half2float(s0[0].x)); *c1 = __fmul_rn(*c1, __half2float(s0[0].y)); *c2 = __fmul_rn(*c2, __half2float(s0[1].x)); @@ -172,4 +177,4 @@ __device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3, *c7 = __fmul_rn(*c7, __half2float(s1[1].y)); } -} // namespace marlin_24 +} // namespace marlin_24 diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 42b0566183a8d..54ad27676e207 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -32,12 +32,15 @@ #else -#include "common/mem.h" -#include "common/mma.h" + #include "common/mem.h" + #include "common/mma.h" #endif -template inline std::string str(T x) { return std::to_string(x); } +template +inline std::string str(T x) { + return std::to_string(x); +} namespace marlin_24 { @@ -45,7 +48,7 @@ namespace marlin_24 { // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int THREADS = 256; -static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory +static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 128; @@ -54,35 +57,36 @@ static constexpr int max_par = 16; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) {} -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -92,29 +96,30 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, #else -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > __global__ void Marlin_24( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4* __restrict__ meta, // 2bit metadata information about 2:4 + // format on B + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -174,27 +179,22 @@ __global__ void Marlin_24( auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; + if (col_off > 0) slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + if (col_off > 0) slice_idx--; } } if (slice_col == n_tiles) { @@ -207,7 +207,7 @@ __global__ void Marlin_24( init_slice(); // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // stride of an A matrix tile in shared memory constexpr int a_sh_stride = 32 * thread_k_blocks / 8; @@ -239,9 +239,9 @@ __global__ void Marlin_24( constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); constexpr int m_sh_wr_delta = threads / 2; @@ -305,7 +305,7 @@ __global__ void Marlin_24( // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; } @@ -325,13 +325,13 @@ __global__ void Marlin_24( // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < thread_m_blocks; j++) { a_sh_rd_trans[0][i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -344,23 +344,23 @@ __global__ void Marlin_24( // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4 *meta_ptr[m_sh_iters]; -#pragma unroll + const int4* meta_ptr[m_sh_iters]; + #pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); - int4 *sh_m = sh_s + (stages * s_sh_stage); + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + int4* sh_m = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks][2]; I4 frag_b_quant[2][b_thread_vecs]; @@ -370,46 +370,43 @@ __global__ void Marlin_24( // Zero accumulators. auto zero_accums = [&]() { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll + #pragma unroll for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], - B_ptr[i] + j); + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } B_ptr[i] += b_gl_rd_delta_o; } - int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; -#pragma unroll + int4* sh_meta_stage = sh_m + m_sh_stage * pipe; + #pragma unroll for (int i = 0; i < m_sh_iters; i++) { if (m_sh_wr_pred) - cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], - meta_ptr[i]); + cp_async4(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], meta_ptr[i]); meta_ptr[i] += m_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -436,13 +433,13 @@ __global__ void Marlin_24( // theoretically better attempts have lead to bad instruction ordering by // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { - int4 *sh_s_stage = + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { ldsm4(frag_a[k % 2][i][0], &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); @@ -450,24 +447,24 @@ __global__ void Marlin_24( &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( + frag_b_quant[k % 2][i] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); } // Load meta with ldsm4 - int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + int4* sh_m_stage = sh_m + m_sh_stage * pipe; ldsm4_m(frag_m[k % 2][0], &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll for (int j = 0; j < 4; j++) { FragB frag_b0; FragB frag_b1; @@ -480,7 +477,7 @@ __global__ void Marlin_24( frag_b1 = dequant_4bit(b_quant_shift); } else { - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; @@ -497,7 +494,7 @@ __global__ void Marlin_24( scale(frag_b1, frag_s[k % 2][j], 1); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], frag_m[k % 2][j / 2], j % 2); @@ -518,41 +515,41 @@ __global__ void Marlin_24( int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); -// Parallel logarithmic shared memory reduction. We make sure to avoid any -// unnecessary read or write iterations, e.g., for two warps we write only once -// by warp 1 and read only once by warp 0. -#pragma unroll + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + #pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll + #pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll + #pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -562,9 +559,9 @@ __global__ void Marlin_24( }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partitioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { // We are very careful here to reduce directly in the output buffer to // maximize L2 cache utilization in this step. To do this, we write out @@ -574,7 +571,7 @@ __global__ void Marlin_24( int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; @@ -584,10 +581,10 @@ __global__ void Marlin_24( int col = 2 * ((threadIdx.x % 32) % 4); if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + @@ -599,32 +596,32 @@ __global__ void Marlin_24( cp_async_wait<0>(); } -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + col + (i % 2) < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll + #pragma unroll for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll + #pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2] += __half2float( - reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + reinterpret_cast<__half*>(&c_red)[(j2 * 4 + j1)]); } } } if (!last) { int4 c; -#pragma unroll + #pragma unroll for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll + #pragma unroll for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( + reinterpret_cast<__half*>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + 4 * ((i % 4) / 2) + i % 2]); } @@ -643,9 +640,9 @@ __global__ void Marlin_24( auto write_result = [&]() { int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); @@ -654,22 +651,22 @@ __global__ void Marlin_24( c_gl_wr += (2 * thread_n_blocks) * slice_col; int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + (threadIdx.x % (2 * 2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, - float c4, float c5, float c6, float c7, FragS &s1) { + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS& s0, + float c4, float c5, float c6, float c7, FragS& s1) { uint2 res[2]; res[0] = to_half4(c0, c1, c2, c3); res[1] = to_half4(c4, c5, c6, c7); - half2 *tmp = (half2 *)&res; + half2* tmp = (half2*)&res; // for per-column quantization we finally apply the scale here if constexpr (group_blocks == -1 && num_bits == 4) { tmp[0] = __hmul2(tmp[0], s0[0]); @@ -677,12 +674,12 @@ __global__ void Marlin_24( tmp[2] = __hmul2(tmp[2], s1[0]); tmp[3] = __hmul2(tmp[3], s1[1]); } - ((int4 *)sh)[idx] = *((int4 *)&res[0]); + ((int4*)sh)[idx] = *((int4*)&res[0]); }; // RLC: only warp 0 and 1 baseline example if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { int wr = c_sh_wr; write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], @@ -707,7 +704,7 @@ __global__ void Marlin_24( } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -721,9 +718,8 @@ __global__ void Marlin_24( // Start global fetch and register load pipelines. auto start_pipes = [&]() { -#pragma unroll - for (int i = 0; i < stages - 1; i++) - fetch_to_shared(i, i, i < slice_iters); + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -733,10 +729,10 @@ __global__ void Marlin_24( // Main loop. while (slice_iters) { -// We unroll over both the global fetch and the register load pipeline to ensure -// all shared memory accesses are static. Note that both pipelines have even -// length meaning that the next iteration will always start at index 0. -#pragma unroll + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); @@ -747,8 +743,7 @@ __global__ void Marlin_24( pipe++; slice_iters--; - if (slice_iters == 0) - break; + if (slice_iters == 0) break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -762,13 +757,11 @@ __global__ void Marlin_24( // write-out if constexpr (group_blocks == -1) { if constexpr (num_bits == 8) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } else { if (last) { - if (s_sh_wr_pred) - cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } } @@ -780,14 +773,14 @@ __global__ void Marlin_24( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); } } else { if (last) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + *(float4*)(frag_s) = *(float4*)(&sh_s[s_sh_rd]); } } } @@ -798,7 +791,7 @@ __global__ void Marlin_24( // overflow in fp16) if constexpr (group_blocks == -1 && num_bits == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { -#pragma unroll + #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { scale_floats(&frag_c[i][0][0][0], &frag_c[i][1][0][0], &frag_c[i][2][0][0], &frag_c[i][3][0][0], frag_s[0][0], @@ -827,13 +820,13 @@ __global__ void Marlin_24( } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; @@ -843,19 +836,17 @@ __global__ void Marlin_24( if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll + #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; -#pragma unroll + #pragma unroll for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] -= m_gl_stride; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + #pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -866,26 +857,26 @@ __global__ void Marlin_24( #endif -#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS) { \ - cudaFuncSetAttribute( \ - Marlin_24, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin_24 \ - <<>>(A_ptr, B_ptr, meta_ptr, \ - C_ptr, s_ptr, prob_n, \ - prob_m, prob_k, locks); \ +#define CALL_IF_2_4(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute( \ + Marlin_24, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin_24 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, prob_n, \ + prob_m, prob_k, locks); \ } -void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, - void *s, int prob_m, int prob_n, int prob_k, - void *workspace, int num_bits, int groupsize = -1, +void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, + void* s, int prob_m, int prob_n, int prob_k, + void* workspace, int num_bits, int groupsize = -1, int dev = 0, cudaStream_t stream = 0, int thread_k = -1, int thread_m = -1, int sms = -1, int max_par = 16) { int tot_n = prob_n; @@ -904,8 +895,8 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, if (thread_k == -1 || thread_m == -1) { if (prob_n <= 16) { - // For small batchizes, better partitioningif is slightly more important than - // better compute utilization + // For small batchizes, better partitioningif is slightly more important + // than better compute utilization thread_k = 128; thread_m = 128; } else { @@ -914,7 +905,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, } } - int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction int thread_m_blocks = thread_m / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int blocks = sms; @@ -931,13 +922,13 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); - const int4 *A_ptr = (const int4 *)A; - const int4 *B_ptr = (const int4 *)B; - const int4 *meta_ptr = (const int4 *)meta; - int4 *C_ptr = (int4 *)C; - const int4 *s_ptr = (const int4 *)s; + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + const int4* meta_ptr = (const int4*)meta; + int4* C_ptr = (int4*)C; + const int4* s_ptr = (const int4*)s; - int *locks = (int *)workspace; + int* locks = (int*)workspace; for (int i = 0; i < tot_n_blocks; i += 4) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; @@ -946,8 +937,7 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_n_blocks - pad) / 64; - if (par > max_par) - par = max_par; + if (par > max_par) par = max_par; prob_n = 64 * par; i += 4 * (par - 1); thread_n_blocks = 4; @@ -956,16 +946,16 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, // For compilation speed, we only define the kernel configurations that have // seemed useful (in terms of performance) in our testing, however many more // are, in principle, possible. - + // the false is start of the CALL_IF macros - if (false) { - } // BMxBNxBK, group + if (false) { + } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(4, 16, 2, 2, 4) CALL_IF_2_4(4, 16, 3, 2, -1) CALL_IF_2_4(4, 16, 3, 2, 4) @@ -973,11 +963,11 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, CALL_IF_2_4(4, 16, 4, 2, 4) // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 - CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 - CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 - CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 CALL_IF_2_4(8, 16, 2, 2, 4) CALL_IF_2_4(8, 16, 3, 2, -1) CALL_IF_2_4(8, 16, 3, 2, 4) @@ -997,12 +987,12 @@ void marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, } } -} // namespace marlin_24 +} // namespace marlin_24 -torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_meta, - torch::Tensor &b_scales, - torch::Tensor &workspace, int64_t num_bits, +torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, + torch::Tensor& b_meta, + torch::Tensor& b_scales, + torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k) { // Verify num_bits @@ -1037,9 +1027,9 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, " is not divisible by tile_size = " + str(marlin_24::tile_size)); int actual_size_n = (b_q_weight.size(1) / marlin_24::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + TORCH_CHECK( + size_n == actual_size_n, + "size_n = " + str(size_n) + ", actual_size_n = " + str(actual_size_n)); // Verify meta TORCH_CHECK(b_meta.size(0) == size_k / 8 / 2 / 2, @@ -1081,7 +1071,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, ", is not divisible by b_scales.size(0) = " + str(b_scales.size(0))); groupsize = size_k / b_scales.size(0); - groupsize /= 2; // Because of 24 + groupsize /= 2; // Because of 24 } // Verify groupsize diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 09964903622b4..1b339fa4b392b 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -22,27 +22,23 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM - const half2* __restrict__ vec, + const half2* __restrict__ vec, #else - const __half2* __restrict__ vec, + const __half2* __restrict__ vec, #endif - const int* __restrict__ mat, + const int* __restrict__ mat, #ifndef USE_ROCM - half2* __restrict__ mul, + half2* __restrict__ mul, #else - float2* __restrict__ mul, + float2* __restrict__ mul, #endif - const __half* __restrict__ lookup_table, - int height, - int width, - int batch, - int vec_height -) { + const __half* __restrict__ lookup_table, int height, int width, int batch, + int vec_height) { const int blockwidth2 = BLOCKWIDTH / 2; int row = BLOCKHEIGHT4 * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; #ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; @@ -73,14 +69,16 @@ __global__ void NUQ4MatMulKernel( unsigned int tmp1; unsigned int lut_index1, lut_index2; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b) { i = width * row + col; res = __int2half_rd(0); k = 0; __syncthreads(); if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x]; + blockvec[threadIdx.x] = + vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + + threadIdx.x]; __syncthreads(); while (k < blockwidth2) { @@ -143,7 +141,8 @@ __global__ void NUQ4MatMulKernel( #ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); #else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), + res); #endif i += width; @@ -179,46 +178,38 @@ __global__ void NUQ4MatMulKernel( } } -} // namespace squeezellm -} // namespace vllm +} // namespace squeezellm +} // namespace vllm // 4-bit matvec kernel (LUT-based) -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table -) { +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table) { int height = mat.size(0); int width = mat.size(1); int batch = vec.size(0); int vec_height = vec.size(1); - dim3 blocks( - (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH - ); + dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH); dim3 threads(BLOCKWIDTH); const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM - (half2*) vec.data(), + (half2*)vec.data(), #else - (__half2*) vec.data_ptr(), + (__half2*)vec.data_ptr(), #endif - mat.data_ptr(), + mat.data_ptr(), #ifndef USE_ROCM - (half2*) mul.data(), - (__half*) lookup_table.data(), + (half2*)mul.data(), (__half*)lookup_table.data(), #else - (float2*) mul.data_ptr(), - (__half*) lookup_table.data_ptr(), + (float2*)mul.data_ptr(), + (__half*)lookup_table.data_ptr(), #endif - height, width, batch, vec_height - ); + height, width, batch, vec_height); } #undef BLOCKWIDTH diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bb5171f854d55..9af4aae516151 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -20,12 +21,12 @@ #include "cuda_compat.h" namespace vllm { -template +template __inline__ __device__ T warpReduceSum(T val) { static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, "numLanes is not a positive power of 2!"); static_assert(numLanes <= WARP_SIZE); - #pragma unroll +#pragma unroll for (int mask = numLanes >> 1; mask > 0; mask >>= 1) val += VLLM_SHFL_XOR_SYNC(val, mask); return val; @@ -38,22 +39,23 @@ static constexpr int _nextPow2(unsigned int num) { } /* Calculate the sum of all elements in a block */ -template +template __inline__ __device__ T blockReduceSum(T val) { static_assert(maxBlockSize <= 1024); if constexpr (maxBlockSize > WARP_SIZE) { val = warpReduceSum(val); - // Calculates max number of lanes that need to participate in the last warpReduce + // Calculates max number of lanes that need to participate in the last + // warpReduce constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; static __shared__ T shared[maxActiveLanes]; int lane = threadIdx.x % WARP_SIZE; int wid = threadIdx.x / WARP_SIZE; - if (lane == 0) - shared[wid] = val; + if (lane == 0) shared[wid] = val; __syncthreads(); - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] + : (T)(0.0f); val = warpReduceSum(val); } else { // A single warpReduce is equal to blockReduce @@ -62,4 +64,4 @@ __inline__ __device__ T blockReduceSum(T val) { return val; } -} // namespace vllm +} // namespace vllm diff --git a/format.sh b/format.sh index 5f6e20256d404..aaec25a8aa0dc 100755 --- a/format.sh +++ b/format.sh @@ -26,6 +26,7 @@ RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) +CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { @@ -40,6 +41,7 @@ tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)" tool_version_check "isort" "$ISORT_VERSION" "$(grep isort requirements-dev.txt | cut -d'=' -f3)" tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "clang-format" "$CLANGFORMAT_VERSION" "$(grep clang-format requirements-dev.txt | cut -d'=' -f3)" YAPF_FLAGS=( '--recursive' @@ -179,7 +181,6 @@ lint_changed() { } # Run Ruff -echo 'vLLM ruff:' ### This flag lints individual files. --files *must* be the first command line ### arg to use this option. if [[ "$1" == '--files' ]]; then @@ -192,6 +193,7 @@ else # Format only the files that changed in last commit. lint_changed fi +echo 'vLLM ruff: Done' # check spelling of specified files isort_check() { @@ -233,6 +235,59 @@ else fi echo 'vLLM isort: Done' +# Clang-format section +# Exclude some files for formatting because they are vendored +# NOTE: Keep up to date with .github/workflows/clang-format.yml +CLANG_FORMAT_EXCLUDES=( + 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu' + 'csrc/punica/bgmv/bgmv_config.h' + 'csrc/punica/bgmv/bgmv_impl.cuh' + 'csrc/punica/bgmv/vec_dtypes.cuh' + 'csrc/punica/punica_ops.cu' + 'csrc/punica/type_convert.h' +) + +# Format specified files with clang-format +clang_format() { + clang-format -i "$@" +} + +# Format files that differ from main branch with clang-format. +clang_format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause clang-format to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + MERGEBASE="$(git merge-base origin/main HEAD)" + + # Get the list of changed files, excluding the specified ones + changed_files=$(git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.h' '*.cpp' '*.cu' '*.cuh' | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}")) + if [ -n "$changed_files" ]; then + echo "$changed_files" | xargs -P 5 clang-format -i + fi +} + +# Format all files with clang-format +clang_format_all() { + find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ + | grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \ + | xargs clang-format -i +} + +# Run clang-format +if [[ "$1" == '--files' ]]; then + clang_format "${@:2}" +elif [[ "$1" == '--all' ]]; then + clang_format_all +else + clang_format_changed +fi +echo 'vLLM clang-format: Done' + + if ! git diff --quiet &>/dev/null; then echo 'Reformatted files. Please review and stage the changes.' echo 'Changes not staged for commit:' diff --git a/requirements-dev.txt b/requirements-dev.txt index 4f6c27d95fe6a..cf2bb9bef22d9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ tomli==2.0.1 ruff==0.1.5 codespell==2.2.6 isort==5.13.2 +clang-format==18.1.5 # type checking mypy==1.9.0 From c74c913bfbefc5d7a1302557eb35cdcbecd91f67 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Wed, 22 May 2024 22:02:58 +0900 Subject: [PATCH 327/413] [misc] remove comments that were supposed to be removed (#4977) --- tests/lora/conftest.py | 1 - vllm/lora/models.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 95fc65cdd1a8f..e5cf9cd48b65d 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -185,7 +185,6 @@ def long_context_lora_files_32k(): return snapshot_download(repo_id="SangBinCho/long_context_32k_testing") -# SANG-TODO Download long lora files. @pytest.fixture(scope="session") def long_context_infos(long_context_lora_files_16k_1, long_context_lora_files_16k_2, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index d001d17144d98..a2092d31ea9aa 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -105,8 +105,6 @@ def convert_mapping( lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) long_lora_offsets[i] = lora_offset - # SANG-TODO - # index_mapping_indices[i] = i indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices From 8674f9880e2d8574c2adc759027e0f27dc9b95de Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 22 May 2024 10:10:43 -0400 Subject: [PATCH 328/413] [Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954) Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs --- .../cutlass_w8a8/scaled_mm_dq_c2x.cu | 6 ++- .../cutlass_w8a8/scaled_mm_dq_c3x.cu | 5 ++- tests/kernels/test_cutlass.py | 41 +++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index e62fe731a98d3..3a6b8a226e18c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,6 +1,8 @@ #include #include +#include + // clang-format will break include orders // clang-format off #include "cute/tensor.hpp" @@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, size_t workspace_size = gemm_op.get_workspace_size(args); cutlass::device_memory::allocation workspace(workspace_size); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + CUTLASS_CHECK(gemm_op.can_implement(args)); - cutlass::Status status = gemm_op(args, workspace.get()); + cutlass::Status status = gemm_op(args, workspace.get(), stream); CUTLASS_CHECK(status); } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 12efcac7bb919..5fd6d8ff20867 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a, size_t workspace_size = gemm_op.get_workspace_size(args); TORCH_CHECK(workspace_size == 0); - cutlass::Status status = gemm_op.run(args); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + cutlass::Status status = gemm_op.run(args, stream); CUTLASS_CHECK(status); } } // namespace diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index fdfd1dee29ce6..2cf0e86e5ca44 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -190,3 +190,44 @@ def test_cutlass_subset(): b.to(dtype=torch.float32)).to(dtype=torch.bfloat16) assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) + + +# Test to make sure cuda graphs work +class CutlassLayer(torch.nn.Module): + + def __init__(self, b, scale_a, scale_b, out_dtype): + super().__init__() + self.b = b + self.scale_a = scale_a + self.scale_b = scale_b + self.out_dtype = out_dtype + + def forward(self, a): + return ops.cutlass_scaled_mm_dq(a, self.b, self.scale_a, self.scale_b, + self.out_dtype) + + +def test_cutlass_cuda_graph(): + m, n, k = 512, 512, 512 + + a = to_int8(torch.randn((m, k), device="cuda")) + b = to_int8(torch.randn((n, k), device="cuda").t()) + + scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32) / 10) + scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32) / 10) + + # Construct a trivial model with a single layer that calls a CUTLASS kernel + model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + out = model(a) + out.zero_() + g.replay() + + baseline = torch.mm(scale_a * a.to(dtype=torch.float32), + scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) + assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) From a3a73ab0696b6692f3eecf80271a01fa97bd001d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 22 May 2024 13:28:20 -0700 Subject: [PATCH 329/413] [Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893) The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter). --- benchmarks/benchmark_latency.py | 14 ++-- benchmarks/benchmark_throughput.py | 12 ++- .../kernels/benchmark_paged_attention.py | 10 +-- tests/models/test_fp8.py | 80 ++++++++++++------- vllm/attention/layer.py | 27 ++++++- vllm/config.py | 8 +- vllm/engine/arg_utils.py | 7 +- .../model_executor/layers/quantization/fp8.py | 47 ++++++++++- vllm/model_executor/models/arctic.py | 3 +- vllm/model_executor/models/baichuan.py | 6 +- vllm/model_executor/models/bloom.py | 3 +- vllm/model_executor/models/chatglm.py | 13 ++- vllm/model_executor/models/commandr.py | 13 ++- vllm/model_executor/models/dbrx.py | 13 ++- vllm/model_executor/models/deepseek.py | 3 +- vllm/model_executor/models/falcon.py | 9 ++- vllm/model_executor/models/gemma.py | 3 +- vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 3 +- vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/jais.py | 13 ++- vllm/model_executor/models/llama.py | 32 ++++---- vllm/model_executor/models/minicpm.py | 3 +- vllm/model_executor/models/mixtral.py | 29 +++++-- vllm/model_executor/models/mixtral_quant.py | 15 ++-- vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/olmo.py | 3 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/orion.py | 3 +- vllm/model_executor/models/phi.py | 3 +- vllm/model_executor/models/qwen.py | 3 +- vllm/model_executor/models/qwen2.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 3 +- vllm/model_executor/models/stablelm.py | 3 +- vllm/model_executor/models/starcoder2.py | 15 ++-- vllm/model_executor/models/xverse.py | 3 +- vllm/utils.py | 2 + vllm/worker/model_runner.py | 17 ++-- 40 files changed, 284 insertions(+), 158 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index f84e3453947c9..a9657f7859750 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -153,15 +153,13 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help='enforce eager mode and disable CUDA graph') parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], - default='auto', - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 41f443968c3c4..7c8cb5ee8cea2 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -323,15 +323,13 @@ def main(args: argparse.Namespace): action="store_true", help="enforce eager execution") parser.add_argument( - "--kv-cache-dtype", + '--kv-cache-dtype', type=str, - choices=["auto", "fp8"], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index ca7967c1ab0d2..fc9621e885dc4 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -183,13 +183,11 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument( "--kv-cache-dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"], default="auto", - help= - 'Data type for kv cache storage. If "auto", will use model data type. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version greater ' - 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for ' - 'common inference criteria.') + help="Data type for kv cache storage. If 'auto', will use model " + "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") args = parser.parse_args() print(args) diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 664e951a89f2a..0a5819ea3f054 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -16,31 +16,55 @@ MAX_MODEL_LEN = 1024 MODELS = [ - "nm-testing/Meta-Llama-3-8B-Instruct-FP8", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", "meta-llama/Meta-Llama-3-8B-Instruct", ] EXPECTED_STRS_MAP = { - "nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred to a', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o', - ], - "meta-llama/Meta-Llama-3-8B-Instruct": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' - ], + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { + "auto": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system made up of several basic components that work together to enable it to', + 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' + ] + }, + "meta-llama/Meta-Llama-3-8B-Instruct": { + "auto": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], + "fp8": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' + ] + }, } capability = torch.cuda.get_device_capability() @@ -52,14 +76,14 @@ @pytest.mark.skipif(fp8_not_supported, reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) -def test_models( - example_prompts, - model_name, -) -> None: +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +def test_models(example_prompts, model_name, kv_cache_dtype) -> None: model = LLM(model=model_name, max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, enforce_eager=True, - quantization="fp8") + quantization="fp8", + kv_cache_dtype=kv_cache_dtype) tokenizer = AutoTokenizer.from_pretrained(model_name) formatted_prompts = [ @@ -81,8 +105,8 @@ def test_models( generations.append(outputs[0].outputs[0].text) del model - print(generations) - expected_strs = EXPECTED_STRS_MAP[model_name] + print(model_name, kv_cache_dtype, generations) + expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] for i in range(len(example_prompts)): generated_str = generations[i] expected_str = expected_strs[i] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4299726bdca4b..dc7b3940bc9b7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,6 +7,8 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) class Attention(nn.Module): @@ -30,6 +32,7 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() if cache_config is not None: @@ -40,6 +43,27 @@ def __init__( block_size = 16 if num_kv_heads is None: num_kv_heads = num_heads + + # The default kv_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized kv_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._kv_scale = 1.0 + quant_method = quant_config.get_quant_method( + self) if quant_config else None + if quant_method is not None: + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # When FP8 quantization is enabled, we make a parameter + # "kv_scale" so that it can be loaded from FP8 checkpoint. + # The kv_scale will then be converted back + # to self._kv_scale in a native float32 value after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() @@ -57,10 +81,9 @@ def forward( value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, - kv_scale: float = 1.0, ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, - kv_scale) + self._kv_scale) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/config.py b/vllm/config.py index 3256c11967914..b245a1a3ee6d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -355,14 +355,12 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype == "fp8": + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " - "But it may cause slight accuracy drop without scaling " - "factors. FP8_E5M2 (without scaling) is only supported on " - "cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 " - "is instead supported for common inference criteria.") + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0a9ec7472fbca..538e3427e37fb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -191,12 +191,11 @@ def add_cli_args( parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' - 'data type. FP8_E5M2 (without scaling) is only supported on cuda ' - 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=nullable_str, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ff996741c1d00..b084b9cee4983 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -8,8 +8,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -58,9 +59,13 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": activation_scheme=activation_scheme) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): return Fp8LinearMethod(self) + if isinstance(layer, Attention): + return Fp8KVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -251,6 +256,44 @@ def apply(self, return torch.narrow(output, 0, 0, x.shape[0]) +class Fp8KVCacheMethod(QuantizeMethodBase): + """Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """Create "weight" (aka kv_scale) for an attention layer. + + Args: + layer: The layer that is using the QuantizeMethodBase factory. + """ + # Initialize the KV cache scale to 1.0 as the default value. + # If the kv_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") + + def process_weights_after_loading(self, layer: Module) -> None: + # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + layer._kv_scale = kv_scale + if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " + "cause accuracy issues. Please make sure kv-cache scaling " + "factor is available in the fp8 checkpoint.") + del layer.kv_scale + + def all_close_1d(x: torch.Tensor) -> bool: assert len(x.shape) == 1 return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index cb99939cbb17a..313762b1353d1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -268,7 +268,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 58b3405d319d1..babb92e7cdcef 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -154,7 +154,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.rotary_emb = get_rope( self.head_dim, @@ -166,7 +167,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index fe2de87b20dc9..a29aee4cffb7d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -111,7 +111,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index ed65d76f7b5b9..e3a5e43e23e1c 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -86,13 +86,12 @@ def __init__( base=10000 * rope_ratio, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 7354d11f98b15..84786921ce1b4 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -177,13 +177,12 @@ def __init__( rope_scaling=self.rope_scaling, is_neox_style=False, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, self.head_dim), diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 083ddf0159f71..8ff19a2015e0f 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -218,13 +218,12 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 62e04f9649915..8fbda2638aaa3 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -232,7 +232,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ab9e1994be426..ba707adb03dfe 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -153,7 +153,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.inv_norm_factor, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + quant_config=quant_config) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -165,13 +166,15 @@ def __init__( self.head_dim, self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + quant_config=quant_config) else: self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index d1502b718a773..27dda00b66af4 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -157,7 +157,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 0deaa58ed9eb5..cc83f6eb6d94d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -75,7 +75,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index c20fb3230c394..f488ef40039c0 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -88,7 +88,8 @@ def __init__( self.head_dim, scale=self.scale, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 5f4d8ec3d3a7a..47fd5788a4c35 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -88,7 +88,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index dcb52ff666c95..eb0fcc8f26a58 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -89,7 +89,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 65f7ddb8b082c..e75c567f589c8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -117,7 +117,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index df30fd1ba0a37..869b8fc91fd64 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -105,13 +105,12 @@ def __init__( head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end] - self.attn = Attention( - self.num_heads, - self.head_dim, - scale=self.scale, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f2996c240aaf4..23141124e69e1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -47,7 +47,7 @@ default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import is_hip +from vllm.utils import is_hip, print_warning_once class LlamaMLP(nn.Module): @@ -119,15 +119,6 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - # This will be overwritten by model initialization if we are using it. - # N.B. currently we only support per tensor scalar scaling factors - # & only applicable to ROCm (AMD GPU). - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - self.kv_scale = 1.0 - self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, @@ -155,7 +146,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -167,8 +159,7 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, - self.kv_scale) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -421,6 +412,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + f"Found kv scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_kv_scale_name}). kv-scale is " + "not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -445,7 +449,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.kv_scale = scaling_factor + layer_self_attn.attn._kv_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 0b85cf1c94795..59fbf8e1b35f2 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -236,7 +236,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e3ac33e0452fe..ea95cf7380d54 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -308,14 +308,13 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, @@ -581,6 +580,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index ee2626b1c1aa2..9b99ff729aadd 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -213,14 +213,13 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 716ac51cde94d..5f9e4d86f3cd8 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -110,7 +110,8 @@ def __init__( scaling, alibi_slopes=alibi_slopes, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 69f23bbfb5d0a..39270f71ec46f 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -96,7 +96,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) # Attention output projection. self.o_proj = RowParallelLinear( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d241756e50f4a..4bf59105dbabb 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -91,7 +91,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, scale=self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 59cd42e31b374..133a10e6bb3e8 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -121,7 +121,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 193a29d20c894..c8e61735a9bb6 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -110,7 +110,8 @@ def __init__(self, self.attn = Attention(self.num_heads, self.head_size, scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d158846a3a1f5..d22ea6b79de0f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -106,7 +106,8 @@ def __init__( self.attn = Attention(self.num_heads, self.head_dim, self.scaling, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 97ab6168c3230..ec203c3b9001a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -141,7 +141,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=self.sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index a0d3b0406ef4a..564536f2dd248 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -241,7 +241,8 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 8b4a5507feade..a6ed3800bed0f 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -127,7 +127,8 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_key_value_heads, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 3c19d63276a77..91ffd0861c39d 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -97,14 +97,13 @@ def __init__(self, base=int(self.rope_theta), is_neox_style=True, ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, - cache_config=cache_config, - ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 6ef230a8ebbca..dda13d83f89a3 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -135,7 +135,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, sliding_window=sliding_window, - cache_config=cache_config) + cache_config=cache_config, + quant_config=quant_config) def forward( self, diff --git a/vllm/utils.py b/vllm/utils.py index 552b43e7f82b2..4cb9d905097bf 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -31,6 +31,8 @@ "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, } diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e264fede0ee64..9720363ac300e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,5 @@ import time +import warnings from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union import numpy as np @@ -168,11 +169,21 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) if self.kv_cache_dtype == "fp8" and is_hip(): - # Currently scaled KV cache is only enabled on ROCm + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. if self.model_config.quantization_param_path is not None: if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) self.model.load_kv_cache_scales( self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) else: raise RuntimeError( "Using FP8 KV cache and scaling factors provided but " @@ -183,10 +194,6 @@ def load_model(self) -> None: "Using FP8 KV cache but no scaling factors " "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - elif self.model_config.quantization_param_path is not None: - logger.warning("KV cache scaling factors provided, " - "but the KV cache data type is not FP8. " - "KV cache scaling factors will not be used.") def save_sharded_state( self, From 97b030005c7f5cde7c1b97c718a8841db7d6220b Mon Sep 17 00:00:00 2001 From: raywanb <112235519+raywanb@users.noreply.github.com> Date: Thu, 23 May 2024 04:58:59 +0800 Subject: [PATCH 330/413] [Model] LoRA gptbigcode implementation (#3949) --- csrc/punica/bgmv/bgmv_config.h | 4 +++ tests/lora/test_punica.py | 2 ++ vllm/lora/models.py | 2 ++ vllm/model_executor/models/gpt_bigcode.py | 31 +++++++++++++++++++---- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 98ac8de779e13..4b376261d30d2 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 4096) \ @@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 6144) \ + f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ @@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ @@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6400, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 193e3906997c4..f021c003b1322 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -58,6 +58,7 @@ def _lora_ref_impl( 2560, 2752, 3072, + 3328, 3456, 3584, 4096, @@ -66,6 +67,7 @@ def _lora_ref_impl( 5504, 5632, 6144, + 6400, 6848, 6912, 7168, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index a2092d31ea9aa..3e82856866d85 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -310,7 +310,9 @@ def from_local_checkpoint( if part_name not in expected_lora_modules: unexpected_modules.append(module) # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + print(unexpected_modules, "modules") raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f488ef40039c0..69b75763e9a3d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -191,14 +191,19 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ GPTBigCodeBlock(config, cache_config, quant_config) @@ -226,19 +231,35 @@ def forward( class GPTBigCodeForCausalLM(nn.Module): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] def __init__( self, config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, cache_config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config, + lora_config) self.lm_head_weight = self.transformer.wte.weight - self.logits_processor = LogitsProcessor(config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward( From eb6d3c264d0cd8e44dec16bca7947fbe96415ce9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 22 May 2024 14:17:27 -0700 Subject: [PATCH 331/413] [Core] Eliminate parallel worker per-step task scheduling overhead (#4894) --- vllm/engine/async_llm_engine.py | 10 +- vllm/engine/llm_engine.py | 8 ++ vllm/executor/distributed_gpu_executor.py | 123 ++++++++++++++++----- vllm/executor/executor_base.py | 8 ++ vllm/executor/multiproc_gpu_executor.py | 73 ++++++++----- vllm/executor/ray_gpu_executor.py | 86 +++++++-------- vllm/spec_decode/ngram_worker.py | 4 +- vllm/spec_decode/spec_decode_worker.py | 125 +++++++++++----------- vllm/worker/embedding_model_runner.py | 5 +- vllm/worker/model_runner.py | 5 +- vllm/worker/worker.py | 103 +++++++++++------- vllm/worker/worker_base.py | 7 +- 12 files changed, 348 insertions(+), 209 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a37bac02823a..5a15ed67e3327 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -234,6 +234,14 @@ async def step_async( # Log stats. self.do_log_stats(scheduler_outputs, output) + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + await self.model_executor.stop_remote_worker_execution_loop_async() + return request_outputs async def encode_request_async( @@ -687,7 +695,7 @@ async def encode( multi_modal_data: Multi modal data per request. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. Details: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 60e23d4df15bb..0631c0de76822 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -692,6 +692,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Log stats. self.do_log_stats(scheduler_outputs, output) + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + self.model_executor.stop_remote_worker_execution_loop() + return request_outputs def do_log_stats( diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index c5b1e61112afb..f7c608af1ad39 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -1,11 +1,12 @@ +import asyncio from abc import abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput logger = init_logger(__name__) @@ -13,6 +14,16 @@ class DistributedGPUExecutor(GPUExecutor): """Abstract superclass of multi-GPU executor implementations.""" + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -52,13 +63,28 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: - all_outputs = self._run_workers("execute_model", - driver_args=args, - driver_kwargs=kwargs) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return all_outputs[0] + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." @@ -88,39 +114,84 @@ def save_sharded_state( pattern=pattern, max_size=max_size) + @abstractmethod + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + @abstractmethod def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" raise NotImplementedError class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + @abstractmethod - async def _run_workers_async( + async def _driver_execute_model_async( self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - raise NotImplementedError + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. - async def execute_model_async(self, *args, - **kwargs) -> List[SamplerOutput]: - all_outputs = await self._run_workers_async("execute_model", - driver_args=args, - driver_kwargs=kwargs) + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError - # Only the driver worker returns the sampling results. - return all_outputs[0] + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 08aa58999b1ec..4d01939c2e38b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -74,6 +74,10 @@ def execute_model( """Executes at least one model step on the given sequences.""" raise NotImplementedError + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError @@ -109,6 +113,10 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 2a7b99c9dcbe1..8fa54454907b5 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,13 +1,14 @@ import asyncio import os from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, List, Optional from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -71,16 +72,34 @@ def shutdown(self): None)) is not None: worker_monitor.close() + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model( + execute_model_req=execute_model_req) + def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ if max_concurrent_workers: raise NotImplementedError( @@ -92,15 +111,12 @@ def _run_workers( for worker in self.workers ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + if async_run_remote_workers_only: + # Just return futures + return worker_outputs - # Start the driver worker after all the ray workers. driver_worker_method = getattr(self.driver_worker, method) - driver_worker_output = driver_worker_method(*driver_args, - **driver_kwargs) + driver_worker_output = driver_worker_method(*args, **kwargs) # Get the results of the workers. return [driver_worker_output @@ -111,30 +127,29 @@ def check_health(self) -> None: if not self.worker_monitor.is_alive(): raise RuntimeError("Worker processes are not running") + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, DistributedGPUExecutorAsync): - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) - driver_executor = make_async(getattr(self.driver_worker, method)) + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_model(execute_model_req) - # Run all the workers asynchronously. - coros = [driver_executor(*driver_args, **driver_kwargs)] + [ - worker.execute_method_async(method, *args, **kwargs) + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") for worker in self.workers ] - return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index dd3ee60682d30..bed356d1b6e58 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -42,6 +42,8 @@ def _init_executor(self) -> None: self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + self.extra_execute_model_run_workers_kwargs[ + "use_ray_compiled_dag"] = True def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -171,23 +173,23 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={"execute_model_req": execute_model_req}, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. - # Only the driver worker returns the sampling results. - return all_outputs[0] + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, @@ -198,9 +200,11 @@ def _run_workers( """Runs the given method on all workers. Can be used in the following ways: + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. - args/kwargs: All workers share the same args/kwargs - - args/kwargs and driver_args/driver_kwargs: Driver worker has - different args - all_args/all_kwargs: args/kwargs for each worker are specified individually """ @@ -209,11 +213,6 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - if driver_args is None: - driver_args = args if all_args is None else all_args[0] - if driver_kwargs is None: - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - count = len(self.workers) all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, 1, None) @@ -225,6 +224,7 @@ def _run_workers( # input. TODO(sang): Fix it. assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) + ray_worker_outputs = [] else: # Start the ray workers first. ray_worker_outputs = [ @@ -234,6 +234,13 @@ def _run_workers( ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( @@ -260,6 +267,11 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + def _compiled_ray_dag(self): import pkg_resources required_version = "2.9" @@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.driver_executor = make_async(self.driver_worker.execute_method) + self.driver_exec_method = make_async(self.driver_worker.execute_method) - async def _run_workers_async( + async def _driver_execute_model_async( self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - coros.append( - self.driver_executor(method, *driver_args, **driver_kwargs)) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) - - all_outputs = await asyncio.gather(*coros) - return all_outputs + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 9628f7af5315a..c2b22f2acd7b4 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -47,7 +47,9 @@ def set_include_gpu_probs_tensor(self): # NGram don't need gpu sampler pass - def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None) -> None: """NGram doesn't depend on model execution, just pass this function""" pass diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ef17b8c1e2cc0..3462a876c3e90 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -231,35 +231,6 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def _broadcast_control_flow_decision( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - disable_all_speculation: bool = False) -> Tuple[int, bool]: - """Broadcast how many lookahead slots are scheduled for this step, and - whether all speculation is disabled, to all non-driver workers. - - This is required as if the number of draft model runs changes - dynamically, the non-driver workers won't know unless we perform a - communication to inform then. - - Returns the broadcasted num_lookahead_slots and disable_all_speculation. - """ - - if self.rank == self._driver_rank: - assert execute_model_req is not None - - broadcast_dict = dict( - num_lookahead_slots=execute_model_req.num_lookahead_slots, - disable_all_speculation=disable_all_speculation, - ) - broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) - else: - assert execute_model_req is None - broadcast_dict = broadcast_tensor_dict(src=self._driver_rank) - - return (broadcast_dict["num_lookahead_slots"], - broadcast_dict["disable_all_speculation"]) - @torch.inference_mode() def execute_model( self, @@ -267,39 +238,58 @@ def execute_model( ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ + if self.rank != self._driver_rank: + self._run_non_driver_rank() + return [] - disable_all_speculation = False - if self.rank == self._driver_rank: - disable_all_speculation = self._should_disable_all_speculation( - execute_model_req) - - (num_lookahead_slots, - disable_all_speculation) = self._broadcast_control_flow_decision( - execute_model_req, disable_all_speculation) - - if self.rank == self._driver_rank: - assert execute_model_req is not None - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding requires non-None seq_group_metadata_list" - ) - - self._maybe_disable_speculative_tokens( - disable_all_speculation, - execute_model_req.seq_group_metadata_list) - - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. - if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all_speculation) - - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) - else: - self._run_non_driver_rank(num_lookahead_slots) + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) return [] + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform then. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + def _should_disable_all_speculation( self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding @@ -346,13 +336,19 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.logprobs = None return [sampler_output] - def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: + def _run_non_driver_rank(self) -> bool: """Run proposer and verifier model in non-driver workers. This is used for both speculation cases (num_lookahead_slots>0) and non-speculation cases (e.g. prefill). + + Returns True iff there are remaining sequences to process. """ - # In non-driver workers the input is None - execute_model_req = None + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] # Even if num_lookahead_slots is zero, we want to run the proposer model # as it may have KV. @@ -360,9 +356,10 @@ def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: # We run the proposer once per lookahead slot. In the future we should # delegate how many times it runs to the proposer. for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model(execute_model_req) + self.proposer_worker.execute_model() - self.scorer_worker.execute_model(execute_model_req) + self.scorer_worker.execute_model() + return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 91f30978ead87..ef02de95fc54e 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -47,7 +47,7 @@ def __init__( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, @@ -84,10 +84,11 @@ def execute_model( def prepare_input_tensors( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: + assert seq_group_metadata_list is not None # Prepare input tensors. ( input_tokens, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9720363ac300e..87d5f5c1b9d67 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -609,10 +609,11 @@ def _prepare_model_input( def prepare_input_tensors( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: + assert seq_group_metadata_list is not None # Prepare input tensors. ( input_tokens, @@ -676,7 +677,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 97b3873b2a9f6..10411a2bf7a10 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -226,48 +226,42 @@ def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[Union[SamplerOutput, PoolerOutput]]: + if not self.is_driver_worker: + self._execute_model_non_driver() + return [] if execute_model_req is None: - seq_group_metadata_list = None - else: - seq_group_metadata_list = execute_model_req.seq_group_metadata_list + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] - blocks_to_swap_in: torch.Tensor - blocks_to_swap_out: torch.Tensor - blocks_to_copy: torch.Tensor - if self.is_driver_worker: - assert seq_group_metadata_list is not None - assert execute_model_req is not None - num_seq_groups = len(seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor( - execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor( - execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + num_seq_groups = len(seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", dtype=torch.int64).view(-1, 2) - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_swap_in = data["blocks_to_swap_in"] - blocks_to_swap_out = data["blocks_to_swap_out"] - blocks_to_copy = data["blocks_to_copy"] + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -282,6 +276,39 @@ def execute_model( # to conform to interface. return [output] + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while self._execute_model_non_driver(): + pass + + def _execute_model_non_driver(self) -> bool: + """Execute model in parallel worker. + + Returns True iff there are remaining sequences to process. + """ + assert not self.is_driver_worker + data = broadcast_tensor_dict(src=0) + if not data: + return False + + num_seq_groups = data.get("num_seq_groups", 0) + blocks_to_swap_in = data.get("blocks_to_swap_in") + blocks_to_swap_out = data.get("blocks_to_swap_out") + blocks_to_copy = data.get("blocks_to_copy") + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return False + + self.model_runner.execute_model(None, self.gpu_cache) + return True + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1f04f821eb0f0..dbac1b5ba339b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,7 +1,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -48,8 +48,9 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError From a36de682d4283c60777bc3022ed3ce71cd90b904 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 22 May 2024 15:26:56 -0700 Subject: [PATCH 332/413] [Minor] Fix small typo in llama.py: QKVParallelLinear -> QuantizationConfig (#4991) --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 23141124e69e1..f43a40a0bfd34 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -57,7 +57,7 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QKVParallelLinear] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, ) -> None: super().__init__() From ee3eea0a1b2c690557455d97074d8829d5a98320 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 22 May 2024 15:55:56 -0700 Subject: [PATCH 333/413] [Misc] Take user preference in attention selector (#4960) --- tests/kernels/test_attention_selector.py | 84 +++++++++++++ vllm/attention/backends/flashinfer.py | 1 + vllm/attention/selector.py | 145 +++++++++++++---------- 3 files changed, 169 insertions(+), 61 deletions(-) create mode 100644 tests/kernels/test_attention_selector.py diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py new file mode 100644 index 0000000000000..f439afa9b7d2b --- /dev/null +++ b/tests/kernels/test_attention_selector.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.selector import which_attn_to_use + + +@pytest.mark.parametrize( + "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +def test_env(name: str, device: str): + """Test that the attention selector can be set via environment variable. + Note that we do not test FlashAttn because it is the default backend. + """ + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = name + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_flash_attn(): + """Test FlashAttn validation.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" + + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_invalid_env(): + """Throw an exception if the backend name is invalid.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" + with pytest.raises(ValueError): + which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7210fefbd8162..7b7959d257fac 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -218,6 +218,7 @@ def forward( ) if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. assert prefill_meta.block_tables is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: output = flash_attn_varlen_func( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 51c25a81b4130..f191461dcd3b7 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -30,24 +30,16 @@ def get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, ) -> Type[AttentionBackend]: - backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size) + """Determine which attention backend to use and only import + the selected backend module. + """ + backend = which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - - # We check it here not in _which_attn_to_use because we cannot know - # the head size until we import FlashAttentionBackend. - supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size in supported_head_sizes: - logger.info("Using FlashAttention-2 backend.") - return FlashAttentionBackend - logger.info( - "Cannot use FlashAttention-2 backend for head size %d. " - "Using XFormers backend instead.", head_size) - backend = _Backend.XFORMERS - + return FlashAttentionBackend if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -64,14 +56,15 @@ def get_attn_backend( return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend.") + logger.warning("Eager mode is required for the Flashinfer backend. " + "Please make sure --enforce-eager is set.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use( +def which_attn_to_use( num_heads: int, head_size: int, num_kv_heads: int, @@ -81,54 +74,84 @@ def _which_attn_to_use( block_size: int, ) -> _Backend: """Returns which flash attention backend to use.""" + + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + backend_members = _Backend.__members__ + if backend_by_env_var not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_by_env_var}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + selected_backend = _Backend[backend_by_env_var] + if is_cpu(): + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA if is_hip(): # AMD GPUs. - if torch.cuda.get_device_capability()[0] != 9: - # not Instinct series GPUs. - logger.info("flash_atten is not supported on NAVI GPUs.") + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH - # NVIDIA GPUs. - if torch.cuda.get_device_capability()[0] < 8: - # Volta and Turing NVIDIA GPUs. - logger.info("Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - return _Backend.XFORMERS - - if dtype not in (torch.float16, torch.bfloat16): - logger.info("Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - return _Backend.XFORMERS - - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - - try: - import vllm_flash_attn # noqa: F401 - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the vllm_flash_attn " - "package is not found. `pip install vllm-flash-attn` for better " - "performance.") - return _Backend.XFORMERS - - backend_by_env_var = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - return _Backend[backend_by_env_var] - - # Default case. - return _Backend.FLASH_ATTN + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if torch.cuda.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache.") + selected_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + selected_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is installed. + if selected_backend == _Backend.FLASH_ATTN: + try: + import vllm_flash_attn # noqa: F401 + + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + selected_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm_flash_attn package is not found. " + "`pip install vllm-flash-attn` for better performance.") + selected_backend = _Backend.XFORMERS + + return selected_backend From 606625329648e6eff1883e23040adfad82f219cf Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 23 May 2024 02:39:27 -0400 Subject: [PATCH 334/413] Marlin 24 prefill performance improvement (about 25% better on average) (#4983) --- benchmarks/kernels/benchmark_marlin.py | 74 ++++++++++++++++--- .../marlin/sparse/marlin_24_cuda_kernel.cu | 55 ++++++++++---- tests/kernels/test_marlin_gemm.py | 2 +- .../layers/quantization/gptq_marlin_24.py | 8 +- 4 files changed, 107 insertions(+), 32 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 5dcffc284f3d4..b771911781574 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -6,9 +6,13 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, marlin_quantize) + MarlinWorkspace, marlin_24_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) @@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, marlin_rand_perm, ) = marlin_quantize(b, num_bits, group_size, act_order) + # Marlin_24 quant + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + # GPTQ quant (w_ref, q_w, s, g_idx, rand_perm) = quantize_weights(b, num_bits, group_size, act_order) @@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare - marlin_workspace = MarlinWorkspace(size_n) + marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) globals = { + # Gen params + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "a": a, + "a_tmp": a_tmp, + # Marlin params "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, "marlin_rand_perm": marlin_rand_perm, + "marlin_workspace": marlin_workspace, + "is_k_full": is_k_full, + # Marlin_24 params + "marlin_24_w_ref": marlin_24_w_ref, + "marlin_24_q_w_comp": marlin_24_q_w_comp, + "marlin_24_meta": marlin_24_meta, + "marlin_24_s": marlin_24_s, + "marlin_24_workspace": marlin_24_workspace, + # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, - "num_bits": num_bits, - "group_size": group_size, - "size_m": size_m, - "size_n": size_n, - "size_k": size_k, - "is_k_full": is_k_full, - "a": a, - "a_tmp": a_tmp, + # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, - "marlin_workspace": marlin_workspace, } min_run_time = 1 @@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, description="gptq_marlin_gemm", ).blocked_autorange(min_run_time=min_run_time)) + if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_24_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + results.append( benchmark.Timer( stmt= @@ -135,8 +170,20 @@ def main(args): continue for act_order in ACT_ORDER_OPTS: + if len(args.limit_act_order + ) > 0 and act_order not in args.limit_act_order: + continue + for is_k_full in K_FULL_OPTS: + if len(args.limit_k_full + ) > 0 and is_k_full not in args.limit_k_full: + continue + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + if len(args.limit_num_bits + ) > 0 and num_bits not in args.limit_num_bits: + continue + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: if len( args.limit_group_size @@ -159,7 +206,7 @@ def main(args): # For quick benchmarking use: -# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 # if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -178,6 +225,9 @@ def main(args): parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) + parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) + parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) args = parser.parse_args() main(args) diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 54ad27676e207..686dd7851e6af 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -48,12 +48,12 @@ namespace marlin_24 { // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int THREADS = 256; -static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory +static constexpr int STAGES = 4; static constexpr int min_thread_n = 128; static constexpr int tile_size = 16; -static constexpr int max_par = 16; +static constexpr int max_par = 64; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -736,10 +736,10 @@ __global__ void Marlin_24( for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + matmul(pipe); wait_for_stage(); fetch_to_registers(pipe + 1, (pipe + 1) % stages); - matmul(pipe); pipe++; slice_iters--; @@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, // than better compute utilization thread_k = 128; thread_m = 128; - } else { + } else if (prob_n <= 256) { thread_k = 64; thread_m = 256; + } else { + thread_k = 32; + thread_m = 512; } } @@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; + constexpr int max_m_blocks = 4; + int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += 4) { + for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; int par = 1; - if (thread_n_blocks > 4) { + if (thread_n_blocks > max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * thread_n_blocks - pad) / 64; + par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); if (par > max_par) par = max_par; - prob_n = 64 * par; - i += 4 * (par - 1); - thread_n_blocks = 4; + prob_n = (max_m_blocks * 16) * par; + i += max_m_blocks * (par - 1); + thread_n_blocks = max_m_blocks; } // For compilation speed, we only define the kernel configurations that have @@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, if (false) { } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 @@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(4, 16, 4, 2, -1) CALL_IF_2_4(4, 16, 4, 2, 4) + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 2, 1, 4) + CALL_IF_2_4(4, 32, 3, 1, -1) + CALL_IF_2_4(4, 32, 3, 1, 4) + CALL_IF_2_4(4, 32, 4, 1, -1) + CALL_IF_2_4(4, 32, 4, 1, 4) + // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 @@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(8, 16, 3, 2, 4) CALL_IF_2_4(8, 16, 4, 2, -1) CALL_IF_2_4(8, 16, 4, 2, 4) + + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 2, 1, 4) + CALL_IF_2_4(8, 32, 3, 1, -1) + CALL_IF_2_4(8, 32, 3, 1, 4) + CALL_IF_2_4(8, 32, 4, 1, -1) + CALL_IF_2_4(8, 32, 4, 1, 4) else { throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]" + @@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int thread_k = -1; int thread_m = -1; int sms = -1; - int max_par = 16; + int max_par = marlin_24::max_par; int groupsize = -1; if (b_scales.size(0) > 1) { diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 587fc3901eb7c..1f8d94bad26d9 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -27,7 +27,7 @@ MARLIN_N_CHUNKS = [64, 128, 256] MARLIN_24_K_CHUNKS = [128] -MARLIN_24_N_CHUNKS = [256] +MARLIN_24_N_CHUNKS = [512] MNK_FACTORS = [ (1, 1, 1), diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index f5345c0443029..6bcfc405afe71 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -15,7 +15,7 @@ GPTQ_MARLIN_24_TILE = 16 GPTQ_MARLIN_24_MIN_THREAD_N = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 16 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] @@ -53,14 +53,14 @@ def __init__( self.tile_size = 16 # Min out_features dim - self.min_n_threads = 128 + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N # Min in_features dim - self.min_k_threads = 128 + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K # Max parallel problems to solve at once (improves large # batch performance) - self.max_parallel = 16 + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL # Permutation length used by the marlin kernels. self.perm_len = 1024 From 2ba80bed2732edf42b1014ea4e34757849fc93d0 Mon Sep 17 00:00:00 2001 From: Letian Li Date: Thu, 23 May 2024 17:08:58 +0100 Subject: [PATCH 335/413] [Bugfix] Update Dockerfile.cpu to fix NameError: name 'vllm_ops' is not defined (#5009) --- .buildkite/run-cpu-test.sh | 2 +- Dockerfile.cpu | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f187d1f181724..414045fe163e5 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -11,4 +11,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py +docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4251fddd6cc3b..aec79824213f3 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -17,4 +17,6 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +WORKDIR /workspace/ + CMD ["/bin/bash"] From 5eda2ea02a01b2457f4d6ac2a217f2fa8a2e5d5f Mon Sep 17 00:00:00 2001 From: Murali Andoorveedu <37849411+andoorve@users.noreply.github.com> Date: Thu, 23 May 2024 09:54:48 -0700 Subject: [PATCH 336/413] [Core][1/N] Support send/recv in PyNCCL Groups (#4988) Signed-off-by: Muralidhar Andoorveedu --- tests/distributed/test_pynccl.py | 75 +++++++++++++++++-- vllm/distributed/communication_op.py | 18 +++-- .../device_communicators/pynccl.py | 34 +++++++++ .../device_communicators/pynccl_wrapper.py | 26 +++++++ vllm/distributed/parallel_state.py | 34 +++++++-- 5 files changed, 170 insertions(+), 17 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 529e75fb2c9e3..0218295a3e3f9 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,6 +3,7 @@ import pytest import torch +import torch.distributed from vllm.distributed.communication_op import ( # noqa graph_capture, tensor_model_parallel_all_reduce) @@ -68,7 +69,7 @@ def test_pynccl(): @worker_fn_wrapper -def multiple_tp_worker_fn(): +def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") groups = [ torch.distributed.new_group(ranks=[0, 1], backend="gloo"), @@ -92,14 +93,14 @@ def multiple_tp_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") -def test_pynccl_multiple_tp(): +def test_pynccl_multiple_allreduce(): # this tests pynccl for multiple tp groups, in a standalone way # i.e. call `pynccl_comm.all_reduce` directly - distributed_run(multiple_tp_worker_fn, 4) + distributed_run(multiple_allreduce_worker_fn, 4) @worker_fn_wrapper -def multiple_tp_with_vllm_worker_fn(): +def multiple_allreduce_with_vllm_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") ensure_model_parallel_initialized(2, 2) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) @@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Need at least 4 GPUs to run the test.") -def test_pynccl_multiple_tp_with_vllm(): +def test_pynccl_multiple_allreduce_with_vllm(): # this tests pynccl for multiple tp groups, together with vllm # i.e. call `tensor_model_parallel_all_reduce` - distributed_run(multiple_tp_with_vllm_worker_fn, 4) + distributed_run(multiple_allreduce_with_vllm_worker_fn, 4) @worker_fn_wrapper @@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph(): distributed_run(worker_fn_with_cudagraph, 2) +@worker_fn_wrapper +def send_recv_worker_fn(): + pynccl_comm = PyNcclCommunicator() + if pynccl_comm.rank == 0: + tensor = torch.ones(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + else: + tensor = torch.empty(16, 1024, 1024, + dtype=torch.float32).cuda(pynccl_comm.rank) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + assert result == 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_send_recv(): + distributed_run(send_recv_worker_fn, 2) + + +@worker_fn_wrapper +def multiple_send_recv_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 2], backend="gloo"), + torch.distributed.new_group(ranks=[1, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] + pynccl_comm = PyNcclCommunicator(group=group, device=device) + if torch.distributed.get_rank() == 0: + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + elif torch.distributed.get_rank() == 1: + tensor = 2 * torch.ones( + 16, 1024, 1024, dtype=torch.float32, device=device) + else: + tensor = torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=device) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor) + else: + pynccl_comm.recv(tensor) + result = tensor.mean().cpu().item() + if torch.distributed.get_rank() in [0, 2]: + assert result == 1 + else: + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_send_recv(): + distributed_run(multiple_send_recv_worker_fn, 4) + + def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 937fd4d392713..2b38ec472de66 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -6,7 +6,7 @@ import torch from torch.distributed import ProcessGroup -from .parallel_state import (get_cpu_world_group, +from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -54,13 +54,19 @@ def graph_capture(): # graph, we use either custom all-reduce kernel or PyTorch NCCL. # We always prioritize using custom all-reduce kernel but fall back # to PyTorch or pynccl if it is disabled or not supported. - pynccl_comm = get_tp_pynccl_communicator() - if pynccl_comm is None: - maybe_pynccl_context = nullcontext() + tp_pynccl_comm = get_tp_pynccl_communicator() + pp_pynccl_comm = get_pp_pynccl_communicator() + if not tp_pynccl_comm: + maybe_tp_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( + maybe_tp_pynccl_context = tp_pynccl_comm.change_state( enable=True, stream=torch.cuda.current_stream()) - with maybe_pynccl_context: + if not pp_pynccl_comm: + maybe_pp_pynccl_context = nullcontext() + else: + maybe_pp_pynccl_context = pp_pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_tp_pynccl_context, maybe_pp_pynccl_context: yield graph_capture_context diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 092a0910329ad..f5f1de0c71615 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -126,6 +126,40 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def send(self, + tensor: torch.Tensor, + dst: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if dst is None: + dst = (self.rank + 1) % self.world_size + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, + tensor: torch.Tensor, + src: Optional[int] = None, + stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src is None: + src = (self.rank - 1) % self.world_size + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + @contextmanager def change_state(self, enable: Optional[bool] = None, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 43d85674b23d0..3aa3744d0d827 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,22 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -248,6 +264,16 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d24104e3ed276..0ebd7a15eab9b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,8 @@ _TP_CA_COMMUNICATOR = None # Pipeline model parallel group that the current rank belongs to. _PP_DEVICE_GROUP: Optional[ProcessGroup] = None +_PP_CPU_GROUP: Optional[ProcessGroup] = None +_PP_PYNCCL_COMMUNICATOR = None # when people blindly call `torch.distributed.all_reduce` etc, # it will use this group. It is initialized with the `backend` @@ -55,6 +57,11 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def get_pp_pynccl_communicator(): + global _PP_PYNCCL_COMMUNICATOR + return _PP_PYNCCL_COMMUNICATOR + + def get_tp_pynccl_communicator(): global _TP_PYNCCL_COMMUNICATOR return _TP_PYNCCL_COMMUNICATOR @@ -180,10 +187,11 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( - group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - ) + if tensor_model_parallel_size > 1: + _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) # Initialize a custom fast all-reduce implementation. if _ENABLE_CUSTOM_ALL_REDUCE: @@ -195,17 +203,26 @@ def initialize_model_parallel( ) # Build the pipeline model-parallel groups. - global _PP_DEVICE_GROUP + global _PP_DEVICE_GROUP, _PP_CPU_GROUP + global _PP_PYNCCL_COMMUNICATOR global _PP_GLOBAL_RANKS assert _PP_DEVICE_GROUP is None, ( "pipeline model parallel group is already initialized") for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: _PP_DEVICE_GROUP = group + _PP_CPU_GROUP = cpu_group _PP_GLOBAL_RANKS = ranks + if pipeline_model_parallel_size > 1: + _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( + group=_PP_CPU_GROUP, + device=_LOCAL_RANK, + ) + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, @@ -267,6 +284,13 @@ def get_pipeline_model_parallel_group(): return _PP_DEVICE_GROUP +def get_pipeline_model_parallel_cpu_group(): + """Get the pipeline model parallel cpu group the caller rank belongs to.""" + assert _PP_CPU_GROUP is not None, ( + "pipeline model parallel cpu group is not initialized") + return _PP_CPU_GROUP + + def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" return torch.distributed.get_world_size( From a1242324c99ff8b1e29981006dfb504da198c7c3 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 23 May 2024 17:29:18 -0400 Subject: [PATCH 337/413] [Kernel] Initial Activation Quantization Support (#4525) Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- CMakeLists.txt | 1 + csrc/ops.h | 3 + csrc/pybind.cpp | 3 + .../compressed_tensors/int8_quant_kernels.cu | 59 +++++ tests/kernels/test_int8_quant.py | 31 +++ tests/quantization/test_compressed_tensors.py | 36 +++ vllm/_custom_ops.py | 18 ++ vllm/model_executor/layers/linear.py | 244 ++++++++++++------ .../layers/quantization/__init__.py | 3 + .../compressed_tensors/__init__.py | 0 .../compressed_tensors/compressed_tensors.py | 151 +++++++++++ .../compressed_tensors/schemes/__init__.py | 5 + .../schemes/compressed_tensors_scheme.py | 33 +++ .../schemes/compressed_tensors_unquantized.py | 39 +++ .../compressed_tensors_w8a8_statictensor.py | 119 +++++++++ .../model_loader/weight_utils.py | 7 + vllm/model_executor/models/llama.py | 25 +- 17 files changed, 683 insertions(+), 94 deletions(-) create mode 100644 csrc/quantization/compressed_tensors/int8_quant_kernels.cu create mode 100644 tests/kernels/test_int8_quant.py create mode 100644 tests/quantization/test_compressed_tensors.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 35846fd1cfa99..b668cbc97de15 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index f5e0e423bb65d..b839eaf0d26c8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -93,6 +93,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, #endif +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, + float scale); + void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index cba07f0ae9f2a..cdbec4a34d77f 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,6 +67,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); + ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, + "Compute int8 quantized tensor for given scaling factor"); + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def("swap_blocks", &swap_blocks, diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu new file mode 100644 index 0000000000000..4902e4c23434c --- /dev/null +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../../dispatch_utils.h" + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + const scalar_t* __restrict__ input, int8_t* __restrict__ out, + scale_type scale, const int hidden_size) { + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = + float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); + } +} +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + float scale) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + vllm::static_scaled_int8_quant_kernel + <<>>(input.data_ptr(), + out.data_ptr(), scale, + hidden_size); + }); +} diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py new file mode 100644 index 0000000000000..b9aa00ce13f56 --- /dev/null +++ b/tests/kernels/test_int8_quant.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from vllm._C import ops + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SEEDS = [0] +SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("scale", SCALE) +@torch.inference_mode() +def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, + seed: int, scale: float) -> None: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 + + out1 = (x / scale).round().clamp( + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max).to(torch.int8) + out2 = torch.empty_like(x, dtype=torch.int8) + ops.static_scaled_int8_quant(out2, x, scale) + assert torch.allclose(out1, out2, + atol=1) # big atol to account for rounding errors diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py new file mode 100644 index 0000000000000..b83286992da3d --- /dev/null +++ b/tests/quantization/test_compressed_tensors.py @@ -0,0 +1,36 @@ +"""Test model set-up and weight loading for sparseml-quantized models. + +Run `pytest tests/quantization/test_compressed_tensors.py`. +""" + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) + + +def test_compressed_tensors_w8a8_static_setup(vllm_runner): + model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" + llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) + + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + + assert qkv_proj.weight.dtype is torch.int8 + assert o_proj.weight.dtype is torch.int8 + assert gate_up_proj.weight.dtype is torch.int8 + + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None + assert qkv_proj.input_scale.dtype is torch.float32 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9e7d0d96bf004..f0fab4d8aa26d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -251,6 +251,24 @@ def scaled_fp8_quant( return output, scale +# int8 +def static_scaled_int8_quant(input: torch.Tensor, + scale: float) -> torch.Tensor: + """ + Quantize the input tensor to int8 and return the quantized tensor. + + Args: + input: The input tensor to be quantized to int8. + scale: Scaling factor for the int8 quantization. + + Returns: + torch.Tensor: Output tensor in int8. + """ + q = torch.empty_like(input, dtype=torch.int8) + vllm_ops.static_scaled_int8_quant(q, input, scale) + return q + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7726dcb9a5fbd..34fbfa8e33ef9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -56,7 +56,6 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" raise NotImplementedError @@ -77,8 +76,7 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - output_size_per_partition = sum(output_partition_sizes) - weight = Parameter(torch.empty(output_size_per_partition, + weight = Parameter(torch.empty(sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype), requires_grad=False) @@ -149,15 +147,13 @@ class ReplicatedLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -210,17 +206,15 @@ class ColumnParallelLinear(LinearBase): the list would be size 3. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -228,18 +222,26 @@ def __init__( # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, tp_size) + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + if output_sizes is None: output_sizes = [output_size] - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -317,22 +319,24 @@ class MergedColumnParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, quant_config, - self.output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -343,6 +347,26 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) @@ -403,6 +427,13 @@ def weight_loader(self, shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( @@ -415,6 +446,14 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") + + if fp8_scales_shard_indexer is None: + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -443,17 +482,15 @@ class QKVParallelLinear(ColumnParallelLinear): quant_config: Quantization configure. """ - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -473,14 +510,19 @@ def __init__( input_size = self.hidden_size output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - output_sizes = [ - self.num_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size, - self.num_kv_heads * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj ] - super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, quant_config, output_sizes) + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config) def weight_loader(self, param: Parameter, @@ -490,6 +532,26 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # Special case for Fp8 scales. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", None) @@ -528,6 +590,8 @@ def weight_loader(self, tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 @@ -567,6 +631,12 @@ def weight_loader(self, shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer( @@ -578,6 +648,13 @@ def weight_loader(self, "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions.") + + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -608,17 +685,15 @@ class RowParallelLinear(LinearBase): quant_config: Quantization configure. """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -628,16 +703,15 @@ def __init__( # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - # All the linear layer supports quant method. assert self.quant_method is not None - self.quant_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) - + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") @@ -665,12 +739,16 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Special case for Fp8 scales. elif fp8_scales_shard_indexer is not None: param_data, loaded_weight = fp8_scales_shard_indexer(param_data, loaded_weight, shard_id=0) + if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f938e7d37ec5f..7b9abe1b629a1 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config @@ -27,6 +29,7 @@ "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "sparseml": CompressedTensorsConfig, } diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000000000..19e464bd64325 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]): + self.ignore = ignore + self.layer_quant_details = layer_quant_details + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16] + + # Need to figure it out + def get_min_capability(self) -> int: + return 60 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["CompressedTensorsLinearMethod"]: + if isinstance(layer, LinearBase): + return CompressedTensorsLinearMethod(self) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + layer_quant_details: Dict[str, Any] = dict() + ignore: List[str] = config.get("ignore", None) + + for key, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + layer_quant_details[target] = {} + layer_quant_details[target]["weight"] = quant_config.get( + "weights") + layer_quant_details[target]["input"] = quant_config.get( + "input_activations") + + return cls(layer_quant_details=layer_quant_details, ignore=ignore) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _get_schema(self, weight_quant: Dict, input_quant: Dict): + # TODO: Refactor as additional cases are supported + + weight_bit = weight_quant.get("num_bits") + input_bit = input_quant.get("num_bits") + + weight_strategy = weight_quant.get("strategy") + input_strategy = input_quant.get("strategy") + + weight_symmetric = weight_quant.get("symmetric") + input_symmetric = input_quant.get("symmetric") + + is_8_bits = weight_bit == input_bit == 8 + is_tensor = weight_strategy == input_strategy == "tensor" + is_symmetric = weight_symmetric and input_symmetric + + if is_8_bits and is_tensor and is_symmetric and \ + torch.cuda.is_available(): + # CompressedTensorsW8A8StaticTensor only supports CUDA path for + # now. + return CompressedTensorsW8A8StaticTensor() + raise NotImplementedError( + "Scheme not supported. Only CUDA, 8-bit static symmtetric " + "per tensor quantization is currently supported") + + def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": + + # TODO: update with matching function from `compressed_tensors` + layer_type_name = None + layer_name_class = type(layer).__name__.lower() + for target in self.layer_quant_details: + if target.lower() in layer_name_class: + layer_type_name = target + break + if layer_type_name is None: + raise ValueError(f"Could not matching target for layer {layer}") + + layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( + layer_type_name, None) + if layer_quant_details is None: + raise ValueError( + f"Could not find quantization details for {layer}.") + + return self._get_schema(weight_quant=layer_quant_details["weight"], + input_quant=layer_quant_details["input"]) + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. + """ + weight_loader = extra_weight_attrs.get("weight_loader") + + scheme = self.quantization_config.get_scheme(layer=layer) + scheme.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + layer.scheme = scheme + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. + """ + + if bias is not None: + raise ValueError("bias is not supported for this linear method") + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000000000..831905b63e2c9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,5 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 +from .compressed_tensors_unquantized import ( # noqa: F401 + CompressedTensorsUnquantized) +from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 + CompressedTensorsW8A8StaticTensor) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000000000..3a5904208656e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,33 @@ +from abc import ABC, abstractmethod + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: toch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py new file mode 100644 index 0000000000000..0cfac13d1ca25 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -0,0 +1,39 @@ +from typing import Callable, List + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsUnquantized"] + + +class CompressedTensorsUnquantized(CompressedTensorsScheme): + """ + Implements the scheme for all layers which are ignored + in the CompressedTensors config. The input and loaded weight are used + in a linear transformation. + """ + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + device="cuda", + dtype=params_dtype), + requires_grad=False) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + return F.linear(x, weight) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py new file mode 100644 index 0000000000000..d16e570d12202 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -0,0 +1,119 @@ +from typing import Callable, List, Tuple, Union + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as custom_ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW8A8StaticTensor"] + + +class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + assert isinstance(shard_id, str) + qkv_idxs = {"q": 0, "k": 1, "v": 2} + assert shard_id in qkv_idxs + return qkv_idxs[shard_id] + + def scales_shard_splitter( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int], + logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shard_id = self._shard_id_as_int(shard_id) + offset = sum(logical_widths[:shard_id]) + size = logical_widths[shard_id] + # update loaded weight with copies for broadcast. + loaded_weight = loaded_weight.repeat(size) + return param[offset:offset + size], loaded_weight + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + # TODO: remove zero_point parameters once the configs given remove them + + # Note on input/weight scales and zero_points + # + # When the scales have a single value, it is required that they be + # on the CPU for 2 reasons, + # 1. Performance: + # When the scales (input_scale/weight_scales) have only a single + # value, we perform a scalar broadcast of that value during the + # quant/dequant operations. The "quant" and the "gemm+dequant" + # kernels accept the Scalar by-value. These tensors are allocated + # on the CPU in order to avoid the GPU-to-CPU copy when passing + # by-value. + # + # 2. CUDA Graphs: + # CUDA Graphs don't support GPU-to-CPU copy operations during + # stream capture. + # + # TODO: zero-points are not supported yet. But we expect a similar + # pattern. + + is_tensor_partitioned = len(output_partition_sizes) != 1 + weight_scale_dim = sum( + output_partition_sizes) if is_tensor_partitioned else 1 + weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda" + + input_scale = Parameter(torch.empty(1, + device="cpu", + dtype=torch.float32), + requires_grad=False) + input_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight_scale = Parameter(torch.empty(weight_scale_dim, + device=weight_scale_device, + dtype=torch.float32), + requires_grad=False) + weight_zero_point = Parameter(torch.empty(1, + device="cpu", + dtype=torch.int8), + requires_grad=False) + + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + layer.register_parameter("input_scale", input_scale) + set_weight_attrs(input_scale, {"weight_loader": weight_loader}) + layer.register_parameter("input_zero_point", input_zero_point) + set_weight_attrs(input_zero_point, {"weight_loader": weight_loader}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs( + weight_scale, { + "shard_splitter": self.scales_shard_splitter, + "logical_widths": output_partition_sizes + }) + layer.register_parameter("weight_zero_point", weight_zero_point) + set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + weight = layer.weight + weight_scale = layer.weight_scale + act_scale = layer.input_scale + + # Input quantize + x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + + return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, + weight_scale, x.dtype) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a1642baa2c90c..f9b1dc60dd006 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -120,6 +120,13 @@ def get_quant_config(model_config: ModelConfig, # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + if hf_quant_config is None: + compression_config = getattr(model_config.hf_config, + "compression_config", None) + if compression_config is not None: + hf_quant_config = compression_config.get("quantization_config", + None) + if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) model_name_or_path = model_config.model diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f43a40a0bfd34..086f9294c4f1c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -62,11 +62,12 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, bias=bias, quant_config=quant_config) if hidden_act != "silu": @@ -120,16 +121,16 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, ) self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, bias=bias, quant_config=quant_config, ) @@ -263,8 +264,10 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config) + for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From e3470f87538ec86d1094ac4747519c6300213088 Mon Sep 17 00:00:00 2001 From: Elisei Smirnov <61423871+kezouke@users.noreply.github.com> Date: Fri, 24 May 2024 01:04:24 +0300 Subject: [PATCH 338/413] [Core]: Option To Use Prompt Token Ids Inside Logits Processor (#4985) Co-authored-by: Elisei Smirnov --- vllm/model_executor/layers/logits_processor.py | 17 ++++++++++++++--- vllm/sampling_params.py | 15 ++++++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 91eb96998c3cf..d450c46455d49 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,4 +1,5 @@ """A layer that compute logits from hidden_stats.""" +import inspect from typing import Optional import torch @@ -95,15 +96,25 @@ def _apply_logits_processors( seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params logits_processors = sampling_params.logits_processors - if logits_processors: found_logits_processors = True + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): logits_row = logits[logits_row_idx] - token_ids = seq_group.seq_data[seq_id].output_token_ids + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, + past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, + logits_row) + logits[logits_row_idx] = logits_row logits_processed += len(seq_group.sample_indices) + len( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5fa94eb149ffb..9d8a361353e26 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -18,10 +18,14 @@ class SamplingType(IntEnum): BEAM = 3 -LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] -"""LogitsProcessor is a function that takes a list of previously generated -tokens and a tensor of the logits for the next token, and returns a modified -tensor of logits to sample from.""" +LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], + Callable[[List[int], List[int], torch.Tensor], + torch.Tensor]] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" class SamplingParams: @@ -95,7 +99,8 @@ class SamplingParams: spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. logits_processors: List of functions that modify logits based on - previously generated tokens. + previously generated tokens, and optionally prompt tokens as + a first argument. truncate_prompt_tokens: If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). From 6a50f4cafaf9f734b3f6ad11e6af38838aa3baf8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 23 May 2024 16:21:54 -0700 Subject: [PATCH 339/413] [Doc] add ccache guide in doc (#5012) Co-authored-by: Michael Goin --- docs/source/getting_started/installation.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 0c81f7ec6d2a9..ba23e7468dcc1 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -56,6 +56,10 @@ You can also build and install vLLM from source: $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. +.. tip:: + + Building from source requires quite a lot compilation. If you are building from source for multiple times, it is beneficial to cache the compilation results. For example, you can install `ccache `_ via either `conda install ccache` or `apt install ccache` . As long as `which ccache` command can find the `ccache` binary, it will be used automatically by the build system. After the first build, the subsequent builds will be much faster. + .. tip:: To avoid your system being overloaded, you can limit the number of compilation jobs to be run simultaneously, via the environment variable `MAX_JOBS`. For example: From 919770957f26d71a5a6eda7a1a7443dfeb5ba0ee Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 24 May 2024 14:28:27 +0200 Subject: [PATCH 340/413] [Bugfix] Fix Mistral v0.3 Weight Loading (#5005) Co-authored-by: Cody Yu --- tests/models/test_mistral.py | 1 + vllm/model_executor/model_loader/loader.py | 17 ++++- .../model_loader/weight_utils.py | 64 ++++++++++++++++++- 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index d0a5bfbfcd922..76b248cf14e98 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -8,6 +8,7 @@ MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.3", ] diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 45ea8160a801b..b7b5b5e7695f4 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -23,7 +23,8 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( - download_weights_from_hf, filter_files_not_needed_for_inference, + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.vlm_base import VisionLanguageModelBase @@ -188,7 +189,19 @@ def _prepare_weights(self, model_name_or_path: str, use_safetensors = True break - if not use_safetensors: + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, self.load_config.download_dir, + revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder) + else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f9b1dc60dd006..53e21eba8fae3 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -12,9 +12,10 @@ import huggingface_hub.constants import numpy as np import torch -from huggingface_hub import HfFileSystem, snapshot_download +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger @@ -218,6 +219,67 @@ def download_weights_from_hf( return hf_folder +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have SAFE_WEIGHTS_INDEX_NAME. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files(hf_weights_files: List[str], + hf_folder: str) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as index_file: + weight_map = json.load(index_file)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [ + f for f in hf_weights_files if f in weight_files_in_index + ] + return hf_weights_files + + def filter_files_not_needed_for_inference( hf_weights_files: List[str]) -> List[str]: """ From e64fde4b013cb8bb2321f59ba78aca50b02071cb Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Sat, 25 May 2024 01:07:09 +0800 Subject: [PATCH 341/413] [Core][Bugfix]: fix prefix caching for blockv2 (#4764) Co-authored-by: Lei Wen --- tests/core/block/test_prefix_caching_block.py | 117 ++++++++++++++++++ vllm/core/block/prefix_caching_block.py | 41 +++--- 2 files changed, 141 insertions(+), 17 deletions(-) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index c4c680e109a84..bcf08cda09f46 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -410,6 +410,123 @@ def test_get_common_computed_block_ids(num_blocks: int, block_size: int, assert (len(res) == zero_point_blocks) + # Test case that assume those prompted block after first immutable would + # be freed into hashless allocator, while first immutable block get ref + # increased. + @staticmethod + @pytest.mark.parametrize("num_blocks", [3]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_alloc_promotion(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(block_size)) + + block = allocator.allocate_immutable(prev_block=None, + token_ids=token_ids) + + assert allocator._refcounter.get(block.block_id) == 1 + m = allocator.allocate_mutable(prev_block=None) + + block_id = m.block_id + for i in range(block_size): + m.append_token_ids([i]) + # After block get promoted to immutable from mutable, if there is + # already same content hash block, then it shall be released into + # hashless_allocator + # And first immutable block's ref get increased by 1 + assert m.block_id == block.block_id + assert block_id in allocator._hashless_allocator._free_block_indices + assert allocator._refcounter.get(block.block_id) == 2 + + # Test case when eviction and allocation are mixed, + # make sure they work as expected + @staticmethod + @pytest.mark.parametrize("num_blocks", [3]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_eviction_alloc_mixed(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + + all_blocks_list = [i for i in range(num_blocks)] + zero_ref = {i: 0 for i in range(num_blocks)} + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + token_ids = list(range(num_blocks * block_size)) + + # now we have num_blocks free blocks in hashless allocator + # with internal tracking list _blocks _cached_blocks and evictor + # empty and block's ref shall be 0 + assert list(allocator._hashless_allocator._free_block_indices + ) == all_blocks_list + assert len(allocator._blocks.keys()) == 0 + assert len(allocator._cached_blocks.values()) == 0 + assert len(allocator.evictor.free_table.keys()) == 0 + assert allocator._refcounter._refcounts == zero_ref + + # Allocate immutable chains with only one block residuled in + new_block = [] + for i in range(num_blocks): + block = allocator.allocate_immutable( + prev_block=None, + token_ids=token_ids[block_size * i:block_size * (i + 1)]) + new_block.append(block) + + # Free all blocks, and now all blocks shall be in the evictor + # there shall be no tracking data left in _blocks + # all blocks shall be tracked in _cached_blocks + # all blocks' ref shall be zero + for block in new_block: + allocator.free(block) + + assert len(allocator._blocks.keys()) == 0 + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert list(allocator._cached_blocks.values()) == all_blocks_list + assert list(allocator.evictor.free_table.keys()) == all_blocks_list + assert allocator._refcounter._refcounts == zero_ref + + # Allocate a mutable block, and the first block shall be evicted + # and set its content hash into None, ref to 1 + mutable = allocator.allocate_mutable(prev_block=None) + + assert mutable.block_id == 0 + assert mutable.content_hash is None + assert 0 in allocator._blocks + assert allocator._refcounter.get(0) == 1 + assert 0 not in allocator._cached_blocks + assert 0 not in allocator.evictor + + # Since this mutable block has no hash yet, it shall be released into + # hashless allocator + allocator.free(mutable) + + assert len(allocator._blocks.keys()) == 0 + assert allocator._refcounter._refcounts == zero_ref + assert 0 not in allocator._cached_blocks + assert 0 not in allocator.evictor + assert 0 in allocator._hashless_allocator._free_block_indices + + # when allocate immutable with first block_size tokens, we + # shall get free block from hashless allocator, thus no block left + # in hashless + block = allocator.allocate_immutable(prev_block=None, + token_ids=token_ids[:block_size]) + + assert block.block_id == 0 + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert 0 in allocator._blocks + assert 0 in allocator._cached_blocks.values() + assert allocator._refcounter.get(0) == 1 + assert 0 not in allocator.evictor + + # allocate mutable block again, it shall be popped from evictor + mutable = allocator.allocate_mutable(prev_block=None) + assert len(allocator._hashless_allocator._free_block_indices) == 0 + assert mutable.block_id not in allocator.evictor.free_table + assert allocator._refcounter.get(mutable.block_id) == 1 + # Test case where two last accessed times are equal @staticmethod @pytest.mark.parametrize("num_blocks", [1024]) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 882f301c1f697..4eb32f145b05b 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -160,21 +160,17 @@ def allocate_mutable(self, # If the evictor has blocks available for eviction, evict a block # and return it. if self.evictor.num_blocks > 0: + # here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list block_id, content_hash_to_evict = self.evictor.evict() - # Here we may have scenario that several blocks have - # the same content hash, but due to the latter coming block - # is coming from mutable to immutable path, their physical - # block is added into evictor. - # However in this case, we shall not pop the _cached_blocks, - # as the same content is still used by others, which means - # we need to check ref before decide to pop the list. - _block_id = self._cached_blocks[content_hash_to_evict] - refcount = self._refcounter.get(_block_id) - if refcount == 1: - self._cached_blocks.pop(content_hash_to_evict) - assert _block_id == block_id + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) self._refcounter.incr(block_id) @@ -199,7 +195,11 @@ def allocate_mutable(self, def _incr_refcount_cached_block(self, block: Block, block_id: BlockId) -> None: - # since block is already computed, mark it + # now _incr_refcount_cached_block comes from two place + # allocate_immutable/promote_to_immutable_block where hit + # _cached_blocks hash key. + # In both cases, it means that already exists a already + # computed block which shared with block now block.computed = True refcount = self._refcounter.incr(block_id) @@ -228,13 +228,19 @@ def _free_block_id_for_block(self, block_id: BlockId, block: Block) -> None: assert isinstance(block, PrefixCachingBlock) - if block.content_hash is None: + # if we comes from promote_to_immutable_block, it means that + # block.content_hash is never None. + # However we need to release the same content block, so that + # physical block could get reused. + if block.block_id != block_id or block.content_hash is None: refcount = self._refcounter.get(block_id) # We have fork case where block would get more than one ref, # so we cannot free it from tracking if ref cnt large than 1 - if refcount <= 1: - assert block.block_id is not None + assert block.block_id is not None + refcount = self._refcounter.get(block.block_id) + if refcount == 1: del self._blocks[block.block_id] + return self._hashless_allocator.free(block) refcount = self._refcounter.decr(block_id) @@ -317,7 +323,8 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: if block.content_hash not in self._cached_blocks: self._cached_blocks[block.content_hash] = block.block_id else: - self._free_block_id_for_block(block.block_id, block) + self._free_block_id_for_block( + self._cached_blocks[block.content_hash], block) self._incr_refcount_cached_block( block, self._cached_blocks[block.content_hash]) From 8e192ff967b44b186ea02d30e49fddf656fdfe50 Mon Sep 17 00:00:00 2001 From: Eric Xihui Lin Date: Sat, 25 May 2024 01:00:52 -0400 Subject: [PATCH 342/413] [Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799) Co-authored-by: beagleski Co-authored-by: bapatra Co-authored-by: Barun Patra Co-authored-by: Michael Goin --- csrc/attention/attention_kernels.cu | 185 ++++++-- csrc/cpu/attention.cpp | 37 +- csrc/ops.h | 35 +- docs/source/models/supported_models.rst | 4 + tests/kernels/test_blocksparse_attention.py | 442 +++++++++++++++++ vllm/_custom_ops.py | 30 +- vllm/attention/backends/abstract.py | 1 + vllm/attention/backends/blocksparse_attn.py | 410 ++++++++++++++++ vllm/attention/backends/flash_attn.py | 5 +- vllm/attention/backends/rocm_flash_attn.py | 5 +- vllm/attention/backends/torch_sdpa.py | 5 +- vllm/attention/backends/xformers.py | 5 +- vllm/attention/layer.py | 10 +- .../ops/blocksparse_attention/__init__.py | 0 .../blocksparse_attention_kernel.py | 423 +++++++++++++++++ .../ops/blocksparse_attention/interface.py | 238 ++++++++++ .../ops/blocksparse_attention/utils.py | 216 +++++++++ vllm/attention/ops/paged_attn.py | 25 +- vllm/attention/selector.py | 7 + vllm/entrypoints/openai/serving_engine.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/phi3_small.py | 447 ++++++++++++++++++ vllm/transformers_utils/config.py | 2 +- 23 files changed, 2446 insertions(+), 88 deletions(-) create mode 100644 tests/kernels/test_blocksparse_attention.py create mode 100644 vllm/attention/backends/blocksparse_attn.py create mode 100644 vllm/attention/ops/blocksparse_attention/__init__.py create mode 100644 vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py create mode 100644 vllm/attention/ops/blocksparse_attention/interface.py create mode 100644 vllm/attention/ops/blocksparse_attention/utils.py create mode 100644 vllm/model_executor/models/phi3_small.py diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index d6203174e7275..45edc3252380c 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -85,6 +85,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Grid: (num_heads, num_seqs, max_num_partitions). template // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -104,7 +105,9 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale) { + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -202,11 +205,55 @@ __device__ void paged_attention_kernel( // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { // NOTE(woosuk): The block number is stored in int32. However, we cast it to // int64 because int32 can lead to overflow when this variable is multiplied // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } const int64_t physical_block_number = static_cast(block_table[block_idx]); @@ -335,6 +382,15 @@ __device__ void paged_attention_kernel( // NOTE(woosuk): The block number is stored in int32. However, we cast it to // int64 because int32 can lead to overflow when this variable is multiplied // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } const int64_t physical_block_number = static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; @@ -441,8 +497,8 @@ __device__ void paged_attention_kernel( // Grid: (num_heads, num_seqs, 1). template + int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, + bool IS_BLOCK_SPARSE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale) { + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( + KV_DTYPE, IS_BLOCK_SPARSE>( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, kv_scale); + kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs, max_num_partitions). template __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float kv_scale) { + const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( + KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, kv_scale); + kv_block_stride, kv_head_stride, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). @@ -607,25 +672,32 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel< \ - T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \ + ((void*)vllm::paged_attention_v1_kernel), \ shared_mem_size); \ vllm::paged_attention_v1_kernel \ + NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \ <<>>( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - kv_scale); + kv_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template + vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, + int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale) { + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -691,23 +763,36 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v1_launcher( \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, kv_scale); + seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ - CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ - CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ - CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -727,18 +812,26 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int block_size, int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale){ + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} - DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, - CALL_V1_LAUNCHER_BLOCK_SIZE)} #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ vllm::paged_attention_v2_kernel \ + NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \ + PARTITION_SIZE> \ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, kv_scale); \ + kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -746,14 +839,17 @@ void paged_attention_v1( max_num_partitions); template + vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, + int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const c10::optional& alibi_slopes, float kv_scale) { + const c10::optional& alibi_slopes, float kv_scale, + const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -824,24 +920,36 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \ - paged_attention_v2_launcher( \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - kv_scale); + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ switch (block_size) { \ case 8: \ - CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -865,7 +973,10 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int block_size, int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale) { + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) } @@ -873,4 +984,4 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 54df69b7379d6..438e9bdb19f50 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher( } } // namespace -void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, - torch::Tensor& key_cache, torch::Tensor& value_cache, - int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, - int block_size, int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale) { +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher( } } // namespace -void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, - float scale, torch::Tensor& block_tables, - torch::Tensor& seq_lens, int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale) { +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); + TORCH_CHECK(blocksparse_vert_stride <= 1, + "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/ops.h b/csrc/ops.h index b839eaf0d26c8..567d9fae4bd2a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,23 +2,24 @@ #include -void paged_attention_v1(torch::Tensor& out, torch::Tensor& query, - torch::Tensor& key_cache, torch::Tensor& value_cache, - int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, - int block_size, int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale); - -void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, - float scale, torch::Tensor& block_tables, - torch::Tensor& seq_lens, int block_size, - int max_seq_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale); +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step); + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, + int max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, float epsilon); diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 31d4b53bd4409..e4bae80343a2c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -123,6 +123,10 @@ Alongside each architecture, we include some popular models that use it. - Phi-3 - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc. - + * - :code:`Phi3SmallForCausalLM` + - Phi-3-Small + - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. + - * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py new file mode 100644 index 0000000000000..9da13ca6e2310 --- /dev/null +++ b/tests/kernels/test_blocksparse_attention.py @@ -0,0 +1,442 @@ +import random +from typing import List, Optional, Tuple + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn) +from vllm.utils import get_max_shared_memory_bytes, is_hip + +from .allclose_default import get_default_atol, get_default_rtol + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = 2771 + +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +HEAD_SIZES = [64, 112] +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +CUDA_DEVICES = ['cuda:0'] +BLOCKSPARSE_LOCAL_BLOCKS = [16] +BLOCKSPARSE_VERT_STRIDES = [8] + +BLOCKSPARSE_BLOCK_SIZES = [64] +BLOCKSPARSE_HEADS_SLIDINGS = [0, 2, -1] +BLOCKSPARSE_HOMO_HEADS = [True, False] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 1, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + seq_len = int(seq_lens[i]) + + keys = [] + values = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + if blocksparse_vert_stride >= 1: + bsize = blocksparse_block_size + hsliding = blocksparse_head_sliding_step + vert = blocksparse_vert_stride + locals = blocksparse_local_blocks + qb = (seq_len - 1) // bsize + attn_mask = q.new_zeros( + (num_query_heads, 1, seq_len)).float() - torch.inf + for h in range(num_query_heads): + if hsliding >= 0: # slide with q heads + bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1 + else: # slide with kv heads + bs_offset = (tp_rank * num_kv_heads + + h // num_queries_per_kv) * (-hsliding) + 1 + for kb in range(qb + 1): + kj = kb * bsize + if (qb - kb) < locals or \ + (kb + bs_offset) % vert == 0: + attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0 + if alibi_bias is not None: + attn_mask += alibi_bias + else: + attn_mask = alibi_bias + + out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("version", ["v1", "v2"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) +@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) +@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) +@pytest.mark.parametrize("blocksparse_head_sliding_step", + BLOCKSPARSE_HEADS_SLIDINGS) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, + blocksparse_local_blocks: int, + blocksparse_vert_stride: int, + blocksparse_block_size: int, + blocksparse_head_sliding_step: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.rand(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + kv_scale = 1.0 + tp_rank = 0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) + elif version == "v2": + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank=tp_rank, + blocksparse_local_blocks=blocksparse_local_blocks, + blocksparse_vert_stride=blocksparse_vert_stride, + blocksparse_block_size=blocksparse_block_size, + blocksparse_head_sliding_step=blocksparse_head_sliding_step, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_key_cache, key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_value_cache, value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) +@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) +@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) +@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_varlen_blocksparse_attention_prefill( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + blocksparse_local_blocks: int, + blocksparse_vert_stride: int, + blocksparse_block_size: int, + blocksparse_homo_heads: bool, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + seq_lens = random.sample(range(1, max_len), num_seqs) + cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + bs_attn_op = LocalStridedBlockSparseAttn( + num_query_heads, + max_len, + local_blocks=blocksparse_local_blocks, + vert_stride=blocksparse_vert_stride, + block_size=blocksparse_block_size, + device=device, + dtype=dtype, + homo_head=blocksparse_homo_heads) + + output = bs_attn_op(query, + key, + value, + cu_seq_lens.to(device), + sm_scale=scale) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f0fab4d8aa26d..22cf5a44e341f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -45,11 +45,17 @@ def paged_attention_v1( alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, seq_lens, - block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, kv_scale) + vllm_ops.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) def paged_attention_v2( @@ -69,12 +75,18 @@ def paged_attention_v2( alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, - key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, - max_seq_len, alibi_slopes, kv_cache_dtype, - kv_scale) + vllm_ops.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) # pos encoding ops diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 94ab64de30a94..6396103bf5efa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -111,6 +111,7 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000000000..dce2b83615b7a --- /dev/null +++ b/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,410 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata": + return BlocksparseFlashAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + ) -> None: + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + kv_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache is None \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0f4568070cfc4..0b9d6283493f2 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -1,6 +1,6 @@ """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache @@ -219,7 +219,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 94f3f55636ed6..e92e6c5e2dc8d 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch @@ -201,7 +201,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "ROCFlashAttention does not support blocksparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index a19c97e1e0e35..9b50adec5244d 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,7 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention @@ -100,7 +100,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "Torch SPDA does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 96169da6cf92c..99a3e88bc07b6 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -1,6 +1,6 @@ """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops @@ -212,7 +212,10 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: + assert blocksparse_params is None, ValueError( + "XFormer does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index dc7b3940bc9b7..b67f04c51d493 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,5 +1,5 @@ """Attention layer.""" -from typing import List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -33,6 +33,7 @@ def __init__( sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() if cache_config is not None: @@ -69,10 +70,12 @@ def __init__( dtype = torch.get_default_dtype() attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads, sliding_window, dtype, kv_cache_dtype, - block_size) + block_size, blocksparse_params + is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype) + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params) def forward( self, @@ -90,4 +93,5 @@ def extra_repr(self) -> str: s += f", num_heads={self.impl.num_heads}" # type: ignore s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" return s diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000000000..ec1c37c5bcb0e --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,423 @@ +import torch +import triton +import triton.language as tl + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000000000..300211e70bb79 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,238 @@ +import math + +import torch + +from vllm.utils import is_cpu, is_hip + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 8) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = is_hip() or is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if torch.cuda.is_available() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = [ + self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v] + ] + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) \ No newline at end of file diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000000000..0d90dd971e156 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,216 @@ +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import torch +import triton +from scipy import sparse + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 30feaa4da254d..e119fdcf11113 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -91,9 +91,21 @@ def forward_decode( scale: float, alibi_slopes: Optional[torch.Tensor], kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, ) -> torch.Tensor: - output = torch.empty_like(query) + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // @@ -107,6 +119,7 @@ def forward_decode( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -123,6 +136,11 @@ def forward_decode( alibi_slopes, kv_cache_dtype, kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, ) else: # Run PagedAttention V2. @@ -155,6 +173,11 @@ def forward_decode( alibi_slopes, kv_cache_dtype, kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, ) return output diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index f191461dcd3b7..9ceda3431b898 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -29,7 +29,14 @@ def get_attn_backend( dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, + is_blocksparse: bool = False, ) -> Type[AttentionBackend]: + + if is_blocksparse: + logger.info("Using BlocksparseFlashAttention backend.") + from vllm.attention.backends.blocksparse_attn import ( + BlocksparseFlashAttentionBackend) + return BlocksparseFlashAttentionBackend """Determine which attention backend to use and only import the selected backend module. """ diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index db3fc85decd70..0df0223b9dbb2 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -100,6 +100,7 @@ def _create_logprobs( token_logprob = step_top_logprobs[token_id].logprob token = step_top_logprobs[token_id].decoded_token logprobs.tokens.append(token) + token_logprob = max(token_logprob, -9999.0) logprobs.token_logprobs.append(token_logprob) if num_output_top_logprobs: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6aec104be8da4..a92abe6b5b8dc 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -56,6 +56,7 @@ "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py new file mode 100644 index 0000000000000..0c5298eb6f100 --- /dev/null +++ b/vllm/model_executor/models/phi3_small.py @@ -0,0 +1,447 @@ +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + + +def load_column_parallel_weight(param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + tp = get_tensor_model_parallel_world_size() + rk = get_tensor_model_parallel_rank() + assert param.size(0) * tp == loaded_weight.size(0) + s = rk * param.size(0) + e = (rk + 1) * param.size(0) + loaded_weight = loaded_weight[s:e] + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + +class HeadMajorQKVParallelLinear(QKVParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +@torch.jit.script +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +@torch.jit.script +def gegelu(input, limit: Optional[float] = None): + a_gelu, a_linear = input[..., ::2], input[..., 1::2] + if limit is not None: + a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, + a_gelu.clamp(min=None, max=limit)) + a_linear = torch.where( + torch.isinf(a_linear), + a_linear, + a_linear.clamp(min=-limit, max=limit), + ) + out_gelu = quick_gelu(a_gelu) + return out_gelu * (a_linear + 1) + + +class Phi3SmallMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + assert (self.config.hidden_act == "gegelu" + ), "Only `gegelu` is supported for the 4.7 series of models .." + self.hidden_size = config.hidden_size + self.gegelu_limit = config.gegelu_limit + self.intermediate_size = config.intermediate_size + + self.up_proj = HeadMajorColumnParallelLinear( + self.hidden_size, + 2 * [self.intermediate_size], + bias=True, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + gate_up, _ = self.up_proj(x) + x = gegelu(gate_up) + x, _ = self.down_proj(x) + return x + + +class Phi3SmallSelfAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.sparse_block_size = config.blocksparse_block_size + self.homo_heads = config.blocksparse_homo_head_pattern + self.local_blocks = config.blocksparse_num_local_blocks + self.vert_stride = config.blocksparse_vert_stride + + assert (config.blocksparse_block_size == + config.blocksparse_triton_kernel_block_size) + + self.hidden_size = config.hidden_size + # Number of Query Heads + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + # Number of total Key Value Heads before tensor parallel + self.num_key_value_heads = config.num_key_value_heads + self.num_q_per_kv = self.num_heads // self.num_key_value_heads + if self.tp_size > 1: + assert self.num_key_value_heads % self.tp_size == 0 + self.num_kv_heads_per_partion = max( + 1, self.num_key_value_heads // self.tp_size) + self.num_heads_per_partition = self.num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_embedding_base = config.rope_embedding_base + self.rope_position_scale = config.rope_position_scale + self.is_causal = True + + norm_factor = None + if config.mup_use_scaling: + norm_factor = self.head_dim / config.mup_attn_multiplier + else: + norm_factor = math.sqrt(self.head_dim) + self.scale = 1 / norm_factor + + self.query_key_value = HeadMajorQKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=True, + quant_config=quant_config, + ) + + self.dense = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config) + + if getattr(self.config, "rope_scaling", None) is not None: + rope_scaling = self.config.rope_scaling + for key in rope_scaling: + if isinstance(rope_scaling[key], list): + rope_scaling[key] = tuple(rope_scaling[key]) + + if "factor" not in rope_scaling: + rope_scaling["factor"] = self.rope_position_scale + else: + rope_scaling = { + "type": "linear", + "factor": self.rope_position_scale, + } + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_embedding_base, + rope_scaling=rope_scaling, + ) + + # blocksparse params + self.blocksparse_block_size = config.blocksparse_block_size + self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks + self.blocksparse_vert_stride = config.blocksparse_vert_stride + + use_dense_attn = (getattr(self.config, + "dense_attention_every_n_layers", None) + and (self.layer_idx + 1) % + self.config.dense_attention_every_n_layers == 0) + + bs_params = None + if not use_dense_attn: + bs_params = { + 'max_seqlen': self.max_position_embeddings, + 'num_heads': self.num_heads_per_partition, + "num_kv_heads": self.num_kv_heads_per_partion, + "block_size": self.sparse_block_size, + "local_blocks": self.local_blocks, + "vert_stride": self.vert_stride, + "homo_head": self.homo_heads + } + + self.attn = Attention( + self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + qkv, _ = self.query_key_value(hidden_states) + + qkv = qkv.view(qkv.shape[:-1] + + (-1, (self.num_q_per_kv + 2), self.head_dim)) + q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) + + # NOTE: this is required by RotaryEmbed, which indeed does not have to + # TODO: allow 3D QK for rotary forward + q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) + k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) + output, _ = self.dense(attn_output) + + return output + + +class Phi3SmallDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3SmallSelfAttention(config, + layer_idx, + cache_config=cache_config, + quant_config=quant_config) + self.mlp = Phi3SmallMLP(config, quant_config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Phi3SmallModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.mup_embedding_multiplier = config.mup_embedding_multiplier + self.layers = nn.ModuleList([ + Phi3SmallDecoderLayer(config, layer_idx, cache_config, + quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata = None, + ): + hidden_states = self.embed_tokens(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + attn_metadata, + ) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Phi3SmallForCausalLM(nn.Module): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Phi3SmallModel(config, cache_config, quant_config) + self.vocab_size = config.vocab_size + self.mup_width_multiplier = config.mup_width_multiplier + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + # tokens in tiktoken but not used + if hasattr(config, 'dummy_token_indices'): + device = self.lm_head.weight.device + self.register_buffer('dummy_token_indices', + torch.LongTensor( + config.dummy_token_indices).to(device), + persistent=False) + else: + self.dummy_token_indices = None + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + if self.dummy_token_indices is not None and logits is not None: + logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + return logits + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + output_hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + output_hidden_states = output_hidden_states + return output_hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits / self.mup_width_multiplier, + sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f36d84dbdf7f9..044eec6410a54 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -63,4 +63,4 @@ def get_hf_text_config(config: PretrainedConfig): assert hasattr(config.text_config, "num_attention_heads") return config.text_config else: - return config + return config \ No newline at end of file From 325c119961698c27d8d11d61d019a6d57c814c51 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 24 May 2024 23:49:49 -0700 Subject: [PATCH 343/413] [Misc] add logging level env var (#5045) --- .github/ISSUE_TEMPLATE/400-bug report.yml | 2 ++ vllm/envs.py | 5 +++++ vllm/logger.py | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index 08120ad8e5a60..ce980c3f4a01d 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -59,6 +59,8 @@ body: Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues. + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/vllm/envs.py b/vllm/envs.py index 56ff79e0cdea9..bef343d08429c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -22,6 +22,7 @@ VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None @@ -178,6 +179,10 @@ "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + # this is used for configuring the default logging level + "VLLM_LOGGING_LEVEL": + lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"), + # Trace function calls # If set to 1, vllm will trace function calls # Useful for debugging diff --git a/vllm/logger.py b/vllm/logger.py index 153cdfb373bb4..3c6bf0803a624 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -14,6 +14,7 @@ VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH +VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" @@ -30,7 +31,7 @@ "vllm": { "class": "logging.StreamHandler", "formatter": "vllm", - "level": "INFO", + "level": VLLM_LOGGING_LEVEL, "stream": "ext://sys.stdout", }, }, From d5a16977729928a2ceafb5ec8764081f40f7cdff Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Sat, 25 May 2024 10:00:14 -0700 Subject: [PATCH 344/413] [Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000) --- .../spec_decode/e2e/test_ngram_correctness.py | 41 +++++++++++++++++++ tests/spec_decode/test_dynamic_spec_decode.py | 16 +++++--- vllm/spec_decode/spec_decode_worker.py | 17 +++++--- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index c2004ff061a1e..d475d37af6425 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py index 948a74b22f0ae..48fa862b2e41a 100644 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ b/tests/spec_decode/test_dynamic_spec_decode.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -13,9 +13,9 @@ from .utils import create_batch, mock_worker -@pytest.mark.parametrize('queue_size', [2, 4]) -@pytest.mark.parametrize('batch_size', [1, 2, 3, 6]) -@pytest.mark.parametrize('k', [1, 2, 5, 7, 10]) +@pytest.mark.parametrize('queue_size', [4]) +@pytest.mark.parametrize('batch_size', [1]) +@pytest.mark.parametrize('k', [1]) @torch.inference_mode() def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): """Verify that speculative tokens are disabled when the batch size @@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int): num_lookahead_slots=k, running_queue_size=queue_size) - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) + if queue_size > disable_by_batch_size: + with patch.object(worker, + '_run_no_spec', + side_effect=ValueError(exception_secret)), \ + pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) # When the batch size is larger than the threshold, # we expect no speculative tokens (0). diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 3462a876c3e90..150e8db0c8aad 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -273,10 +273,17 @@ def execute_model( self._maybe_disable_speculative_tokens( disable_all_speculation, execute_model_req.seq_group_metadata_list) - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch. + # In any of these cases, the proposer and scorer workers + # are called normally. if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: + execute_model_req.seq_group_metadata_list + ) == 0 or disable_all_speculation: return self._run_no_spec(execute_model_req, skip_proposer=disable_all_speculation) @@ -316,8 +323,8 @@ def _maybe_disable_speculative_tokens( @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec(self, execute_model_req: ExecuteModelRequest, skip_proposer: bool) -> List[SamplerOutput]: - """Run a prefill step, without any speculation. The input is sent to - the proposer and scorer model so that the KV cache is consistent + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent between the two. When skip_proposer is True, the proposer model is not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. From f17a1a8f9665bb237a3dddda7dc93f259e5e81e0 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 25 May 2024 10:28:16 -0700 Subject: [PATCH 345/413] [Misc] Make Serving Benchmark More User-friendly (#5044) --- benchmarks/backend_request_func.py | 6 ++++++ benchmarks/benchmark_serving.py | 29 ++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index f9d167590fe47..58dcc6167efa6 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -89,6 +89,9 @@ async def async_request_tgi( output.latency = most_recent_timestamp - st output.success = True output.generated_text = data["generated_text"] + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() @@ -276,6 +279,9 @@ async def async_request_openai_completions( output.generated_text = generated_text output.success = True output.latency = latency + else: + output.error = response.reason or "" + output.success = False except Exception: output.success = False exc_info = sys.exc_info() diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9c3fed4817de2..f3d71de775f82 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -215,6 +215,11 @@ def calculate_metrics( else: actual_output_lens.append(0) + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -226,9 +231,9 @@ def calculate_metrics( 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, - mean_tpot_ms=np.mean(tpots) * 1000, - median_tpot_ms=np.median(tpots) * 1000, - p99_tpot_ms=np.percentile(tpots, 99) * 1000, + mean_tpot_ms=np.mean(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, ) return metrics, actual_output_lens @@ -250,6 +255,24 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + best_of=best_of, + use_beam_search=use_beam_search, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) From 1102bef2195a102c6a5489a28329e543a600b4d8 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 27 May 2024 15:18:17 -0700 Subject: [PATCH 346/413] [Bugfix / Core] Prefix Caching Guards (merged with main) (#4846) Co-authored-by: rsnm2 Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- .../test_disable_sliding_window.py | 44 ++++++++++++ tests/test_config.py | 24 +++++++ vllm/attention/layer.py | 3 +- vllm/config.py | 67 ++++++++++++++++++- vllm/engine/arg_utils.py | 12 +++- vllm/model_executor/models/llama.py | 4 -- vllm/model_executor/models/mixtral.py | 22 +++--- vllm/model_executor/models/mixtral_quant.py | 4 -- vllm/model_executor/models/qwen2.py | 25 ++++--- vllm/model_executor/models/starcoder2.py | 2 - vllm/model_executor/models/xverse.py | 4 -- 11 files changed, 167 insertions(+), 44 deletions(-) create mode 100644 tests/prefix_caching/test_disable_sliding_window.py diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py new file mode 100644 index 0000000000000..eeac6ab43c05f --- /dev/null +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -0,0 +1,44 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_prefix_caching.py`. +""" +import pytest + +from tests.conftest import cleanup +from vllm import LLM + +MODEL_LEN_LEN = [ + # Example models with sliding window. + ("bigcode/starcoder2-3b", 4096, 16384), + # ("mistralai/Mistral-7B-v0.1", 4096, 32768), << OOM in CI + + # Confirm model with sliding window works. + # config has "use_sliding_window": false + ("Qwen/Qwen1.5-0.5B-Chat", 32768, 32768), + # config has no sliding window attribute. + ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", 2048, 2048), +] + + +@pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) +def test_disable_sliding_window(model_len_len, ): + model, sliding_len, full_len = model_len_len + vllm_disabled_model = LLM(model, disable_sliding_window=True) + vllm_disabled_model.generate("Hi my name is") + model_config = vllm_disabled_model.llm_engine.model_config + assert model_config.max_model_len == sliding_len, ( + "Max len expected to equal sliding_len of %s, but got %s", sliding_len, + model_config.max_model_len) + + del vllm_disabled_model + cleanup() + + vllm_enabled_model = LLM(model, disable_sliding_window=False) + vllm_enabled_model.generate("Hi my name is") + model_config = vllm_enabled_model.llm_engine.model_config + assert model_config.max_model_len == full_len, ( + "Max len expected to equal full_len of %s, but got %s", full_len, + model_config.max_model_len) + + del vllm_enabled_model + cleanup() diff --git a/tests/test_config.py b/tests/test_config.py index 6bc51a53dc07c..7cbdaeca9c4d4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,29 @@ +import pytest + from vllm.config import ModelConfig +MODEL_IDS_EXPECTED = [ + ("Qwen/Qwen1.5-7B", 32768), + ("mistralai/Mistral-7B-v0.1", 4096), + ("mistralai/Mistral-7B-Instruct-v0.2", 32768), +] + + +@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED) +def test_disable_sliding_window(model_id_expected): + model_id, expected = model_id_expected + model_config = ModelConfig( + model_id, + model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + disable_sliding_window=True, + ) + assert model_config.max_model_len == expected + def test_get_sliding_window(): TEST_SLIDING_WINDOW = 4096 diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b67f04c51d493..db55a31476fed 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -30,7 +30,6 @@ def __init__( scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, @@ -39,9 +38,11 @@ def __init__( if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size + sliding_window = cache_config.sliding_window else: kv_cache_dtype = "auto" block_size = 16 + sliding_window = None if num_kv_heads is None: num_kv_heads = num_heads diff --git a/vllm/config.py b/vllm/config.py index b245a1a3ee6d3..4b256d00a32df 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -69,6 +69,10 @@ class ModelConfig: max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. served_model_name: The model name used in metrics tag `model_name`, @@ -96,6 +100,7 @@ def __init__( max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, + disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, ) -> None: @@ -118,14 +123,18 @@ def __init__( self.max_seq_len_to_capture = (max_seq_len_to_capture or max_context_len_to_capture) self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling) self.hf_text_config = get_hf_text_config(self.hf_config) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_text_config, - max_model_len) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) self.served_model_name = get_served_model_name(model, served_model_name) if not self.skip_tokenizer_init: @@ -220,7 +229,7 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") - def get_sliding_window(self) -> Optional[int]: + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled. """ @@ -232,6 +241,15 @@ def get_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + def get_vocab_size(self) -> int: return self.hf_text_config.vocab_size @@ -336,6 +354,7 @@ def __init__( self.enable_prefix_caching = enable_prefix_caching self._verify_args() self._verify_cache_dtype() + self._verify_prefix_caching() # Will be set after profiling. self.num_gpu_blocks = None @@ -364,6 +383,19 @@ def _verify_cache_dtype(self) -> None: else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -1116,6 +1148,8 @@ def _get_and_verify_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1135,6 +1169,7 @@ def _get_and_verify_max_len( "max_seq_length", "seq_len", ] + # Choose the smallest "max_length" from the possible keys. max_len_key = None for key in possible_keys: max_len = getattr(hf_config, key, None) @@ -1142,6 +1177,16 @@ def _get_and_verify_max_len( max_len_key = key if max_len < derived_max_model_len \ else max_len_key derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. @@ -1157,6 +1202,13 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None and rope_scaling["type"] != "su": + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": @@ -1164,6 +1216,8 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. if max_model_len is None: max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: @@ -1172,6 +1226,13 @@ def _get_and_verify_max_len( # with model_max_length and allow this override when it's smaller. model_max_length = getattr(hf_config, "model_max_length", None) if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") pass else: raise ValueError( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 538e3427e37fb..3267c8c9f44d2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,6 +41,7 @@ class EngineArgs: max_parallel_loading_workers: Optional[int] = None block_size: int = 16 enable_prefix_caching: bool = False + disable_sliding_window: bool = False use_v2_block_manager: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -267,6 +268,10 @@ def add_cli_args( parser.add_argument('--enable-prefix-caching', action='store_true', help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size') parser.add_argument('--use-v2-block-manager', action='store_true', help='Use BlockSpaceMangerV2.') @@ -558,8 +563,8 @@ def create_engine_config(self, ) -> EngineConfig: self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, self.max_seq_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init, - self.served_model_name) + self.max_logprobs, self.disable_sliding_window, + self.skip_tokenizer_init, self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -645,7 +650,8 @@ def create_engine_config(self, ) -> EngineConfig: if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled): raise ValueError( - "Chunked prefill is not supported with sliding window.") + "Chunked prefill is not supported with sliding window. " + "Set --disable-sliding-window to disable sliding window.") return EngineConfig(model_config=model_config, cache_config=cache_config, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 086f9294c4f1c..2ca55f9270fc7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -94,7 +94,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -146,7 +145,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -183,7 +181,6 @@ def __init__( config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( @@ -198,7 +195,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=attention_bias, - sliding_window=sliding_window, cache_config=cache_config, ) self.mlp = LlamaMLP( diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ea95cf7380d54..d6dd7fa1fe9e2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -246,15 +246,16 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -276,7 +277,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if isinstance( quant_config, @@ -312,7 +312,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -349,7 +348,6 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE( diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 9b99ff729aadd..1894c05e167d6 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -166,7 +166,6 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -190,7 +189,6 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window self.qkv_proj = QKVParallelLinear( hidden_size, @@ -217,7 +215,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -254,7 +251,6 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - sliding_window=config.sliding_window, cache_config=cache_config, quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index ec203c3b9001a..9a4829a27873e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -86,10 +86,8 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - use_sliding_window: bool = False, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - sliding_window: Optional[int] = None, rope_scaling: Optional[Tuple] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -112,7 +110,6 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.sliding_window = sliding_window if use_sliding_window else None self.qkv_proj = QKVParallelLinear( hidden_size, @@ -140,7 +137,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -164,7 +160,6 @@ class Qwen2DecoderLayer(nn.Module): def __init__( self, config: Qwen2Config, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -173,18 +168,14 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) - use_sliding_window = (config.use_sliding_window - and layer_idx < config.max_window_layers) self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - use_sliding_window=use_sliding_window, cache_config=cache_config, quant_config=quant_config, - sliding_window=config.sliding_window, rope_scaling=rope_scaling) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -244,8 +235,8 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) + Qwen2DecoderLayer(config, cache_config, quant_config) + for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -302,6 +293,18 @@ def __init__( lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + super().__init__() self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 91ffd0861c39d..4324bf50d4ad1 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -74,7 +74,6 @@ def __init__(self, self.rope_theta = config.rope_theta self.max_position_embeddings = config.max_position_embeddings self.use_bias = config.use_bias - self.sliding_window = config.sliding_window self.qkv_proj = QKVParallelLinear( self.hidden_size, @@ -101,7 +100,6 @@ def __init__(self, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window, cache_config=cache_config, quant_config=quant_config) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index dda13d83f89a3..1e5280dde3ff9 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -88,7 +88,6 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, - sliding_window: Optional[int] = None, cache_config: Optional[CacheConfig] = None, ) -> None: super().__init__() @@ -134,7 +133,6 @@ def __init__( self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config) @@ -167,7 +165,6 @@ def __init__( rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - sliding_window = getattr(config, "sliding_window", None) self.self_attn = XverseAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -178,7 +175,6 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, bias=getattr(config, "bias", False), - sliding_window=sliding_window, cache_config=cache_config, ) self.mlp = XverseMLP( From fbdb7b3ee2c3f7b58841c12b52ed4c83f508babb Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Mon, 27 May 2024 22:26:14 +0000 Subject: [PATCH 347/413] [Core] Allow AQLM on Pascal (#5058) --- vllm/model_executor/layers/quantization/aqlm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index 83e24fadc1405..730595c3d36d1 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -192,7 +192,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - return 70 + return 60 @classmethod def get_config_filenames(cls) -> List[str]: From 890aa93d275a2b75313629614ab9ed278a13f6d7 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 28 May 2024 07:41:43 +0800 Subject: [PATCH 348/413] [Model] Add support for falcon-11B (#5069) --- vllm/model_executor/models/falcon.py | 55 ++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ba707adb03dfe..9618652f70d23 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput @@ -246,18 +246,26 @@ def __init__( self.mlp = FalconMLP(config, quant_config) self.config = config - if config.new_decoder_architecture: - # The layer norm before self-attention - self.ln_attn = LayerNorm(hidden_size, - eps=config.layer_norm_epsilon) - # The layer norm before the MLP - self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - else: + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - if not config.parallel_attn: - self.post_attention_layernorm = LayerNorm( - hidden_size, eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -271,7 +279,7 @@ def forward( ) -> torch.Tensor: residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -294,6 +302,10 @@ def forward( residual += attention_output mlp_layernorm_out = self.post_attention_layernorm(residual) + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + # MLP. mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) if self.reduce_row_parallel_results and mlp_bias is not None: @@ -375,7 +387,20 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = FalconModel(config, cache_config, quant_config) - self.lm_head_weight = self.transformer.word_embeddings.weight + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head_weight = self.transformer.word_embeddings.weight + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + self.lm_head_weight = self.lm_head.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -419,8 +444,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if name == "lm_head.weight": - # Falcon uses tied embeddings. + if name == "lm_head.weight" and self.tie_word_embeddings: + # Falcon uses tied embeddings except Falcon-11b. continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: From d4f398590786f0015d474b03a3d078db1e7d1be2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Mon, 27 May 2024 19:07:07 -0700 Subject: [PATCH 349/413] [Core] Sliding window for block manager v2 (#4545) Co-authored-by: Ruth Evans --- tests/core/block/e2e/conftest.py | 26 +++ tests/core/block/e2e/test_correctness.py | 11 +- .../e2e/test_correctness_sliding_window.py | 168 ++++++++++++++++++ tests/core/block/test_block_manager_v2.py | 69 +++++++ vllm/attention/ops/prefix_prefill.py | 6 +- vllm/core/block/block_table.py | 34 +++- vllm/core/block/cpu_gpu_block_allocator.py | 74 ++++++++ vllm/core/block/interfaces.py | 9 + vllm/core/block_manager_v2.py | 24 ++- vllm/engine/arg_utils.py | 3 +- vllm/worker/cache_engine.py | 5 +- vllm/worker/model_runner.py | 73 +++++--- 12 files changed, 457 insertions(+), 45 deletions(-) create mode 100644 tests/core/block/e2e/test_correctness_sliding_window.py diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index b0d62c8993d3f..e870597b7a011 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -1,3 +1,5 @@ +from typing import Callable, Iterable, Optional + import pytest from vllm import LLM @@ -40,3 +42,27 @@ def generator_inner(): for llm in generator_inner(): yield llm del llm + + +def get_text_from_llm_generator(llm_generator: Iterable[LLM], + prompts, + sampling_params, + llm_cb: Optional[Callable[[LLM], + None]] = None): + for llm in llm_generator: + if llm_cb: + llm_cb(llm) + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + text = [output.outputs[0].text for output in outputs] + del llm + + return text + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index c3666da7542b5..3713ef2fed4d1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -4,6 +4,8 @@ from vllm import SamplingParams +from .conftest import get_token_ids_from_llm_generator + @pytest.mark.parametrize( "common_llm_kwargs", @@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, assert expected_token_ids == actual_token_ids assert baseline_token_ids == test_token_ids - - -def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): - for llm in llm_generator: - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - token_ids = [output.outputs[0].token_ids for output in outputs] - del llm - - return token_ids diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py new file mode 100644 index 0000000000000..e98292e807d73 --- /dev/null +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -0,0 +1,168 @@ +import random +from typing import List + +import pytest + +from vllm import LLM, SamplingParams + +from .conftest import get_text_from_llm_generator + +# relatively small model with 4k sliding window +MODEL = "bigcode/starcoder2-3b" +BLOCK_SIZE = 16 + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + # needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008 + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, + batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + + Additionally, we compare the results of the v1 and v2 managers. + """ + sampling_params = SamplingParams( + max_tokens=1024, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + print('Getting token ids from block manager v1') + baseline_texts = get_text_from_llm_generator(baseline_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + + check_answers(indices, answer, baseline_texts) + + print('Getting token ids from block manager v2') + test_texts = get_text_from_llm_generator(test_llm_generator, prompts, + sampling_params) + check_answers(indices, answer, test_texts) + + cmp = [ + expected_text == actual_text + for expected_text, actual_text in zip(baseline_texts, test_texts) + ] + print(cmp) + # make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768 + # however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290 + # states that xformers and flash_attn have different ideas about the window + # size anyways + assert sum(cmp) > 0.7 * len(cmp) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": MODEL, + + # skip cuda graph creation for fast test. + "enforce_eager": True, + "block_size": BLOCK_SIZE, + "num_gpu_blocks_override": 100000 // BLOCK_SIZE, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "use_v2_block_manager": True, + "enable_chunked_prefill": True +}]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): + """ + This is similar to test_sliding_window_retrival, however, it doesn't + compare against the v1 block manager since v1 doesn't support + chunked prefill with sliding window. + + The results with and without chunked prefill are not the same due to + numerical instabilities. + """ + sampling_params = SamplingParams( + max_tokens=10, + ignore_eos=True, + temperature=0.0, + ) + + prompts, answer, indices = prep_prompts(batch_size) + + # We don't compare with the baseline model here, since the results + # slightly different due to different tailing in attention. + test_texts = get_text_from_llm_generator(test_llm_generator, + prompts, + sampling_params, + llm_cb=check_window(prompts)) + check_answers(indices, answer, test_texts) + + +def prep_prompts(batch_size: int): + """ + Generate prompts which a bunch of assignments, + then asking for the value of one of them. + The prompt is just under 10k tokens; sliding window is 4k + so the answer is outside sliding window, but should still be correct. + """ + prompts: List[str] = [] + answer: List[int] = [] + indices: List[int] = [] + random.seed(1) + for _ in range(batch_size): + idx = random.randint(30, 90) + indices.append(idx) + prompt = "```python\n# We set a number of variables, " + \ + f"x{idx} will be important later\n" + ln = random.randint(800, 1100) + for k in range(30, ln): + v = random.randint(10, 99) + if k == idx: + answer.append(v) + prompt += f"x{k} = {v}\n" + prompt += f"# Now, we check the value of x{idx}:\n" + prompt += f"assert x{idx} == " + prompts.append(prompt) + return prompts, answer, indices + + +def check_answers(indices: List[int], answer: List[int], outputs: List[str]): + answer2 = [int(text[0:2].strip()) for text in outputs] + print(list(zip(indices, zip(answer, answer2)))) + numok = 0 + for a1, a2 in zip(answer, answer2): + if a1 == a2: + numok += 1 + frac_ok = numok / len(answer) + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") + assert frac_ok > 0.7 + + +def check_window(prompts: List[str]): + + def inner(llm: LLM): + sliding_window = llm.llm_engine.model_config.get_sliding_window() + assert sliding_window and sliding_window > 0 + assert any( + len(llm.get_tokenizer().tokenize(prompt)) > sliding_window + for prompt in prompts) + + return inner diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 1e8e4ccdfb151..91b047f0e183e 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, range(prompt_len + num_slots_to_append + num_lookahead_slots)), block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) assert num_consumed_blocks == expected_consumed_blocks + + +@pytest.mark.parametrize("block_size", [8, 16]) +@pytest.mark.parametrize("prompt_len", [10, 300, 1000]) +@pytest.mark.parametrize("num_slots_to_append", [50]) +@pytest.mark.parametrize("sliding_window", [20, 32, 200, 512]) +def test_sliding_window(block_size, prompt_len, num_slots_to_append, + sliding_window): + """Verify append_slots consumes the correct number of blocks from the block + table. + """ + + num_gpu_blocks = 1024 + watermark = 0.1 + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + watermark=watermark, + sliding_window=sliding_window, + ) + + def check_used(min_n, max_n=None): + if max_n is None: + max_n = min_n + used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() + #print("check", min_n, used, max_n) + assert min_n <= used + assert used <= max_n + + def num_blocks(num_tokens): + return (num_tokens + block_size - 1) // block_size + + check_used(0) + + seq_group = create_seq_group( + seq_prompt_len=prompt_len, + seq_output_lens=[0], + ) + + check_used(0) + + # Allocate seq + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + + check_used(num_blocks(prompt_len)) + + # Seq seq to RUNNING + seq = seq_group.get_seqs()[0] + seq.status = SequenceStatus.RUNNING + + seq.data.update_num_computed_tokens(prompt_len) + check_used(num_blocks(prompt_len)) + + # this is how we compute it in BlockSpaceManagerV2.__init__ + sliding_blocks = (sliding_window // block_size) + 2 + # plus one block for null block + sliding_blocks += 1 + + # Append tokens to the sequeqnce + for token_id in range(num_slots_to_append): + seq.append_token_id(token_id, {token_id: Logprob(0.0)}) + seq.data.update_num_computed_tokens(1) + block_manager.append_slots(seq, num_lookahead_slots=0) + if prompt_len < sliding_window + 10: + check_used(0, sliding_blocks + 1) + else: + check_used(sliding_blocks, sliding_blocks + 1) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 997b25e887e30..b99cf9a50d105 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -697,6 +697,10 @@ def context_attention_fwd(q, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + num_warps = 8 if Lk <= 64 else 8 if alibi_slopes is not None: _fwd_kernel_alibi[grid]( @@ -794,7 +798,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, - SLIDING_WINDOW=sliding_window if sliding_window is not None else 0, + SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index b0d9511fba521..26c704b8de901 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -20,6 +20,10 @@ class BlockTable: _blocks (Optional[List[Block]], optional): An optional list of existing blocks to initialize the BlockTable with. If not provided, an empty BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequance. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. Attributes: _block_size (int): The maximum number of tokens that can be stored in a @@ -37,6 +41,7 @@ def __init__( block_size: int, block_allocator: DeviceAwareBlockAllocator, _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, ): self._block_size = block_size self._allocator = block_allocator @@ -44,6 +49,7 @@ def __init__( _blocks = [] self._blocks: List[Block] = _blocks + self._max_block_sliding_window = max_block_sliding_window # Use helper method instead of directly calculating, as blocks # may not be allocated. self._num_full_slots = len(self._get_all_token_ids()) @@ -89,7 +95,8 @@ def allocate(self, def append_token_ids(self, token_ids: List[int], - num_lookahead_slots: int = 0) -> None: + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None) -> None: """Appends a sequence of token IDs to the existing blocks in the BlockTable. @@ -104,13 +111,35 @@ def append_token_ids(self, Args: token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. """ - assert self._is_allocated + assert self._is_allocated, "no blocks have been allocated" assert len(self._blocks) > 0 + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) + # Update the blocks with the new tokens blocks = self._blocks[self._num_full_slots // self._block_size:] token_blocks = self._chunk_token_blocks_for_append(token_ids) @@ -168,6 +197,7 @@ def fork(self) -> "BlockTable": block_size=self._block_size, block_allocator=self._allocator, _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, ) def free(self) -> None: diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 0577ca76ea971..d28a684376974 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -105,11 +105,19 @@ def __init__( Device.GPU: gpu_block_allocator, } + self._null_block: Optional[Block] = None + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} for _, allocator in self._allocators.items(): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable(None, Device.GPU)) + return self._null_block + def allocate_mutable(self, prev_block: Optional[Block], device: Device) -> Block: """Allocates a new mutable block on the specified device. @@ -149,6 +157,9 @@ def free(self, block: Block) -> None: Args: block (Block): The block to be freed. """ + # Null block should never be freed + if isinstance(block, NullBlock): + return block_id = block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] @@ -165,6 +176,8 @@ def fork(self, last_block: Block) -> List[Block]: List[Block]: A new list of blocks that shares the same memory as the original sequence. """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) block_id = last_block.block_id assert block_id is not None allocator = self._block_ids_to_allocator[block_id] @@ -226,3 +239,64 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: raise NotImplementedError + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 140fbbb0949cc..8fc4c601106cd 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -203,3 +203,12 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index f0bc96564050a..834436c25e160 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -66,9 +66,18 @@ def __init__( self.num_total_gpu_blocks = num_gpu_blocks self.num_total_cpu_blocks = num_cpu_blocks - assert sliding_window is None, "Sliding window not yet supported" - - self.block_sliding_window = None + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 self.watermark = watermark assert watermark >= 0.0 @@ -96,10 +105,9 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) - assert self.block_sliding_window is None - if self.block_sliding_window is not None: + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, - self.block_sliding_window) + self.max_block_sliding_window) num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( device=Device.GPU) @@ -125,8 +133,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table = BlockTable( block_size=self.block_size, block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, ) - assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) self.block_tables[seq.seq_id] = block_table @@ -174,6 +183,7 @@ def append_slots( block_table.append_token_ids( token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), ) # Return any new copy-on-writes. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3267c8c9f44d2..11485aa2438c0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -648,7 +648,8 @@ def create_engine_config(self, ) -> EngineConfig: guided_decoding_backend=self.guided_decoding_backend) if (model_config.get_sliding_window() is not None - and scheduler_config.chunked_prefill_enabled): + and scheduler_config.chunked_prefill_enabled + and not scheduler_config.use_v2_block_manager): raise ValueError( "Chunked prefill is not supported with sliding window. " "Set --disable-sliding-window to disable sliding window.") diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 07d51dca226bd..2f0e59f7ae7c9 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -68,8 +68,11 @@ def _allocate_kv_cache( pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] for _ in range(self.num_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. kv_cache.append( - torch.empty(kv_cache_shape, + torch.zeros(kv_cache_shape, dtype=self.dtype, pin_memory=pin_memory, device=device)) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 87d5f5c1b9d67..5ddd2d1b65f81 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -269,6 +269,12 @@ def _prepare_model_input( if len(seq_group_metadata_list) == 0: return ModelInput.empty(self.device) + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window + self.block_size - + 1) // self.block_size + block_aligned_sliding_window = \ + sliding_window_blocks * self.block_size + for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) is_prompt = seq_group_metadata.is_prompt @@ -309,6 +315,30 @@ def _prepare_model_input( and self.sliding_window is None and is_prompt) + # These are seq_len/context_len capped to the sliding window. + # They are passed to decode kernel. + # We still need original seq_len/context_len to compute slot + # mapping (and input position) below. + curr_sliding_window_blocks = None + sliding_seq_len = seq_len + sliding_context_len = context_len + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if (self.sliding_window is not None and not is_prompt): + curr_sliding_window_blocks = sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = seq_len % self.block_size + sliding_seq_len = min( + seq_len, block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_blocks += 1 + else: + sliding_seq_len = min(seq_len, self.sliding_window) + sliding_context_len = sliding_seq_len - 1 + # TODO(sang): Combine chunked prefill and prefix caching by # only allowing multiple of block_size chunk size. # NOTE: This only works for oooooooxxx style attention. @@ -316,6 +346,13 @@ def _prepare_model_input( assert computed_block_nums is not None context_len = len(computed_block_nums) * self.block_size tokens = tokens[context_len:] + + # need to think what to set it to when we have both sliding + # window and prefix caching... + assert self.sliding_window is None, \ + "Prefix caching is not supported with sliding window" + sliding_context_len = context_len + if self.attn_backend.get_name() == "flash-attn": # NOTE(woosuk): For flash-attn, the block table should # include the entries for the incoming prefill tokens. @@ -329,14 +366,9 @@ def _prepare_model_input( if seq_group_metadata.block_tables is not None: # chunked prefill or decode block_table = seq_group_metadata.block_tables[seq_id] - if self.sliding_window is not None: - # chunked prefill doesn't support sliding window. - assert (not self.scheduler_config. - chunked_prefill_enabled) - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - + if curr_sliding_window_blocks is not None: + block_table = block_table[ + -curr_sliding_window_blocks:] if self.attn_backend.get_name() == "flashinfer": paged_kv_indices.extend(block_table) paged_kv_indptr.append(paged_kv_indptr[-1] + @@ -354,16 +386,9 @@ def _prepare_model_input( block_table = [] block_tables.append(block_table) - # TODO(sang): This is a hack to make sliding window work with - # paged attn. We can remove it if we make paged attn kernel - # to properly handle slinding window attn. - if (self.sliding_window is not None and not is_prompt): - seq_len = min(seq_len, self.sliding_window) - context_len = seq_len - 1 - - seq_lens.append(seq_len) - context_lens.append(context_len) - query_len = seq_len - context_len + seq_lens.append(sliding_seq_len) + context_lens.append(sliding_context_len) + query_len = sliding_seq_len - sliding_context_len query_lens.append(query_len) input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) @@ -380,16 +405,15 @@ def _prepare_model_input( "seq_len: {}, context_len: {}, query_len: {}".format( seq_len, context_len, query_len)) num_decode_tokens += query_len - decode_seq_lens.append(seq_len) + decode_seq_lens.append(sliding_seq_len) if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (seq_len - context_len) + lora_index_mapping += [lora_id] * query_len lora_prompt_mapping.extend( [lora_id] * - (seq_len - - context_len if seq_group_metadata.sampling_params + (query_len if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs else 1)) @@ -417,9 +441,10 @@ def _prepare_model_input( start_idx = 0 if self.sliding_window is not None: if is_prompt: - assert context_len == 0, ( + assert self.scheduler_config.use_v2_block_manager \ + or context_len == 0, ( "Prefix caching is currently not supported with " - "sliding window attention") + "sliding window attention in V1 block manager") # It is an optimization. When it is decoding, it is always # 0. When prefill, we use it to not write slots to kv cache # to save memory. From 9ba415588aeda8d99bda8889f90010f0d7330e89 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 28 May 2024 08:32:42 -0700 Subject: [PATCH 350/413] [BugFix] Fix Embedding Models with TP>1 (#5075) --- vllm/worker/embedding_model_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index ef02de95fc54e..0ba1200696cab 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -79,6 +79,10 @@ def execute_model( execute_model_kwargs.update({"image_input": multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return None + return self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata) From dd8de11f0a15f9ef48cd1dcac02dd8e2a8bcb494 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Tue, 28 May 2024 11:03:23 -0500 Subject: [PATCH 351/413] [Kernel][ROCm][AMD] Add fused_moe Triton configs for MI300X (#4951) This PR adds Triton kernel configs for the MoE kernel for MI300X --- ...14336,device_name=AMD_Instinct_MI300X.json | 128 ++++++++++++++++++ ...=1792,device_name=AMD_Instinct_MI300X.json | 110 +++++++++++++++ ...=3584,device_name=AMD_Instinct_MI300X.json | 128 ++++++++++++++++++ ...=7168,device_name=AMD_Instinct_MI300X.json | 128 ++++++++++++++++++ 4 files changed, 494 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..93472eb08a462 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_stages": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..5bd9d71e8f9bb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + "48": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..02e66280c1a3a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_stages": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_stages": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_stages": 1 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_stages": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 0000000000000..34c3b593d9799 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,128 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_stages": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_stages": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_stages": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_stages": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_stages": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_stages": 0 + } +} From 290f4ada2bf42174a53ae6aab2873e115c8ae11b Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Tue, 28 May 2024 12:29:09 -0500 Subject: [PATCH 352/413] [Docs] Add Dropbox as sponsors (#5089) --- README.md | 1 + docs/source/community/sponsors.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 627e45d4de0c4..d63819c3815c0 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ vLLM is a community project. Our compute resources for development and testing a - Crusoe Cloud - Databricks - DeepInfra +- Dropbox - Lambda Lab - NVIDIA - Replicate diff --git a/docs/source/community/sponsors.md b/docs/source/community/sponsors.md index 532ce77beb7b8..d167b66267a4d 100644 --- a/docs/source/community/sponsors.md +++ b/docs/source/community/sponsors.md @@ -12,6 +12,7 @@ vLLM is a community project. Our compute resources for development and testing a - Crusoe Cloud - Databricks - DeepInfra +- Dropbox - Lambda Lab - NVIDIA - Replicate From 8f2f226d90a94a38e87626c3cc9bca3e4675aa92 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 28 May 2024 19:11:55 +0000 Subject: [PATCH 353/413] Removed HIP specific matvec logic that is duplicated from tuned_gemm.py and doesn't support bf16 --- vllm/model_executor/layers/linear.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 227ec7c848d4b..fd809d6107835 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -15,7 +15,6 @@ from vllm.model_executor.parallel_utils.utils import ( divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip logger = init_logger(__name__) @@ -75,30 +74,6 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = weights["weight"] - if is_hip() and x.view(-1, x.size(-1)).shape[0] == 1: - batched = False - if x.dim() == 3: - inp = x.view(-1, x.size(-1)) - batched = True - else: - inp = x - m, k = weight.shape[0], inp.shape[1] - out = torch.empty(inp.shape[0], - weight.shape[0], - dtype=inp.dtype, - device='cuda') - if (k == 8192 and - (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): - _custom_C.LLMM1(weight, inp, out, 8) - elif k <= 8192 and k % 8 == 0 and m % 4 == 0: - _custom_C.LLMM1(weight, inp, out, 4) - else: - out = F.linear(inp, weight) - if batched: - out = out.view(x.shape[0], x.shape[1], weight.shape[0]) - if bias is not None: - out = out + bias - return out if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias From e9fdf71a70373d1dab2ad9d8590868d774efb1b3 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 28 May 2024 19:14:24 +0000 Subject: [PATCH 354/413] And an unused import --- vllm/model_executor/layers/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fd809d6107835..954da086a061c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm import _custom_C from vllm.logger import init_logger from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.model_executor.parallel_utils.communication_op import ( From 5ae5ed1e6047d4095149e26526a618be0529a118 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 29 May 2024 04:29:31 +0800 Subject: [PATCH 355/413] [Core] Consolidate prompt arguments to LLM engines (#4328) Co-authored-by: Roger Wang --- .buildkite/test-pipeline.yaml | 9 +- benchmarks/benchmark_latency.py | 11 +- .../{ => dev}/offline_inference/llm.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 14 + .../dev/offline_inference/offline_index.rst | 8 + .../sampling_params.rst | 0 docs/source/index.rst | 11 +- .../serving/openai_compatible_server.md | 4 +- examples/llava_example.py | 25 +- pyproject.toml | 7 + tests/async_engine/test_async_llm_engine.py | 2 +- tests/async_engine/test_openapi_server_ray.py | 2 +- tests/conftest.py | 23 +- tests/core/test_block_manager.py | 15 +- tests/core/utils.py | 15 +- tests/engine/test_skip_tokenizer_init.py | 2 +- tests/entrypoints/openai/test_serving_chat.py | 4 + tests/entrypoints/test_guided_processors.py | 2 + tests/entrypoints/test_llm_encode.py | 144 +++++++ tests/entrypoints/test_llm_generate.py | 137 +++++- tests/entrypoints/test_openai_server.py | 34 +- .../test_server_oot_registration.py | 11 +- tests/lora/test_long_context.py | 8 +- tests/samplers/test_logits_processor.py | 11 +- tests/samplers/test_seeded_generate.py | 6 +- tests/test_cache_block_hashing.py | 11 +- tests/test_inputs.py | 53 +++ tests/test_utils.py | 63 +++ tests/tokenization/test_detokenize.py | 7 +- tests/utils.py | 14 + vllm/__init__.py | 4 + vllm/engine/async_llm_engine.py | 171 ++++---- vllm/engine/llm_engine.py | 269 ++++++++---- vllm/engine/output_processor/util.py | 10 +- vllm/entrypoints/llm.py | 404 +++++++++++++----- vllm/entrypoints/openai/serving_chat.py | 12 +- vllm/entrypoints/openai/serving_completion.py | 17 +- vllm/entrypoints/openai/serving_embedding.py | 42 +- vllm/entrypoints/openai/serving_engine.py | 6 +- vllm/inputs.py | 130 ++++++ vllm/outputs.py | 42 +- vllm/sequence.py | 38 +- vllm/utils.py | 43 +- 43 files changed, 1404 insertions(+), 439 deletions(-) rename docs/source/{ => dev}/offline_inference/llm.rst (86%) create mode 100644 docs/source/dev/offline_inference/llm_inputs.rst create mode 100644 docs/source/dev/offline_inference/offline_index.rst rename docs/source/{offline_inference => dev}/sampling_params.rst (100%) create mode 100644 tests/entrypoints/test_llm_encode.py create mode 100644 tests/test_inputs.py create mode 100644 tests/test_utils.py create mode 100644 vllm/inputs.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index def8a460e84a7..08e132d0c68bf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -63,9 +63,9 @@ steps: mirror_hardwares: [amd] commands: - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s test_inputs.py + - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m openai - label: Examples Test working_dir: "/vllm-workspace/examples" @@ -110,6 +110,9 @@ steps: mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a9657f7859750..3146fb33cc27e 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -3,13 +3,14 @@ import json import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -48,7 +49,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -59,13 +62,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/docs/source/offline_inference/llm.rst b/docs/source/dev/offline_inference/llm.rst similarity index 86% rename from docs/source/offline_inference/llm.rst rename to docs/source/dev/offline_inference/llm.rst index 1a443ea406994..83ba1b6987c6d 100644 --- a/docs/source/offline_inference/llm.rst +++ b/docs/source/dev/offline_inference/llm.rst @@ -1,5 +1,5 @@ LLM Class -========== +========= .. autoclass:: vllm.LLM :members: diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst new file mode 100644 index 0000000000000..31c3d16a3c8eb --- /dev/null +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -0,0 +1,14 @@ +LLM Inputs +========== + +.. autodata:: vllm.inputs.PromptStrictInputs + +.. autoclass:: vllm.inputs.TextPrompt + :show-inheritance: + :members: + :member-order: bysource + +.. autoclass:: vllm.inputs.TokensPrompt + :show-inheritance: + :members: + :member-order: bysource diff --git a/docs/source/dev/offline_inference/offline_index.rst b/docs/source/dev/offline_inference/offline_index.rst new file mode 100644 index 0000000000000..27dfb0e9df90e --- /dev/null +++ b/docs/source/dev/offline_inference/offline_index.rst @@ -0,0 +1,8 @@ +Offline Inference +================================= + +.. toctree:: + :maxdepth: 1 + + llm + llm_inputs diff --git a/docs/source/offline_inference/sampling_params.rst b/docs/source/dev/sampling_params.rst similarity index 100% rename from docs/source/offline_inference/sampling_params.rst rename to docs/source/dev/sampling_params.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 5db1c9346c45d..5f18fe9ae0a73 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -68,13 +68,6 @@ Documentation getting_started/quickstart getting_started/examples/examples_index -.. toctree:: - :maxdepth: 1 - :caption: Offline Inference - - offline_inference/llm - offline_inference/sampling_params - .. toctree:: :maxdepth: 1 :caption: Serving @@ -108,7 +101,9 @@ Documentation .. toctree:: :maxdepth: 2 :caption: Developer Documentation - + + dev/sampling_params + dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a775c6addf1d9..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python diff --git a/examples/llava_example.py b/examples/llava_example.py index 3d22b492654bf..60250c4303fbf 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -23,11 +23,15 @@ def run_llava_pixel_values(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_pixel_values.pt") + image = torch.load("images/stop_sign_pixel_values.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) for o in outputs: generated_text = o.outputs[0].text print(generated_text) @@ -46,11 +50,14 @@ def run_llava_image_features(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_image_features.pt") - - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + image = torch.load("images/stop_sign_image_features.pt") + + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/pyproject.toml b/pyproject.toml index 96f78c37cfefb..0e9096fb4c035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" [tool.isort] use_parentheses = true skip_gitignore = true + +[tool.pytest.ini_options] +markers = [ + "skip_global_cleanup", + "llm: run tests for vLLM API only", + "openai: run tests for OpenAI API only", +] diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index b69cdc0a21409..10a46422887e3 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,7 +25,7 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async(self, *args, **kwargs): + async def process_model_inputs_async(self, *args, **kwargs): pass def generate(self, request_id): diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index ace4c53916c71..7a8d4b3915617 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -29,7 +29,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/tests/conftest.py b/tests/conftest.py index c1a44a606e1bf..af04cfbbb9902 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.sequence import MultiModalData @@ -402,12 +403,22 @@ def generate( ) -> List[Tuple[List[int], str]]: if images is not None: assert len(prompts) == images.shape[0] - req_outputs = self.model.generate( - prompts, - sampling_params=sampling_params, - multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) - if images is not None else None) + + prompt_inputs: List[PromptInputs] = [] + for i, prompt in enumerate(prompts): + image = None if images is None else images[i:i + 1] + mm_data = None if image is None else MultiModalData( + type=MultiModalData.Type.IMAGE, + data=image, + ) + + prompt_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 22a9f0cf47d32..88cd4f98091f9 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -133,8 +133,11 @@ def test_append_slot_cow(): # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, - prompt="one two three", - prompt_token_ids=[1, 2, 3], + inputs={ + "prompt": "one two three", + "prompt_token_ids": [1, 2, 3], + "multi_modal_data": None + }, block_size=block_size) # Fork the sequence, such that a COW will be required when we append a new @@ -304,7 +307,13 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - parent = Sequence(1, "one two three", [0, 1, 2], block_size) + parent = Sequence(seq_id=1, + inputs={ + "prompt": "one two three", + "prompt_token_ids": [0, 1, 2], + "multi_modal_data": None + }, + block_size=block_size) seq_group = SequenceGroup(request_id="1", seqs=[parent], arrival_time=time.time(), diff --git a/tests/core/utils.py b/tests/core/utils.py index 8fb13177a2d6c..1c5724090b69b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -21,7 +21,13 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence(int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup(request_id=request_id, seqs=[prompt], arrival_time=time.time(), @@ -51,8 +57,11 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index baa463a316902..338b208723ba9 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str): with pytest.raises(ValueError) as err: llm.generate("abc", sampling_params) assert "prompts must be None if" in str(err.value) - outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 completions = outputs[0].outputs diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 74b49726734b5..c45f02fe564a3 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,11 +1,15 @@ import asyncio from dataclasses import dataclass +import pytest + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +pytestmark = pytest.mark.openai + @dataclass class MockModelConfig: diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 41c871ca40bc8..5d4163e96fd87 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -52,6 +52,8 @@ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +pytestmark = pytest.mark.openai + def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py new file mode 100644 index 0000000000000..7c3fbe43a8384 --- /dev/null +++ b/tests/entrypoints/test_llm_encode.py @@ -0,0 +1,144 @@ +import weakref +from typing import List + +import pytest + +from vllm import LLM, EmbeddingRequestOutput, PoolingParams + +from ..conftest import cleanup + +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + # Using ID={0, 1, 2, 3} results in NaN values, + # so we add this offset of 1000 + [1000], + [1000, 1001], + [1000, 1002, 1001], + [1000, 1003, 1001, 1002], +] + +pytestmark = pytest.mark.llm + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[EmbeddingRequestOutput], + o2: List[EmbeddingRequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=prompt_token_ids, + pooling_params=pooling_params) + + v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, + pooling_params=pooling_params) + + v2_output = llm.encode( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_pooling_params(llm: LLM): + pooling_params = [ + PoolingParams(), + PoolingParams(), + PoolingParams(), + PoolingParams(), + ] + + # Multiple PoolingParams should be matched with each prompt + outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + + # Single PoolingParams should be applied to every prompt + single_pooling_params = PoolingParams() + outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + assert len(PROMPTS) == len(outputs) + + # pooling_params is None, default params should be applied + outputs = llm.encode(PROMPTS, pooling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 5e8b7ca4d9977..a00fff91a310e 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,21 +1,124 @@ +import weakref +from typing import List + import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, RequestOutput, SamplingParams + +from ..conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] -def test_multiple_sampling_params(): +pytestmark = pytest.mark.llm - llm = LLM(model="facebook/opt-125m", + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=4096, - tensor_parallel_size=1) + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup() + + +def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, + sampling_params=sampling_params) + + v2_output = llm.generate( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] +@pytest.mark.skip_global_cleanup +def test_multiple_sampling_params(llm: LLM): sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95), @@ -24,18 +127,18 @@ def test_multiple_sampling_params(): ] # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params) + assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3]) # Single SamplingParams should be applied to every prompt single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) - outputs = llm.generate(prompts, sampling_params=single_sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params) + assert len(PROMPTS) == len(outputs) # sampling_params is None, default params should be applied - outputs = llm.generate(prompts, sampling_params=None) - assert len(prompts) == len(outputs) \ No newline at end of file + outputs = llm.generate(PROMPTS, sampling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1b04e3205c4b8..2463ccde2bc8b 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,7 +71,7 @@ "Swift", "Kotlin" ] -pytestmark = pytest.mark.asyncio +pytestmark = pytest.mark.openai @pytest.fixture(scope="session") @@ -91,6 +91,8 @@ def server(zephyr_lora_files): "--max-model-len", "8192", "--enforce-eager", + "--gpu-memory-utilization", + "0.75", # lora config below "--enable-lora", "--lora-modules", @@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files): # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", + "--enforce-eager", + "--gpu-memory-utilization", + "0.75", "--max-model-len", "8192", - "--enforce-eager", ]) ray.get(server_runner.ready.remote()) yield server_runner @@ -136,6 +140,7 @@ def client(): yield client +@pytest.mark.asyncio async def test_check_models(server, client: openai.AsyncOpenAI): models = await client.models.list() models = models.data @@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): assert lora_models[1].id == "zephyr-lora2" +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs.top_logprobs is None +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +@pytest.mark.asyncio async def test_logits_bias(server, client: openai.AsyncOpenAI): prompt = "Hello, my name is" max_tokens = 5 @@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_completion(server, client: openai.AsyncOpenAI, @@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_chat(server, client: openai.AsyncOpenAI, @@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI, assert json1["age"] != json2["age"] +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, @@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, @@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, assert ip1 != ip2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, @@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, assert completion.choices[i].text in TEST_CHOICE +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, @@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, assert choice1 != choice2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, @@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, @@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token, logprob in token_dict.items()) +@pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( @@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio async def test_extra_fields(server, client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( @@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +@pytest.mark.asyncio async def test_complex_message_content(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, @@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +@pytest.mark.asyncio async def test_custom_role(server, client: openai.AsyncOpenAI): # Not sure how the model handles custom roles so we just check that # both string and complex message content are handled in the same way @@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI): assert content1 == content2 +@pytest.mark.asyncio async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -847,6 +873,7 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +@pytest.mark.asyncio async def test_long_seed(server, client: openai.AsyncOpenAI): for seed in [ torch.iinfo(torch.long).min - 1, @@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], @@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 5 +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 22e65bf7e7da1..3e55d7f4297fb 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -1,7 +1,7 @@ -import multiprocessing import sys import time +import pytest import torch from openai import OpenAI, OpenAIError @@ -10,6 +10,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +pytestmark = pytest.mark.openai + class MyOPTForCausalLM(OPTForCausalLM): @@ -26,15 +28,16 @@ def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - f" float32 --api-key token-abc123 --port {port}").split() + ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): port = get_open_port() - server = multiprocessing.Process(target=server_function, args=(port, )) + ctx = torch.multiprocessing.get_context() + server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( base_url=f"http://localhost:{port}/v1", diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a539..4361e5452cdff 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -86,20 +86,18 @@ def generate( def batched_generate( - llm, + llm: vllm.LLM, inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], ): for input in inputs: prompt, sampling_param, lora_req = input - requests_data = llm._validate_and_prepare_requests( + # Add requests to the engine and run the engine + llm._validate_and_add_requests( prompt, sampling_param, lora_request=lora_req, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - llm._add_request(**request_data) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index be4c2ea1b7810..0ccbabfff6403 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -35,28 +35,25 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[0], + example_prompts[0], params=params_with_logprobs, - prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[1], + example_prompts[1], params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), - prompt_token_ids=None, ) # test grouped requests vllm_model.model._add_request( - prompt=example_prompts[2], + example_prompts[2], params=SamplingParams(max_tokens=max_tokens), - prompt_token_ids=None, ) - outputs = vllm_model.model._run_engine(False) + outputs = vllm_model.model._run_engine(use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index ce4501bbf71e5..fef5ff3fb9e8e 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -57,11 +57,7 @@ def test_random_sample_with_seed( sampling_params_seed_1, sampling_params_seed_2, ): - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - params=params, - ) + llm._add_request(prompt, params=params) results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3b257ac062f56..97864af88e40a 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id, lora_request) + seq = Sequence(seq_id, + inputs={ + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=block_size, + eos_token_id=tokenizer.tokenizer.eos_token_id, + lora_request=lora_request) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000000000..887c7101decda --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest + +from vllm.inputs import parse_and_batch_prompt + +STRING_INPUTS = [ + '', + 'foo', + 'foo bar', + 'foo baz bar', + 'foo bar qux baz', +] + +TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], +] + +INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), +] + + +def test_parse_single_batch_empty(): + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([]) + + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([[]]) + + +@pytest.mark.parametrize('string_input', STRING_INPUTS) +def test_parse_single_batch_string_consistent(string_input: str): + assert parse_and_batch_prompt(string_input) \ + == parse_and_batch_prompt([string_input]) + + +@pytest.mark.parametrize('token_input', TOKEN_INPUTS) +def test_parse_single_batch_token_consistent(token_input: List[int]): + assert parse_and_batch_prompt(token_input) \ + == parse_and_batch_prompt([token_input]) + + +@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) +def test_parse_single_batch_string_slice(inputs_slice: slice): + assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ + == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000..54dc5c6f5bfba --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,63 @@ +import pytest + +from vllm.utils import deprecate_kwargs + +from .utils import error_on_warning + + +def test_deprecate_kwargs_always(): + + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_dynamic(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9bc9becb2a6f1..1d4c74d6bd8da 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/utils.py b/tests/utils.py index 689d8c8c5ba8a..329842911e159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import subprocess import sys import time +import warnings +from contextlib import contextmanager import ray import requests @@ -87,3 +89,15 @@ def multi_process_tensor_parallel( ray.get(refs) ray.shutdown() + + +@contextmanager +def error_on_warning(): + """ + Within the scope of this context manager, tests will fail if any warning + is emitted. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + yield diff --git a/vllm/__init__.py b/vllm/__init__.py index 74674ca0d12af..a0e154d24087c 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -16,6 +17,9 @@ __all__ = [ "LLM", "ModelRegistry", + "PromptStrictInputs", + "TextPrompt", + "TokensPrompt", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a15ed67e3327..d4289c715d9e6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,12 +12,13 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -244,64 +245,69 @@ async def step_async( return request_outputs - async def encode_request_async( + async def process_model_inputs_async( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = await self.tokenizer.encode_async( + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = await tokenizer.encode_async( request_id=request_id, - prompt=prompt, + prompt=inputs["prompt"], lora_request=lora_request) - return prompt_token_ids + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = await self.encode_request_async( + + processed_inputs = await self.process_model_inputs_async( + request_id=request_id, inputs=inputs, lora_request=lora_request) + + self._add_processed_request( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - return self.add_request(request_id, - prompt=prompt, - params=params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def check_health_async(self) -> None: self.model_executor.check_health() class AsyncLLMEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for :class:`LLMEngine`. - This class is used to wrap the LLMEngine class to make it asynchronous. It - uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there - are requests in the waiting queue. The generate method yields the outputs - from the LLMEngine to the caller. + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. - NOTE: For the comprehensive list of arguments, see `LLMEngine`. + NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for @@ -315,8 +321,8 @@ class AsyncLLMEngine: being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for LLMEngine. - *kwargs: Arguments for LLMEngine. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine @@ -526,22 +532,26 @@ async def run_engine_loop(self): async def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: - shortened_prompt = prompt - shortened_token_ids = prompt_token_ids - if self.max_log_len is not None: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:self.max_log_len] + shortened_prompt = shortened_prompt[:max_log_len] if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:self. - max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] + logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " @@ -562,39 +572,33 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await ( - self.engine.encode_request_async.remote( # type: ignore + processed_inputs = await self.engine.process_model_inputs_async \ + .remote( # type: ignore request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request)) + inputs=inputs, + lora_request=lora_request) else: - prompt_token_ids = await self.engine.encode_request_async( + processed_inputs = await self.engine.process_model_inputs_async( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, + inputs=inputs, lora_request=lora_request) stream = self._request_tracker.add_request( request_id, - prompt=prompt, + inputs=processed_inputs, params=params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) return stream async def generate( self, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -603,14 +607,12 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `RequestOutput` objects from the LLMEngine @@ -659,24 +661,20 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, - prompt, + inputs, sampling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, - prompt: Optional[str], + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -685,14 +683,12 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `EmbeddingRequestOutput` objects from the LLMEngine @@ -739,24 +735,21 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, - prompt, + inputs, pooling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) - async def process_request( + async def _process_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, + *, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -764,12 +757,10 @@ async def process_request( stream = await self.add_request( request_id, - prompt, + inputs, params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0631c0de76822..08bccf209b7c4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,8 @@ import time -from typing import Iterable, List, Optional, Type, Union +from contextlib import contextmanager +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, TypeVar, Union from transformers import GenerationConfig, PreTrainedTokenizer @@ -18,6 +21,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, @@ -25,8 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - MultiModalData, PoolerOutput, SamplerOutput, - Sequence, SequenceGroup, SequenceGroupMetadata, + PoolerOutput, SamplerOutput, Sequence, + SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): return {} +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -60,11 +67,11 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. + NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs` + class. For the comprehensive list of arguments, see :ref:`engine_args`. Args: model_config: The configuration related to the LLM model. @@ -81,9 +88,60 @@ class LLMEngine: executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection + usage_context: Specified entry point, used for usage info collection. """ + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] + def __init__( self, model_config: ModelConfig, @@ -151,12 +209,11 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer: BaseTokenizerGroup - self._init_tokenizer() + self.tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) else: - self.detokenizer = None self.tokenizer = None + self.detokenizer = None self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( @@ -318,14 +375,26 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() + MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + def get_tokenizer_group( + self, + fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(None) + return self.get_tokenizer_group().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + return self.get_tokenizer_group().get_lora_tokenizer( + sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), @@ -335,8 +404,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( - self.parallel_config.tokenizer_pool_config, **init_kwargs) + + return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, + **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -346,29 +416,85 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def encode_request( + def _get_eos_token_id( + self, lora_request: Optional[LoRARequest]) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def _add_processed_request( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + processed_inputs: LLMInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> None: + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self._get_eos_token_id(lora_request) + + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def process_model_inputs( + self, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - return prompt_token_ids + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -378,15 +504,14 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. - params: Parameters for sampling or pooling. SamplingParams - for text generation. PoolingParams for pooling. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -417,59 +542,26 @@ def add_request( "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = self.encode_request( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = None - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, lora_request) - # Create a SequenceGroup based on SamplingParams or PoolingParams - if isinstance(params, SamplingParams): - seq_group = self._create_sequence_group_with_sampling( - request_id, - seq, - params, - arrival_time, - lora_request, - multi_modal_data, - ) - elif isinstance(params, PoolingParams): - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, - params, - arrival_time, - lora_request, - multi_modal_data, - ) - else: - raise ValueError( - "Either SamplingParams or PoolingParams must be provided.") + processed_inputs = self.process_model_inputs(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + ) def _create_sequence_group_with_sampling( self, request_id: str, seq: Sequence, sampling_params: SamplingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + arrival_time: float, + lora_request: Optional[LoRARequest], ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -495,8 +587,7 @@ def _create_sequence_group_with_sampling( seqs=[seq], arrival_time=arrival_time, sampling_params=sampling_params, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) return seq_group @@ -505,9 +596,8 @@ def _create_sequence_group_with_pooling( request_id: str, seq: Sequence, pooling_params: PoolingParams, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, + arrival_time: float, + lora_request: Optional[LoRARequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -517,7 +607,6 @@ def _create_sequence_group_with_pooling( seqs=[seq], arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, pooling_params=pooling_params) return seq_group @@ -570,7 +659,7 @@ def _process_sequence_group_outputs( def _process_model_outputs( self, - output: List[Union[SamplerOutput, PoolerOutput]], + output: GenericSequence[Union[SamplerOutput, PoolerOutput]], scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], @@ -585,7 +674,7 @@ def _process_model_outputs( # Organize outputs by [sequence group][step] instead of # [step][sequence group]. output_by_sequence_group = create_output_by_sequence_group( - sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) + output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. for scheduled_seq_group, outputs, seq_group_meta in zip( diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 9816e966c1e36..57cc33d911183 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,18 +1,20 @@ from typing import List +from typing import Sequence as GenericSequence +from typing import Union -from vllm.sequence import SamplerOutput, SequenceGroupOutput +from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput def create_output_by_sequence_group( - sampler_outputs: List[SamplerOutput], + outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ - output_by_sequence_group: List[List[SamplerOutput]] = [ + output_by_sequence_group: List[List[SequenceGroupOutput]] = [ [] for _ in range(num_seq_groups) ] - for step in sampler_outputs: + for step in outputs: for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 25f4428100b27..9759d05577796 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,11 +1,14 @@ -from typing import List, Optional, Union +from contextlib import contextmanager +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload -import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, + parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -13,7 +16,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter +from vllm.utils import Counter, deprecate_kwargs logger = init_logger(__name__) @@ -28,8 +31,10 @@ class LLM: mechanism and efficient memory management. NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + + NOTE: For the comprehensive list of arguments, see + :class:`~vllm.EngineArgs`. Args: model: The name or path of a HuggingFace Transformers model. @@ -81,6 +86,18 @@ class LLM: disable_custom_all_reduce: See ParallelConfig """ + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + def __init__( self, model: str, @@ -138,15 +155,101 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + @overload # LEGACY: single (prompt + optional token ids) + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) def generate( self, - prompts: Optional[Union[str, List[str]]] = None, + prompts: List[str], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload + def generate( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def generate( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -155,49 +258,138 @@ def generate( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() - requests_data = self._validate_and_prepare_requests( - prompts, - sampling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + self._validate_and_add_requests( + inputs=inputs, + params=sampling_params, + lora_request=lora_request, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - self._add_request(**request_data) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) - return self._run_engine(use_tqdm) + @overload # LEGACY: single (prompt + optional token ids) + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + @overload # LEGACY: multi (prompt + optional token ids) def encode( self, - prompts: Optional[Union[str, List[str]]] = None, + prompts: List[str], pooling_params: Optional[Union[PoolingParams, - List[PoolingParams]]] = None, + Sequence[PoolingParams]]] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def encode( + self, + prompts: Optional[List[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload + def encode( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") + def encode( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -206,124 +398,133 @@ def encode( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `EmbeddingRequestOutput` objects containing the generated embeddings in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + inputs = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + else: + inputs = cast( + Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts) + if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - requests_data = self._validate_and_prepare_requests( - prompts, - pooling_params, - prompt_token_ids, - lora_request, - multi_modal_data, + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + lora_request=lora_request, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - self._add_request(**request_data) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) - return self._run_engine(use_tqdm) - - def _validate_and_prepare_requests( + # LEGACY + def _convert_v1_inputs( self, prompts: Optional[Union[str, List[str]]], - params: Union[Union[SamplingParams, PoolingParams], - List[Union[SamplingParams, - PoolingParams]]], # Unified parameter - prompt_token_ids: Optional[List[List[int]]] = None, - lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, - ) -> List[dict]: - """Validates and prepares request data for adding to the engine. + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + multi_modal_data: Optional[MultiModalData], + ): + # skip_tokenizer_init is now checked in engine - Ensures prompts and token IDs are consistent, and returns a list of - dictionaries with request data for further processing. - """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if self.llm_engine.model_config.skip_tokenizer_init \ - and prompts is not None: - raise ValueError("prompts must be None if skip_tokenizer_init " - "is True") - if isinstance(prompts, str): - # Convert a single prompt to a list. - prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + num_requests = None if prompts is not None: num_requests = len(prompts) - else: - assert prompt_token_ids is not None + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + inputs: List[PromptInputs] = [] + for i in range(num_requests): + if prompts is not None: + if prompt_token_ids is not None: + item = TextTokensPrompt( + prompt=prompts[i], + prompt_token_ids=prompt_token_ids[i]) + else: + item = TextPrompt(prompt=prompts[i]) + else: + if prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + if multi_modal_data is not None: + item["multi_modal_data"] = multi_modal_data + + inputs.append(item) + + return inputs + + def _validate_and_add_requests( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[LoRARequest], + ) -> None: + if isinstance(inputs, (str, dict)): + # Convert a single prompt to a list. + inputs = [inputs] + + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") - if multi_modal_data: - multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - requests_data = [] - for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] - - multi_modal_item = MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0), - ) if multi_modal_data else None - - requests_data.append({ - "prompt": - prompt, - "params": - params[i] if isinstance(params, list) else params, - "prompt_token_ids": - token_ids, - "lora_request": - lora_request, - "multi_modal_data": - multi_modal_item, - }) - - return requests_data + for i, request_inputs in enumerate(inputs): + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request, + ) def _add_request( self, - prompt: Optional[str], + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, - prompt, + inputs, params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) def _run_engine( - self, use_tqdm: bool + self, *, use_tqdm: bool ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: @@ -355,5 +556,4 @@ def _run_engine( # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7e179362eef8a..33daabd881df0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -176,9 +176,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt_text, sampling_params, - request_id, prompt_ids, - lora_request) + result_generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + request_id, + lora_request, + ) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 158d8ed7fbbf5..d1812c8f44f41 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -119,12 +119,17 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - sampling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids, - lora_request=lora_request)) + generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + f"{request_id}-{i}", + lora_request=lora_request, + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7a57be0c88915..5a3448de3d7a4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,5 +1,5 @@ import time -from typing import AsyncIterator, List, Tuple +from typing import AsyncIterator, List, Optional, Tuple from fastapi import Request @@ -100,11 +100,16 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - pooling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids)) + generator = self.engine.encode( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + pooling_params, + f"{request_id}-{i}", + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -113,16 +118,21 @@ async def create_embedding(self, request: EmbeddingRequest, int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0df0223b9dbb2..708b0dad102c4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -143,7 +143,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -155,7 +156,8 @@ async def _check_model( status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None diff --git a/vllm/inputs.py b/vllm/inputs.py new file mode 100644 index 0000000000000..f5d99b1b66b70 --- /dev/null +++ b/vllm/inputs.py @@ -0,0 +1,130 @@ +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, + TypedDict, Union, cast, overload) + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from vllm.sequence import MultiModalData + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0], str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) + for elem in cast(List[str], prompt) + ] + if isinstance(prompt[0], int): + # case 3: array of tokens + elem = cast(List[int], prompt) + return [ParsedTokens(content=elem, is_tokens=True)] + if isinstance(prompt[0], list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0][0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in cast(List[List[int]], prompt) + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +class TextTokensPrompt(TypedDict): + """It is assumed that :attr:`prompt` is consistent with + :attr:`prompt_token_ids`. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + prompt: str + """The prompt text.""" + + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + multi_modal_data: NotRequired["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +""" +The inputs to the LLM, which can take one of the following forms: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" + +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +"""Same as :const:`PromptStrictInputs` but additionally accepts +:class:`TextTokensPrompt`.""" + + +class LLMInputs(TypedDict): + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalData"] diff --git a/vllm/outputs.py b/vllm/outputs.py index f9bce9e683f22..49f526b5f9300 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import List, Optional, Union from vllm.lora.request import LoRARequest @@ -6,6 +7,7 @@ SequenceGroup, SequenceStatus) +@dataclass class CompletionOutput: """The output data of one completion output of a request. @@ -24,25 +26,14 @@ class CompletionOutput: lora_request: The LoRA request that was used to generate the output. """ - def __init__( - self, - index: int, - text: str, - token_ids: List[int], - cumulative_logprob: float, - logprobs: Optional[SampleLogprobs], - finish_reason: Optional[str] = None, - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.index = index - self.text = text - self.token_ids = token_ids - self.cumulative_logprob = cumulative_logprob - self.logprobs = logprobs - self.finish_reason = finish_reason - self.stop_reason = stop_reason - self.lora_request = lora_request + index: int + text: str + token_ids: List[int] + cumulative_logprob: float + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None def finished(self) -> bool: return self.finish_reason is not None @@ -57,6 +48,7 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +@dataclass class EmbeddingOutput: """The output data of one completion output of a request. @@ -65,15 +57,11 @@ class EmbeddingOutput: length of vector depends on the model as listed in the embedding guide. """ - def __init__( - self, - embedding: List[float], - ) -> None: - self.embedding = embedding + embedding: List[float] def __repr__(self) -> str: return (f"EmbeddingOutput(" - f"embedding={len(self.embedding)}") + f"embedding={len(self.embedding)})") class RequestOutput: @@ -93,7 +81,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: str, + prompt: Optional[str], prompt_token_ids: List[int], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], @@ -183,7 +171,7 @@ class EmbeddingRequestOutput: finished (bool): A flag indicating whether the embedding is completed. """ - def __init__(self, request_id: str, outputs: 'EmbeddingOutput', + def __init__(self, request_id: str, outputs: "EmbeddingOutput", prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids diff --git a/vllm/sequence.py b/vllm/sequence.py index aa759448d82b1..f8e9da6c7965a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock +from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -210,8 +211,7 @@ class Sequence: Args: seq_id: The ID of the sequence. - prompt: The prompt of the sequence. - prompt_token_ids: The token IDs of the prompt. + inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -220,25 +220,24 @@ class Sequence: def __init__( self, seq_id: int, - prompt: str, - prompt_token_ids: List[int], + inputs: LLMInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id - self.prompt = prompt + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data: SequenceData = SequenceData(prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -248,6 +247,18 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def prompt(self) -> Optional[str]: + return self.inputs["prompt"] + + @property + def prompt_token_ids(self) -> List[int]: + return self.inputs["prompt_token_ids"] + + @property + def multi_modal_data(self) -> Optional["MultiModalData"]: + return self.inputs["multi_modal_data"] + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -415,7 +426,6 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - multi_modal_data: Multi modal data associated with the request. embeddings: The embeddings vectors of the prompt of the sequence group for an embedding model. pooling_params: The pooling parameters used to generate the pooling @@ -429,7 +439,6 @@ def __init__( arrival_time: float, sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, ) -> None: @@ -444,12 +453,11 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.multi_modal_data = multi_modal_data self.embeddings = embeddings self.pooling_params = pooling_params @property - def prompt(self) -> str: + def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt @@ -458,7 +466,13 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).prompt_token_ids + + @property + def multi_modal_data(self) -> Optional[MultiModalData]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).multi_modal_data @property def lora_int_id(self) -> int: diff --git a/vllm/utils.py b/vllm/utils.py index 4cb9d905097bf..c8bc54dab41b3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -11,7 +11,7 @@ import uuid import warnings from collections import defaultdict -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None: filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) + + +def identity(value: T) -> T: + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper From dfba529b4024fba9ce1346467318b35e8f2fa9d9 Mon Sep 17 00:00:00 2001 From: Junichi Sato Date: Wed, 29 May 2024 09:15:35 +0900 Subject: [PATCH 356/413] [Bugfix] Remove the last EOS token unless explicitly specified (#5077) --- .../output_processor/test_stop_checker.py | 86 +++++++++++++++++++ vllm/engine/output_processor/stop_checker.py | 5 ++ 2 files changed, 91 insertions(+) create mode 100644 tests/engine/output_processor/test_stop_checker.py diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py new file mode 100644 index 0000000000000..ae54c83605e11 --- /dev/null +++ b/tests/engine/output_processor/test_stop_checker.py @@ -0,0 +1,86 @@ +from unittest.mock import MagicMock + +import pytest +from transformers import PreTrainedTokenizer + +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob, Sequence, SequenceStatus + + +def sequence_with_eos(text: str, eos_token: str, + eos_token_id: int) -> Sequence: + """ + Create a Sequence that ends with an EOS token. + """ + seq = Sequence( + seq_id=0, + prompt="", + prompt_token_ids=[], + block_size=16, + eos_token_id=eos_token_id, + ) + seq.output_text = text + eos_token + + offset = eos_token_id + 1 + for i in range(offset, len(text) + offset): + seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)}) + seq.append_token_id(token_id=eos_token_id, + logprobs={eos_token_id: Logprob(0.0)}) + + seq.status = SequenceStatus.RUNNING + + return seq + + +@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ + ("This text ends with EOS token", "", 2), +]) +@pytest.mark.parametrize("ignore_eos", [True, False, None]) +@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None]) +@pytest.mark.skip_global_cleanup +def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, + ignore_eos: bool, include_stop_str_in_output: bool): + """ + Test the behavior of the StopChecker's maybe_stop_sequence method + when an EOS token is encountered. + + This test covers: + - When the EOS token should stop the sequence and be removed from the output + - When the EOS token should stop the sequence and be included in the output + - When the EOS token should be ignored, and the sequence continues + """ + + tokenizer = MagicMock(spec=PreTrainedTokenizer) + get_tokenizer_for_seq = MagicMock(return_value=tokenizer) + stop_checker = StopChecker(max_model_len=1024, + get_tokenizer_for_seq=get_tokenizer_for_seq) + + seq = sequence_with_eos( + text=text_wo_eos, + eos_token=eos_token, + eos_token_id=eos_token_id, + ) + new_char_count = len(eos_token) + + # Note that `stop` and `stop_token_ids` are not specified + sampling_params = SamplingParams( + min_tokens=1, + ignore_eos=ignore_eos, + include_stop_str_in_output=include_stop_str_in_output) + + stop_checker.maybe_stop_sequence( + seq=seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + if ignore_eos: + assert seq.status == SequenceStatus.RUNNING + assert seq.output_text == text_wo_eos + eos_token + elif include_stop_str_in_output: + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.output_text == text_wo_eos + eos_token + else: + assert seq.status == SequenceStatus.FINISHED_STOPPED + assert seq.output_text == text_wo_eos diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 5fb11b32bad6d..96f0d1142611b 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -48,6 +48,11 @@ def maybe_stop_sequence( # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] seq.status = SequenceStatus.FINISHED_STOPPED return From 616e600e0b092050213e79fd2a10baabb30dcf6d Mon Sep 17 00:00:00 2001 From: Marut Pandya Date: Tue, 28 May 2024 17:16:18 -0700 Subject: [PATCH 357/413] [Misc] add gpu_memory_utilization arg (#5079) Signed-off-by: pandyamarut --- benchmarks/benchmark_latency.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 3146fb33cc27e..f69d91a086a9f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -35,7 +35,8 @@ def main(args: argparse.Namespace): use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, - block_size=args.block_size) + block_size=args.block_size, + gpu_memory_utilization=args.gpu_memory_utilization) sampling_params = SamplingParams( n=args.n, @@ -214,5 +215,11 @@ def run_to_completion(profile_dir: Optional[str] = None): type=str, default=None, help='Path to save the latency results in JSON format.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') args = parser.parse_args() main(args) From 5bd3c650721cc5de451f034bcbed37d1a1a4116c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 28 May 2024 22:13:52 -0700 Subject: [PATCH 358/413] [Core][Optimization] remove vllm-nccl (#5091) --- .buildkite/test-pipeline.yaml | 1 - requirements-cuda.txt | 1 - setup.py | 7 +-- tests/distributed/test_pynccl_library.py | 43 ------------------- .../device_communicators/pynccl_wrapper.py | 20 +++------ vllm/utils.py | 43 ++++--------------- vllm/worker/worker_base.py | 6 ++- 7 files changed, 21 insertions(+), 100 deletions(-) delete mode 100644 tests/distributed/test_pynccl_library.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 08e132d0c68bf..21cbd9ba13780 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,7 +37,6 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 2 commands: - - pytest -v -s distributed/test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index acb0164007dba..5109f17356178 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,7 +4,6 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 nvidia-ml-py # for pynvml package -vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 diff --git a/setup.py b/setup.py index a66af2c5d556f..b4baebb0d4801 100644 --- a/setup.py +++ b/setup.py @@ -358,11 +358,8 @@ def _read_requirements(filename: str) -> List[str]: cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if "vllm-nccl-cu12" in req: - req = req.replace("vllm-nccl-cu12", - f"vllm-nccl-cu{cuda_major}") - elif ("vllm-flash-attn" in req - and not (cuda_major == "12" and cuda_minor == "1")): + if ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): # vllm-flash-attn is built only for CUDA 12.1. # Skip for other versions. continue diff --git a/tests/distributed/test_pynccl_library.py b/tests/distributed/test_pynccl_library.py deleted file mode 100644 index ec60a5ed3114d..0000000000000 --- a/tests/distributed/test_pynccl_library.py +++ /dev/null @@ -1,43 +0,0 @@ -import multiprocessing -import tempfile - - -def target_fn(env, filepath): - from vllm.utils import update_environment_variables - update_environment_variables(env) - from vllm.utils import nccl_integrity_check - nccl_integrity_check(filepath) - - -def test_library_file(): - # note: don't import vllm.distributed.device_communicators.pynccl - # before running this test, otherwise the library file will be loaded - # and it might interfere with the test - from vllm.utils import find_nccl_library - so_file = find_nccl_library() - with open(so_file, 'rb') as f: - content = f.read() - try: - # corrupt the library file, should raise an exception - with open(so_file, 'wb') as f: - f.write(content[:len(content) // 2]) - p = multiprocessing.Process(target=target_fn, args=({}, so_file)) - p.start() - p.join() - assert p.exitcode != 0 - - # move the library file to a tmp path - # test VLLM_NCCL_SO_PATH - fd, path = tempfile.mkstemp() - with open(path, 'wb') as f: - f.write(content) - p = multiprocessing.Process(target=target_fn, - args=({ - "VLLM_NCCL_SO_PATH": path - }, path)) - p.start() - p.join() - assert p.exitcode == 0 - finally: - with open(so_file, 'wb') as f: - f.write(content) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 3aa3744d0d827..50d6719fbfe62 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -28,7 +28,7 @@ from torch.distributed import ReduceOp from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check +from vllm.utils import find_nccl_library logger = init_logger(__name__) @@ -188,28 +188,22 @@ def __init__(self, so_file: Optional[str] = None): so_file = so_file or find_nccl_library() try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] except Exception as e: logger.error( "Failed to load NCCL library from %s ." "It is expected if you are not running on NVIDIA/AMD GPUs." "Otherwise, the nccl library might not exist, be corrupted " "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" " to point to the correct nccl library path.", so_file, platform.platform()) raise e - if so_file not in NCCLLibrary.path_to_dict_mapping: - lib = ctypes.CDLL(so_file) - NCCLLibrary.path_to_library_cache[so_file] = lib - self.lib = NCCLLibrary.path_to_library_cache[so_file] - if so_file not in NCCLLibrary.path_to_dict_mapping: _funcs = {} for func in NCCLLibrary.exported_functions: diff --git a/vllm/utils.py b/vllm/utils.py index c8bc54dab41b3..85e045cb3b768 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2,7 +2,6 @@ import datetime import enum import gc -import glob import os import socket import subprocess @@ -565,28 +564,6 @@ def init_cached_hf_modules(): init_hf_modules() -def nccl_integrity_check(filepath): - """ - when the library is corrupted, we cannot catch - the exception in python. it will crash the process. - instead, we use the exit code of `ldd` to check - if the library is corrupted. if not, we will return - the version of the library. - """ - exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null") - if exit_code != 0: - raise RuntimeError(f"Failed to load NCCL library from {filepath} .") - import ctypes - - nccl = ctypes.CDLL(filepath) - version = ctypes.c_int() - nccl.ncclGetVersion.restype = ctypes.c_int - nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - result = nccl.ncclGetVersion(ctypes.byref(version)) - assert result == 0 - return version.value - - @lru_cache(maxsize=None) def find_library(lib_name: str) -> str: """ @@ -616,17 +593,13 @@ def find_library(lib_name: str) -> str: def find_nccl_library(): + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ so_file = envs.VLLM_NCCL_SO_PATH - VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT - - # check if we have vllm-managed nccl - vllm_nccl_path = None - if torch.version.cuda is not None: - cuda_major = torch.version.cuda.split(".")[0] - path = os.path.expanduser( - f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*") - files = glob.glob(path) - vllm_nccl_path = files[0] if files else None # manually load the nccl library if so_file: @@ -635,9 +608,9 @@ def find_nccl_library(): so_file) else: if torch.version.cuda is not None: - so_file = vllm_nccl_path or find_library("libnccl.so.2") + so_file = "libnccl.so.2" elif torch.version.hip is not None: - so_file = find_library("librccl.so.1") + so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info("Found nccl from library %s", so_file) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dbac1b5ba339b..258f31de17d87 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -121,12 +121,14 @@ def update_environment_variables(envs: Dict[str, str]) -> None: def init_worker(self, *args, **kwargs): """ - Actual initialization of the worker class, and set up - function tracing if required. + Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ enable_trace_function_call_for_thread() + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) self.worker = worker_class(*args, **kwargs) From 18c1f16d86d5130ca989d32a3f05142a6652ba0d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 29 May 2024 15:16:41 +0800 Subject: [PATCH 359/413] [Bugfix] Fix arguments passed to `Sequence` in stop checker test (#5092) --- tests/engine/output_processor/test_stop_checker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index ae54c83605e11..1d9c878ddde50 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -15,8 +15,11 @@ def sequence_with_eos(text: str, eos_token: str, """ seq = Sequence( seq_id=0, - prompt="", - prompt_token_ids=[], + inputs={ + "prompt": "", + "prompt_token_ids": [], + "multi_modal_data": None, + }, block_size=16, eos_token_id=eos_token_id, ) From 594392d27a0dc3b1df84246afb46cc229946c0f3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 29 May 2024 04:29:07 -0700 Subject: [PATCH 360/413] [Core][Distributed] improve p2p access check (#4992) --- .../device_communicators/custom_all_reduce.py | 3 +- .../custom_all_reduce_utils.py | 186 ++++++++++++++++++ vllm/distributed/utils.py | 90 +-------- 3 files changed, 189 insertions(+), 90 deletions(-) create mode 100644 vllm/distributed/device_communicators/custom_all_reduce_utils.py diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 30ee9d1f8a1e9..a3902aecb3793 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -6,6 +6,8 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check) from vllm.distributed.parallel_state import ( get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger @@ -65,7 +67,6 @@ def _is_full_nvlink(device_ids: List[int]) -> bool: def _can_p2p(rank: int, world_size: int) -> bool: - from vllm.distributed.utils import gpu_p2p_access_check for i in range(world_size): if i == rank: continue diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000000000..24ef3cb45b19d --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,186 @@ +import json +import os +import sys +import tempfile +import time +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@contextmanager +def mute_output(): + with open(os.devnull, "w") as f: + sys.stderr = f + sys.stdout = f + yield + + +def producer(i: int, + init_method: str, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + with mute_output(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + world_size=2, + rank=0, + ) + # produce a tensor in GPU i + data = torch.zeros((128, ), device=f"cuda:{i}") + # get the information to reconstruct the shared tensor + func, args = torch.multiprocessing.reductions.reduce_tensor(data) + args = list(args) + dist.broadcast_object_list([(func, args)], src=0) + dist.barrier() + torch.cuda.synchronize() + assert torch.all(data == 1).item() + + +def consumer(j: int, + init_method: str, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + with mute_output(): + dist.init_process_group( + backend="gloo", + init_method=init_method, + world_size=2, + rank=1, + ) + torch.cuda.set_device(j) + recv = [None] + dist.broadcast_object_list(recv, src=0) + func: Callable + args: List + func, args = recv[0] # type: ignore + # `args[6]` is the device id + # by default pytorch will use `i` from the producer + # here we need to set it to `j` to test P2P access + args[6] = j + data = func(*args) + data += 1 + dist.barrier() + torch.cuda.synchronize() + assert torch.all(data == 1).item() + + +def can_actually_p2p(i, j): + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(i, j)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU i --> cuda context i --> tensor i --> process i + + We need to combine p2p and cuda IPC, so that: + GPU i --> cuda context i --> tensor i --> process i + |shared| + GPU j --> cuda context j --> tensor j --> process j + That is to say, process i creates a tensor in GPU i, passes IPC handle to + process j, and process j accesses the tensor in GPU j. Any operation on the + tensor in process j will be reflected in the tensor in process i, because + they are the same memory segment. + It is important to note that process j accesses the tensor in GPU j, not + GPU i. That's why we need p2p access. # noqa + """ + cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the temp file is not the same across different calls + temp_path = tempfile.mktemp() + str(time.time()) + # create an empty file + with open(temp_path, "w"): + pass + init_method = f"file://{temp_path}" + + # make sure the processes are spawned + smp = mp.get_context("spawn") + pi = smp.Process(target=producer, + args=(i, init_method, cuda_visible_devices)) + pj = smp.Process(target=consumer, + args=(j, init_method, cuda_visible_devices)) + pi.start() + pj.start() + pi.join() + pj.join() + return pi.exitcode == 0 and pj.exitcode == 0 + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(i: int, j: int) -> bool: + """Check if GPU i can access GPU j.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{i}->{j}"] + + is_distributed = dist.is_initialized() + + num_dev = torch.cuda.device_count() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT + path = os.path.expanduser( + f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) + os.makedirs(os.path.dirname(path), exist_ok=True) + if ((not is_distributed or get_local_rank() == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache for in %s", path) + cache = {} + for _i in range(num_dev): + for _j in range(num_dev): + cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j) + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + cpu_world_group = get_cpu_world_group() + dist.barrier(cpu_world_group) + logger.info("reading GPU P2P access cache from %s", path) + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{i}->{j}"] + + +__all__ = ["gpu_p2p_access_check"] diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 1965d4c1d3cbc..0cd420c8e11b5 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,19 +2,9 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -import json -import os -from typing import Dict, Optional, Sequence +from typing import Sequence import torch -import torch.distributed as dist - -import vllm.envs as envs -from vllm.logger import init_logger - -from .parallel_state import get_cpu_world_group, get_local_rank - -logger = init_logger(__name__) def ensure_divisibility(numerator, denominator): @@ -56,81 +46,3 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list - - -# code partly borrowed from -# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10 -# License: MIT -def _can_actually_p2p(idx_a, idx_b): - dev_i = f"cuda:{idx_a}" - dev_j = f"cuda:{idx_b}" - a = torch.randn(5, device=dev_i) + 123.0 - b = a.to(dev_j) - c = b.to(dev_i) - return torch.all(a == c).cpu().item() - - -# why do we need this cache? -# 1. we can have runtime checks for P2P access, where every process checks -# P2P access to all other GPUs. Unfortunately, the test might cost many -# (world_size * world_size) cuda context, and reduce the memory available -# for the model. see https://github.com/vllm-project/vllm/issues/3821 -# 2. alternatively, we can have a p2p map that is generated by the master -# process and broadcasted to all other processes. This still requires -# #world_size of cuda context, belonging to the master process, on each GPU. -# 3. we can have a cache file, that records the p2p access status. The first -# time the master process checks the p2p access, it will generate the cache -# file, at the cost of #world_size of cuda context. Later on, all processes -# can read the cache file to check the p2p access status without any cost of -# additional cuda context. -# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we -# can have different cache files for different CUDA_VISIBLE_DEVICES settings, -# e.g. used by different vllm engines. The device id in the cache file is a -# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number -# of visible devices in the vllm engine. -_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None - - -def gpu_p2p_access_check(i: int, j: int) -> bool: - """Check if GPU i can access GPU j.""" - - # if the cache variable is already calculated, - # read from the cache instead of checking it again - global _gpu_p2p_access_cache - if _gpu_p2p_access_cache is not None: - return _gpu_p2p_access_cache[f"{i}->{j}"] - - is_distributed = dist.is_initialized() - - num_dev = torch.cuda.device_count() - cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES - if cuda_visible_devices is None: - cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) - VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT - path = os.path.expanduser( - f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" - ) - os.makedirs(os.path.dirname(path), exist_ok=True) - if (not is_distributed or get_local_rank() == 0) \ - and (not os.path.exists(path)): - # only the local master process (with local_rank == 0) can - # enter this block to calculate the cache - logger.info("generating GPU P2P access cache for in %s", path) - cache = {} - for _i in range(num_dev): - for _j in range(num_dev): - # on some platforms, P2P support might be buggy and we need - # additional checks. See also: - # https://github.com/vllm-project/vllm/issues/2728 - cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer( - _i, _j) and _can_actually_p2p(_i, _j) - with open(path, "w") as f: - json.dump(cache, f, indent=4) - if is_distributed: - cpu_world_group = get_cpu_world_group() - dist.barrier(cpu_world_group) - logger.info("reading GPU P2P access cache from %s", path) - with open(path, "r") as f: - cache = json.load(f) - _gpu_p2p_access_cache = cache - return _gpu_p2p_access_cache[f"{i}->{j}"] From 4238bc82f24d5887784b04a353ed93e2360623b4 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Wed, 29 May 2024 12:09:13 -0400 Subject: [PATCH 361/413] [Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837) --- tests/core/block/test_block_manager_v2.py | 154 ++++++++++++++- tests/core/test_block_manager.py | 220 +++++++++++++++++++++- tests/core/utils.py | 99 +++++++++- vllm/core/block/utils.py | 56 ++++++ vllm/core/block_manager_v1.py | 187 ++++++++++++------ vllm/core/block_manager_v2.py | 65 ++++++- vllm/sequence.py | 23 +++ 7 files changed, 735 insertions(+), 69 deletions(-) create mode 100644 vllm/core/block/utils.py diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 91b047f0e183e..f98fc0e217278 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -1,11 +1,13 @@ import pytest +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v2 import BlockSpaceManagerV2 from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list -from ..utils import create_seq_group +from ..utils import create_seq_group, create_seq_group_encoder_decoder @pytest.mark.parametrize("block_size", [16]) @@ -52,6 +54,156 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, assert can_allocate_result == AllocStatus.LATER +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16, 80, 160]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group_encoder_decoder(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for bdx, num_prompt_blocks in enumerate( + range(1, num_gpu_blocks - num_output_blocks)): + num_cross_blocks_per_seq = num_prompt_blocks + + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id=str(bdx)) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + \ + num_output_blocks + \ + num_cross_blocks_per_seq + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, + num_seqs_per_group: int, + num_gpu_blocks: int, + watermark: float): + ''' + SWA short for Sliding Window Attention. + + At time of writing block manager v2 does not support SWA. + + However even when SWA is implemented for block manager v2, + there will still most likely be a separate workstream required + to enable SWA for encoder/decoder models. + + Therefore this test enforces that one of the following cases + hold true: + 1. Block manager v2 does not support SWA at all (true at time of writing) + 2. Block manager v2 fails with NotImplementError when SWA is enabled + AND a SequenceGroup with an encoder sequence (i.e. in support of an + encoder/decoder model) is passed into can_allocate() as an argument + + The setup for this test is stripped down version of + test_can_allocate_seq_group_encoder_decoder() + ''' + + with pytest.raises((NotImplementedError, AssertionError)) as exc_info: + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + sliding_window=5 # SWA + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + block_manager.can_allocate(seq_group) + + # Assert that either + # 1. Block manager v2 constructor fails with assertion that sliding window + # is not yet supported (most likely near-term outcome at time of + # writing), or + # 2. can_allocate() fails with NotImplementedError due to combination of + # encoder/decoder and sliding window attention + if isinstance(exc_info.value, NotImplementedError): + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + elif isinstance(exc_info.value, AssertionError): + assert str(exc_info.value) == "Sliding window not yet supported" + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [16]) +@pytest.mark.parametrize("num_seqs_per_group", [1]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_encoder_decoder_fails_with_prefix_cache( + block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, + watermark: float): + + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + enable_caching=True # Prefix cache + ) + + num_output_blocks_per_seq = 1 + num_prompt_blocks = 1 + num_output_blocks = num_output_blocks_per_seq + seq_group = create_seq_group_encoder_decoder( + seq_prompt_len=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + request_id="0") + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + # Assert that either can_allocate() fails with NotImplementedError + # due to combination of encoder/decoder and prefix cache + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + @pytest.mark.parametrize("block_size", [1, 8]) @pytest.mark.parametrize("prompt_len", [1, 7, 8]) @pytest.mark.parametrize("num_slots_to_append", [1, 8, 129]) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 88cd4f98091f9..ddd843174f7b1 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -6,13 +6,15 @@ from vllm import SamplingParams from vllm.block import PhysicalTokenBlock +from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, UncachedBlockAllocator) from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from .utils import create_dummy_prompt +from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder def test_block_allocator_allocate(): @@ -73,7 +75,7 @@ def test_allocate(): # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK @@ -85,11 +87,107 @@ def test_allocate(): watermark=1 / num_gpu_blocks) for i in range(num_gpu_blocks - 1): _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK block_manager.allocate(seq_group) assert block_manager.can_allocate(seq_group) != AllocStatus.OK +def test_allocate_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + str(i), + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + assert block_manager.can_allocate(seq_group) == AllocStatus.OK + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + +def test_allocate_encoder_decoder_fails_with_swa(): + # SWA short for sliding window attention + + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + sliding_window=5) # swa + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + + # Assert that allocate() fails due to SWA + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA + + +def test_allocate_encoder_decoder_fails_with_prefix_caching(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=True) # Prefix cache + + # Allocate same sequence group to all available gpu blocks. + _, _, seq_group = create_dummy_prompt_encoder_decoder( + "0", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + + # Assert that can_allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.can_allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + # Assert that allocate() fails due to prefix caching + with pytest.raises(NotImplementedError) as exc_info: + block_manager.allocate(seq_group) + + assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE + + def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 @@ -244,6 +342,62 @@ def test_swap(): assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) +def test_swap_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + decoder_prompt.status = SequenceStatus.WAITING + encoder_prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + decoder_prompt.status = SequenceStatus.RUNNING + decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap encoder/decoder seq group from GPU -> CPU. + decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) + cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) + gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert [x[0] for x in mapping] == gpu_blocks + #assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + decoder_prompt.status = SequenceStatus.SWAPPED + + # Swap encoder/decoder seq group from CPU -> GPU. + decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) + cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) + cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert [x[0] for x in mapping] == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + + def test_free(): block_size = 4 num_cpu_blocks = 4 @@ -268,6 +422,41 @@ def test_free(): block_manager.get_block_table(prompt) +def test_free_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + decoder_prompt, encoder_prompt, seq_group = \ + create_dummy_prompt_encoder_decoder( + "1", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + + # Free allocated seq. + decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) + encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) + prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(decoder_prompt) + block_manager.free_cross(seq_group) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(decoder_prompt) + + # Block table for freed encoder & decoder seq's are deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(encoder_prompt) + + def test_reset(): block_size = 4 num_cpu_blocks = 4 @@ -289,6 +478,31 @@ def test_reset(): assert block_manager.get_num_free_gpu_blocks() == original_blocks +def test_reset_encoder_decoder(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_req_per_seq_group = 2 + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks // block_req_per_seq_group): + _, _, seq_group = create_dummy_prompt_encoder_decoder( + f"{i}", + decoder_prompt_length=block_size, + encoder_prompt_length=block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks + + def test_sliding_window_multi_seq(): """ Tests that memory allocation and deallocation is handled diff --git a/tests/core/utils.py b/tests/core/utils.py index 1c5724090b69b..cd2045b8a1889 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -39,6 +39,52 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_prompt_encoder_decoder( + request_id: str, + decoder_prompt_length: int, + encoder_prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = decoder_prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + decoder_prompt_tokens = list(range(decoder_prompt_length)) + decoder_prompt_str = " ".join([str(t) for t in decoder_prompt_tokens]) + + decoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": decoder_prompt_str, + "prompt_token_ids": decoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + + encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length)))) + encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens]) + encoder_prompt = Sequence(int(request_id), + inputs={ + "prompt": encoder_prompt_str, + "prompt_token_ids": encoder_prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[decoder_prompt], + sampling_params=SamplingParams( + use_beam_search=use_beam_search, + best_of=best_of), + arrival_time=time.time(), + lora_request=lora_request, + encoder_seq=encoder_prompt) + + return decoder_prompt, encoder_prompt, seq_group + + def create_seq_group( seq_prompt_len: int = 1024, seq_output_lens: Iterable[int] = (128, ), @@ -82,5 +128,56 @@ def create_seq_group( return seq_group +def create_seq_group_encoder_decoder( + seq_prompt_len: int = 1024, + seq_output_lens: Iterable[int] = (128, ), + request_id: str = '0', + seq_id_start: int = 0, + sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + if sampling_params is None: + sampling_params = SamplingParams() + + prompt_token_ids = [0] * seq_prompt_len + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + # Encoder sequence + encoder_seq = Sequence( + seq_id=seq_id_start + len(seq_output_lens), + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=16, + ) + + return SequenceGroup(request_id=request_id, + seqs=seqs, + sampling_params=sampling_params, + arrival_time=time.time(), + encoder_seq=encoder_seq) + + def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size + return (seq_len + block_size - 1) // block_size \ No newline at end of file diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000000000..2c412a8f472e0 --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,56 @@ +"""Block manager utils.""" +from vllm.sequence import SequenceGroup + +# Exception strings for non-implemented block manager enc/dec scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + + +def _get_block_mgr_sliding_window_attr(block_mgr): + ''' + BlockManagerV1 and BlockManagerV2 have slightly different + members related to sliding window attention (SWA). This + function extracts the appropriate member to use for determining + whether SWA is enabled. + + Arguments: + + * block_mgr: BlockManagerV1 or BlockManagerV2 instance + ''' + + if hasattr(block_mgr, 'block_sliding_window'): + return block_mgr.block_sliding_window + if hasattr(block_mgr, 'max_block_sliding_window'): + return block_mgr.max_block_sliding_window + + raise AttributeError("Block manager instance has neither " + \ + "block_sliding_window nor " + \ + "max_block_sliding_window attributes.") + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if _get_block_mgr_sliding_window_attr(block_mgr) is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 52a170d79e4e7..201cba309f6ef 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger @@ -255,14 +256,30 @@ def __init__( Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + # Mapping: req_id -> BlockTable + # Note that each SequenceGroup has a unique + # request ID + self.cross_block_tables: Dict[str, BlockTable] = {} + + def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + return 0 if seq is None \ + else len(seq.logical_token_blocks) def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - num_required_blocks = len(seq.logical_token_blocks) + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + self_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks + \ + cross_num_required_blocks if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() @@ -276,11 +293,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER - def allocate(self, seq_group: SequenceGroup) -> None: - # NOTE: Here we assume that all sequences in the group have the same - # prompt. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] - + def _allocate_sequence(self, \ + seq: Sequence, \ + ref_count: int, \ + is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = len(seq.logical_token_blocks) @@ -290,21 +306,46 @@ def allocate(self, seq_group: SequenceGroup) -> None: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() - elif self.enable_caching: + block.ref_count = ref_count + elif not is_encoder_decoder and self.enable_caching: block = self.gpu_allocator.allocate( seq.hash_of_block(logical_idx), seq.num_hashed_tokens_of_block(logical_idx)) else: block = self.gpu_allocator.allocate() # Set the reference counts of the token blocks. - block.ref_count = seq_group.num_seqs() + block.ref_count = ref_count block_table.append(block) - # Assign the block table for each sequence. + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + is_encoder_decoder = seq_group.is_encoder_decoder() + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # Allocate decoder sequences + # + # NOTE: Here we assume that all sequences in the group have the same + # decoder prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + block_table: BlockTable = \ + self._allocate_sequence(seq, + seq_group.num_seqs(), + is_encoder_decoder) + + # Assign the self-attention block tables for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() + # Allocate encoder sequence + if is_encoder_decoder: + # A SequenceGroup has only a single encoder sequence (at most), + # thus allocate with a ref count of 1 + block_table = self._allocate_sequence(seq_group.get_encoder_seq(), + 1, is_encoder_decoder) + # Assign the cross-attention block table for the SequenceGroup. + self.cross_block_tables[seq_group.request_id] = block_table + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> bool: @@ -443,13 +484,18 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def _get_physical_blocks( self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. + request_id = seq_group.request_id blocks: Set[PhysicalTokenBlock] = set() for seq in seq_group.get_seqs(): if seq.is_finished(): continue blocks.update(self.block_tables[seq.seq_id]) + # Cross-attention blocks + if seq_group.is_encoder_decoder(): + blocks.update(self.cross_block_tables[request_id]) return list(blocks) def can_swap_in(self, @@ -457,8 +503,11 @@ def can_swap_in(self, num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + if seq_group.is_encoder_decoder(): + num_swapped_seqs += 1 num_free_blocks = self.gpu_allocator.get_num_free_blocks() # NOTE: Conservatively, we assume that every sequence will allocate # at least one free block right after the swap-in. @@ -471,70 +520,81 @@ def can_swap_in(self, else: return AllocStatus.LATER + def _swap_block_table( + self, block_table: BlockTable, src_allocator: BlockAllocatorBase, + dest_allocator: BlockAllocatorBase, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + new_block_table = [] + + for from_block in block_table: + if from_block in mapping: + to_block = mapping[from_block] + to_block.ref_count += 1 + else: + to_block = dest_allocator.allocate( + from_block.block_hash, from_block.num_hashed_tokens) + mapping[from_block] = to_block + new_block_table.append(to_block) + # Free the source block swapped in to destination. + src_allocator.free(from_block) + + return new_block_table + def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" + request_id = seq_group.request_id + # CPU block -> GPU block. # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for cpu_block in block_table: - if cpu_block in mapping: - gpu_block = mapping[cpu_block] - gpu_block.ref_count += 1 - else: - gpu_block = self.gpu_allocator.allocate( - cpu_block.block_hash, cpu_block.num_hashed_tokens) - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + request_id = seq_group.request_id + # GPU block -> CPU block. # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - new_block_table: BlockTable = [] - block_table = self.block_tables[seq.seq_id] - - for gpu_block in block_table: - if gpu_block in mapping: - cpu_block = mapping[gpu_block] - cpu_block.ref_count += 1 - else: - cpu_block = self.cpu_allocator.allocate( - gpu_block.block_hash, gpu_block.num_hashed_tokens) - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - self.gpu_allocator.free(gpu_block) - self.block_tables[seq.seq_id] = new_block_table - - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - # convert to list of tuples once here - return list(block_number_mapping.items()) + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up @@ -559,15 +619,32 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] + def free_cross(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.cross_block_tables: + # Already freed or hasn't ben scheduled yet. + return + block_table = self.cross_block_tables[seq_group.request_id] + self._free_block_table(block_table) + del self.cross_block_tables[seq_group.request_id] + def reset(self) -> None: + # Free decoder block tables for block_table in self.block_tables.values(): self._free_block_table(block_table) self.block_tables.clear() + # Free cross-attention block tables + for block_table in self.cross_block_tables.values(): + self._free_block_table(block_table) + self.cross_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: block_table = self.block_tables[seq.seq_id] return [block.block_number for block in block_table] + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.cross_block_tables[seq_group.request_id] + return [block.block_number for block in block_table] + def get_num_free_gpu_blocks(self) -> int: return self.gpu_allocator.get_num_free_blocks() diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 834436c25e160..cad42ab3c1ba2 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -5,11 +5,13 @@ from vllm.core.block.block_table import BlockTable from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device SeqId = int +EncoderSeqId = str class BlockSpaceManagerV2(BlockSpaceManager): @@ -94,17 +96,26 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = BlockTable.get_num_required_blocks( seq.get_token_ids(), block_size=self.block_size, ) + if seq_group.is_encoder_decoder(): + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.get_encoder_seq().get_token_ids(), + block_size=self.block_size, + ) + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.max_block_sliding_window) @@ -121,7 +132,19 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: else: return AllocStatus.LATER + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + block_table.allocate(seq.get_token_ids()) + + return block_table + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert not (set(seq.seq_id for seq in waiting_seqs) & self.block_tables.keys()), "block table already exists" @@ -129,20 +152,29 @@ def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. seq = waiting_seqs[0] - - block_table = BlockTable( - block_size=self.block_size, - block_allocator=self.block_allocator, - max_block_sliding_window=self.max_block_sliding_window, - ) - - block_table.allocate(seq.get_token_ids()) + block_table: BlockTable = self._allocate_sequence(seq) self.block_tables[seq.seq_id] = block_table # Assign the block table for each sequence. for seq in waiting_seqs[1:]: self.block_tables[seq.seq_id] = block_table.fork() + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + block_table = self._allocate_sequence(seq_group.get_encoder_seq()) + self.cross_block_tables[request_id] = block_table + def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> bool: """Determine if there is enough space in the GPU KV cache to continue @@ -197,12 +229,27 @@ def free(self, seq: Sequence) -> None: self.block_tables[seq.seq_id].free() del self.block_tables[seq.seq_id] + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + def access_all_blocks_in_seq(self, seq: Sequence, now: float): # Update the last accessed time of all the blocks accessed # in this step. diff --git a/vllm/sequence.py b/vllm/sequence.py index f8e9da6c7965a..ee8c94bbf06f7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -430,6 +430,8 @@ class SequenceGroup: for an embedding model. pooling_params: The pooling parameters used to generate the pooling for an embedding model. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. """ def __init__( @@ -441,6 +443,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, + encoder_seq: Optional[Sequence] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -455,6 +458,7 @@ def __init__( self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params + self.encoder_seq = encoder_seq @property def prompt(self) -> Optional[str]: @@ -538,6 +542,12 @@ def get_seqs( seq for seq in self.seqs_dict.values() if seq.status == status ] + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq + def get_unfinished_seqs(self) -> List[Sequence]: return [ seq for seq in self.seqs_dict.values() if not seq.is_finished() @@ -621,6 +631,15 @@ class SequenceGroupMetadata: used in prefix caching. state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. """ def __init__( @@ -637,6 +656,8 @@ def __init__( computed_block_nums: Optional[List[int]] = None, state: Optional[SequenceGroupState] = None, multi_modal_data: Optional[MultiModalData] = None, + encoder_seq_data: Optional[SequenceData] = None, + cross_block_table: Optional[List[int]] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -648,6 +669,8 @@ def __init__( self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state + self.encoder_seq_data = encoder_seq_data + self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size self.do_sample = do_sample From ae495c74eab390e52bcade098ee8313679fa8802 Mon Sep 17 00:00:00 2001 From: Ronen Schaffer Date: Thu, 30 May 2024 01:26:33 +0300 Subject: [PATCH 362/413] [Doc]Replace deprecated flag in readme (#4526) --- examples/production_monitoring/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/production_monitoring/README.md b/examples/production_monitoring/README.md index 29b611caeda23..268f2e771018f 100644 --- a/examples/production_monitoring/README.md +++ b/examples/production_monitoring/README.md @@ -29,7 +29,8 @@ python3 ../../benchmarks/benchmark_serving.py \ --model mistralai/Mistral-7B-v0.1 \ --tokenizer mistralai/Mistral-7B-v0.1 \ --endpoint /v1/completions \ - --dataset ShareGPT_V3_unfiltered_cleaned_split.json \ + --dataset-name sharegpt \ + --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \ --request-rate 3.0 ``` From eecd864388cba75421215411d42bde1c328fa518 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 30 May 2024 07:02:25 +0800 Subject: [PATCH 363/413] [Bugfix][CI/Build] Fix test and improve code for `merge_async_iterators` (#5096) --- .../test_merge_async_iterators.py | 41 ------------- tests/test_utils.py | 57 ++++++++++++++++++- vllm/utils.py | 9 ++- 3 files changed, 62 insertions(+), 45 deletions(-) delete mode 100644 tests/async_engine/test_merge_async_iterators.py diff --git a/tests/async_engine/test_merge_async_iterators.py b/tests/async_engine/test_merge_async_iterators.py deleted file mode 100644 index ea453526c77f8..0000000000000 --- a/tests/async_engine/test_merge_async_iterators.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -from typing import AsyncIterator, Tuple - -import pytest - -from vllm.utils import merge_async_iterators - - -@pytest.mark.asyncio -async def test_merge_async_iterators(): - - async def mock_async_iterator(idx: int) -> AsyncIterator[str]: - try: - while True: - yield f"item from iterator {idx}" - await asyncio.sleep(0.1) - except asyncio.CancelledError: - pass - - iterators = [mock_async_iterator(i) for i in range(3)] - merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( - *iterators) - - async def stream_output(generator: AsyncIterator[Tuple[int, str]]): - async for idx, output in generator: - print(f"idx: {idx}, output: {output}") - - task = asyncio.create_task(stream_output(merged_iterator)) - await asyncio.sleep(0.5) - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - for iterator in iterators: - try: - await asyncio.wait_for(anext(iterator), 1) - except StopAsyncIteration: - # All iterators should be cancelled and print this message. - print("Iterator was cancelled normally") - except (Exception, asyncio.CancelledError) as e: - raise AssertionError() from e diff --git a/tests/test_utils.py b/tests/test_utils.py index 54dc5c6f5bfba..a6c3896fa43bf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,64 @@ +import asyncio +import sys +from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, + Tuple, TypeVar) + import pytest -from vllm.utils import deprecate_kwargs +from vllm.utils import deprecate_kwargs, merge_async_iterators from .utils import error_on_warning +if sys.version_info < (3, 10): + if TYPE_CHECKING: + _AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any]) + _AwaitableT_co = TypeVar("_AwaitableT_co", + bound=Awaitable[Any], + covariant=True) + + class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]): + + def __anext__(self) -> _AwaitableT_co: + ... + + def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT": + return i.__anext__() + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + + async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + iterators = [mock_async_iterator(i) for i in range(3)] + merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + *iterators) + + async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e + def test_deprecate_kwargs_always(): diff --git a/vllm/utils.py b/vllm/utils.py index 85e045cb3b768..26140e15636a4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import os import socket import subprocess +import sys import tempfile import threading import uuid @@ -234,9 +235,11 @@ async def consumer(): yield item except (Exception, asyncio.CancelledError) as e: for task in _tasks: - # NOTE: Pass the error msg in cancel() - # when only Python 3.9+ is supported. - task.cancel() + if sys.version_info >= (3, 9): + # msg parameter only supported in Python 3.9+ + task.cancel(e) + else: + task.cancel() raise e await asyncio.gather(*_tasks) From eb6c50cdc2bfb58591bd524ff08c8016e7c0411a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 30 May 2024 07:02:54 +0800 Subject: [PATCH 364/413] [Bugfix][CI/Build] Fix codespell failing to skip files in `git diff` (#5097) --- format.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/format.sh b/format.sh index aaec25a8aa0dc..d110855f8c273 100755 --- a/format.sh +++ b/format.sh @@ -113,8 +113,11 @@ mypy vllm/logging --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml +# If git diff returns a file that is in the skip list, the file may be checked anyway: +# https://github.com/codespell-project/codespell/issues/1915 +# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem CODESPELL_EXCLUDES=( - '--skip' '*docs/source/_build/**,./tests/lora/data' + '--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,tests/lora/data/**,build/**' ) # check spelling of specified files From b1c255630db60e08c394964b8ed6c0154d31a29f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 30 May 2024 07:05:01 +0800 Subject: [PATCH 365/413] [Core] Avoid the need to pass `None` values to `Sequence.inputs` (#5099) --- tests/core/test_block_manager.py | 2 -- tests/core/utils.py | 7 +------ tests/engine/output_processor/test_stop_checker.py | 6 +----- tests/test_cache_block_hashing.py | 1 - tests/tokenization/test_detokenize.py | 1 - vllm/inputs.py | 4 ++-- vllm/sequence.py | 4 ++-- 7 files changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index ddd843174f7b1..cd306b9e4d3cc 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -234,7 +234,6 @@ def test_append_slot_cow(): inputs={ "prompt": "one two three", "prompt_token_ids": [1, 2, 3], - "multi_modal_data": None }, block_size=block_size) @@ -525,7 +524,6 @@ def test_sliding_window_multi_seq(): inputs={ "prompt": "one two three", "prompt_token_ids": [0, 1, 2], - "multi_modal_data": None }, block_size=block_size) seq_group = SequenceGroup(request_id="1", diff --git a/tests/core/utils.py b/tests/core/utils.py index cd2045b8a1889..2fbf099c5f90b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -25,7 +25,6 @@ def create_dummy_prompt( inputs={ "prompt": prompt_str, "prompt_token_ids": prompt_tokens, - "multi_modal_data": None, }, block_size=block_size) seq_group = SequenceGroup(request_id=request_id, @@ -103,11 +102,7 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - inputs={ - "prompt": "", - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, - }, + inputs={"prompt_token_ids": prompt_token_ids}, block_size=16, ) diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index 1d9c878ddde50..f795403e3d8ad 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -15,11 +15,7 @@ def sequence_with_eos(text: str, eos_token: str, """ seq = Sequence( seq_id=0, - inputs={ - "prompt": "", - "prompt_token_ids": [], - "multi_modal_data": None, - }, + inputs={"prompt_token_ids": []}, block_size=16, eos_token_id=eos_token_id, ) diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 97864af88e40a..0fbe3dae1ff08 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -74,7 +74,6 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, inputs={ "prompt": prompt, "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, }, block_size=block_size, eos_token_id=tokenizer.tokenizer.eos_token_id, diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 1d4c74d6bd8da..8d019fe5f38ca 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -126,7 +126,6 @@ def create_sequence(prompt_token_ids=None): inputs={ "prompt": "", "prompt_token_ids": prompt_token_ids, - "multi_modal_data": None, }, block_size=16, ) diff --git a/vllm/inputs.py b/vllm/inputs.py index f5d99b1b66b70..85c9cd84f5ed5 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -126,5 +126,5 @@ class TextTokensPrompt(TypedDict): class LLMInputs(TypedDict): prompt_token_ids: List[int] - prompt: Optional[str] - multi_modal_data: Optional["MultiModalData"] + prompt: NotRequired[Optional[str]] + multi_modal_data: NotRequired[Optional["MultiModalData"]] diff --git a/vllm/sequence.py b/vllm/sequence.py index ee8c94bbf06f7..ac5c234d052bd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -249,7 +249,7 @@ def __init__( @property def prompt(self) -> Optional[str]: - return self.inputs["prompt"] + return self.inputs.get("prompt") @property def prompt_token_ids(self) -> List[int]: @@ -257,7 +257,7 @@ def prompt_token_ids(self) -> List[int]: @property def multi_modal_data(self) -> Optional["MultiModalData"]: - return self.inputs["multi_modal_data"] + return self.inputs.get("multi_modal_data") @property def lora_int_id(self) -> int: From 7c3604fb68031da36567151a9bdfe69e04de44b8 Mon Sep 17 00:00:00 2001 From: Itay Etelis <92247226+Etelis@users.noreply.github.com> Date: Thu, 30 May 2024 02:13:22 +0300 Subject: [PATCH 366/413] [Bugfix] logprobs is not compatible with the OpenAI spec #4795 (#5031) --- vllm/entrypoints/openai/protocol.py | 5 ++--- vllm/entrypoints/openai/serving_chat.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 41e2f77fe56f1..e6eae689d7e03 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -109,7 +109,7 @@ class ChatCompletionRequest(OpenAIBaseModel): frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None logprobs: Optional[bool] = False - top_logprobs: Optional[int] = None + top_logprobs: Optional[int] = 0 max_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -192,8 +192,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params(self) -> SamplingParams: - if self.logprobs and not self.top_logprobs: - raise ValueError("Top logprobs must be set when logprobs is.") + # We now allow logprobs being true without top_logrobs. logits_processors = None if self.logit_bias: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 33daabd881df0..8cb50e33e58d1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -286,7 +286,7 @@ async def chat_completion_stream_generator( logprobs = self._create_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, + num_output_top_logprobs=request.top_logprobs, initial_text_offset=len(previous_texts[i]), ) else: @@ -373,7 +373,7 @@ async def chat_completion_full_generator( logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, + num_output_top_logprobs=request.top_logprobs, ) else: logprobs = None From 4fbcb0f27e78df75de47c0248ce6901cd081c8ff Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 29 May 2024 16:51:18 -0700 Subject: [PATCH 367/413] [Doc][Build] update after removing vllm-nccl (#5103) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- Dockerfile | 6 ------ docs/source/serving/deploying_with_docker.rst | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/Dockerfile b/Dockerfile index ddca95c0e8786..eb96bf3c1db2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -79,12 +79,6 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ COPY .buildkite/check-wheel-size.py check-wheel-size.py RUN python3 check-wheel-size.py dist -# the `vllm_nccl` package must be installed from source distribution -# pip is too smart to store a wheel in the cache, and other CI jobs -# will directly use the wheel from the cache, which is not what we want. -# we need to remove it manually -RUN --mount=type=cache,target=/root/.cache/pip \ - pip cache remove vllm_nccl* #################### EXTENSION Build IMAGE #################### #################### vLLM installation IMAGE #################### diff --git a/docs/source/serving/deploying_with_docker.rst b/docs/source/serving/deploying_with_docker.rst index cfc462ff33b90..fa82bc8e3bd33 100644 --- a/docs/source/serving/deploying_with_docker.rst +++ b/docs/source/serving/deploying_with_docker.rst @@ -51,4 +51,4 @@ To run vLLM: .. note:: - vLLM docker image is currently designed to be run under the root user (contribution welcomed for changing this!). It will try to load library at runtime under the root user's home directory, e.g. `/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . If you are running the container under a different user, you may need to change the permissions of the library (and all the parent directories) to allow the user to access it. Then run vLLM with environment variable `VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . + **For `v0.4.1` and `v0.4.2` only** - the vLLM docker images under these versions are supposed to be run under the root user since a library under the root user's home directory, i.e. ``/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1`` is required to be loaded during runtime. If you are running the container under a different user, you may need to first change the permissions of the library (and all the parent directories) to allow the user to access it, then run vLLM with environment variable ``VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1`` . From 5bf185a1c48fdca524dd76aec4a1424b3a09c9a1 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Wed, 29 May 2024 20:30:18 -0400 Subject: [PATCH 368/413] [Bugfix] gptq_marlin: Ensure g_idx_sort_indices is not a Parameter (#5108) --- .../layers/quantization/gptq_marlin.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 4374fd98012f6..ae440743fdf8e 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -298,14 +298,10 @@ def create_weights( }, ) - g_idx_sort_indices = Parameter( - torch.empty( - g_idx.shape, - dtype=torch.int32, - ), - requires_grad=False, + g_idx_sort_indices = torch.empty( + g_idx.shape, + dtype=torch.int32, ) - set_weight_attrs(g_idx_sort_indices, extra_weight_attrs) # Scales scales = Parameter( @@ -356,9 +352,9 @@ def create_weights( layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) - layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) + layer.g_idx_sort_indices = g_idx_sort_indices layer.workspace = workspace layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition From e07aff9e52342dc82b73c803ba69601242801bc4 Mon Sep 17 00:00:00 2001 From: omkar kakarparthi <75638701+okakarpa@users.noreply.github.com> Date: Wed, 29 May 2024 22:27:39 -0500 Subject: [PATCH 369/413] [CI/Build] Docker cleanup functionality for amd servers (#5112) Co-authored-by: Alexey Kondratiev Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: Alexei V. Ivanov Co-authored-by: omkarkakarparthi --- .buildkite/run-amd-test.sh | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 7452423479521..bde8ab6184d3c 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -5,6 +5,34 @@ set -ex echo "--- ROCm info" rocminfo +# cleanup older docker images +cleanup_docker() { + # Get Docker's root directory + docker_root=$(docker info -f '{{.DockerRootDir}}') + if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 + fi + echo "Docker root directory: $docker_root" + # Check disk usage of the filesystem where Docker's root directory is located + disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') + # Define the threshold + threshold=70 + if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes + docker volume prune -f + echo "Docker images and volumes cleanup completed." + else + echo "Disk usage is below $threshold%. No cleanup needed." + fi +} + +# Call the cleanup docker function +cleanup_docker + echo "--- Resetting GPUs" echo "reset" > /opt/amdgpu/etc/gpu_state From 87d41c849d2cde9279fb08a3a0d97123e3d8fe2f Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 30 May 2024 11:52:14 +0200 Subject: [PATCH 370/413] [BUGFIX] [FRONTEND] Correct chat logprobs (#5029) Co-authored-by: Breno Faria --- tests/async_engine/test_openapi_server_ray.py | 6 +- tests/entrypoints/test_openai_server.py | 209 +++++++++++++++--- vllm/entrypoints/openai/protocol.py | 50 ++++- vllm/entrypoints/openai/serving_chat.py | 68 +++++- vllm/entrypoints/openai/serving_completion.py | 74 ++++++- vllm/entrypoints/openai/serving_engine.py | 52 +---- 6 files changed, 361 insertions(+), 98 deletions(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 7a8d4b3915617..4c362a0512feb 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 2463ccde2bc8b..972137030f46f 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras @@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, choice = completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 @pytest.mark.asyncio @@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -251,10 +338,93 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=False) + + choice = chat_completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=0) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=5) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 6 + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, - model_name: str): +async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -263,13 +433,13 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, "content": "what is 1+1?" }] - # Default max_logprobs is 5, so this should raise an error + # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): stream = await client.chat.completions.create(model=model_name, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=21, stream=True) async for chunk in stream: ... @@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=30, stream=False) - with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=True) - async for chunk in stream: - ... - - with pytest.raises(openai.BadRequestError): - await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=False) - # the server should still work afterwards chat_completion = await client.chat.completions.create(model=model_name, messages=messages, @@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, top_logprobs=5, extra_body=dict(guided_choice=TEST_CHOICE, guided_decoding_backend=guided_decoding_backend)) - top_logprobs = chat_completion.choices[0].logprobs.top_logprobs + top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs # -9999.0 is the minimum logprob returned by OpenAI assert all( - isinstance(logprob, float) and logprob >= -9999.0 - for token_dict in top_logprobs - for token, logprob in token_dict.items()) + isinstance(token.logprob, float) and token.logprob >= -9999.0 + for token in top_logprobs) @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index e6eae689d7e03..e380212a4d76b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -250,6 +250,19 @@ def check_guided_decoding_count(cls, data): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "top_logprobs" in data and data["top_logprobs"] is not None: + if "logprobs" not in data or data["logprobs"] is False: + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + elif not 0 <= data["top_logprobs"] <= 20: + raise ValueError( + "`top_logprobs` must be a value in the interval [0, 20].") + return data + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -396,6 +409,15 @@ def check_guided_decoding_count(cls, data): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "logprobs" in data and data[ + "logprobs"] is not None and not 0 <= data["logprobs"] <= 5: + raise ValueError(("if passed, `logprobs` must be a value", + " in the interval [0, 5].")) + return data + class EmbeddingRequest(BaseModel): # Ordered by official OpenAI API documentation @@ -415,7 +437,7 @@ def to_pooling_params(self): return PoolingParams(additional_data=self.additional_data) -class LogProbs(OpenAIBaseModel): +class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) @@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel): class CompletionResponseChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = Field( default=None, @@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = Field( default=None, @@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel): content: str +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None stop_reason: Optional[Union[int, str]] = None @@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None stop_reason: Optional[Union[int, str]] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8cb50e33e58d1..cc5b896e0e56c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,10 @@ import codecs import time from dataclasses import dataclass -from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional, - TypedDict, Union, cast, final) +from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, + Optional) +from typing import Sequence as GenericSequence +from typing import TypedDict, Union, cast, final from fastapi import Request from openai.types.chat import ChatCompletionContentPartTextParam @@ -10,8 +12,9 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( - ChatCompletionContentPartParam, ChatCompletionMessageParam, - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionContentPartParam, ChatCompletionLogProb, + ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) @@ -21,6 +24,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import random_uuid logger = init_logger(__name__) @@ -283,11 +287,10 @@ async def chat_completion_stream_generator( previous_num_tokens[i]:] if output.logprobs else None if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.top_logprobs, - initial_text_offset=len(previous_texts[i]), ) else: logprobs = None @@ -370,7 +373,7 @@ async def chat_completion_full_generator( top_logprobs = output.logprobs if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.top_logprobs, @@ -383,8 +386,7 @@ async def chat_completion_full_generator( message=ChatMessage(role=role, content=output.text), logprobs=logprobs, finish_reason=output.finish_reason, - stop_reason=output.stop_reason, - ) + stop_reason=output.stop_reason) choices.append(choice_data) if request.echo: @@ -414,3 +416,51 @@ async def chat_completion_full_generator( ) return response + + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], + top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb( + token=self._get_decoded_token(p[1], p[0]), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + self._get_decoded_token(p[1], + p[0]).encode("utf-8", + errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + + logprobs_content = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self.tokenizer.decode(token_id), + bytes=list( + self.tokenizer.decode(token_id).encode( + "utf-8", errors="replace")))) + else: + logprobs_content.append( + ChatCompletionLogProbsContent( + token=step_top_logprobs[token_id].decoded_token, + logprob=max(step_top_logprobs[token_id].logprob, + -9999.0), + bytes=list( + step_top_logprobs[token_id].decoded_token.encode( + "utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs))) + + return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d1812c8f44f41..2fb122edaf98a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,23 +1,29 @@ import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Optional, Tuple) + Optional) +from typing import Sequence as GenericSequence +from typing import Tuple from fastapi import Request from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.protocol import (CompletionRequest, +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - LogProbs, UsageInfo) + UsageInfo) +# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -25,7 +31,7 @@ TypeTokenIDs = List[int] TypeTopLogProbs = List[Optional[Dict[int, float]]] TypeCreateLogProbsFn = Callable[ - [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] def parse_prompt_format(prompt) -> Tuple[bool, list]: @@ -235,7 +241,7 @@ async def completion_stream_generator( i]:] if output.logprobs else None if request.logprobs is not None: - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -317,7 +323,7 @@ def request_output_to_completion_response( assert top_logprobs is not None, ( "top_logprobs must be provided when logprobs " "is requested") - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -351,3 +357,59 @@ def request_output_to_completion_response( choices=choices, usage=usage, ) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: int, + initial_text_offset: int = 0, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: List[int] = [] + out_token_logprobs: List[Optional[float]] = [] + out_tokens: List[str] = [] + out_top_logprobs: List[Optional[Dict[str, float]]] = [] + + last_token_len = 0 + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + token = self._get_decoded_token(step_top_logprobs[token_id], + token_id) + token_logprob = max(step_top_logprobs[token_id].logprob, + -9999.0) + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token(top_lp[1], top_lp[0]): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 708b0dad102c4..066acdf1c019a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ErrorResponse, - LogProbs, ModelCard, ModelList, + ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -75,51 +75,6 @@ async def show_available_models(self) -> ModelList: model_cards.extend(lora_cards) return ModelList(data=model_cards) - def _create_logprobs( - self, - token_ids: List[int], - top_logprobs: List[Optional[Dict[int, Logprob]]], - num_output_top_logprobs: Optional[int] = None, - initial_text_offset: int = 0, - ) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() - last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] - - for i, token_id in enumerate(token_ids): - step_top_logprobs = top_logprobs[i] - if step_top_logprobs is None: - token = self.tokenizer.decode(token_id) - logprobs.tokens.append(token) - logprobs.token_logprobs.append(None) - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append(None) - else: - token_logprob = step_top_logprobs[token_id].logprob - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - token_logprob = max(token_logprob, -9999.0) - logprobs.token_logprobs.append(token_logprob) - - if num_output_top_logprobs: - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append({ - # Convert float("-inf") to the - # JSON-serializable float that OpenAI uses - p.decoded_token: max(p.logprob, -9999.0) - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) - else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) - last_token_len = len(token) - return logprobs - def create_error_response( self, message: str, @@ -235,3 +190,8 @@ def _validate_prompt_and_tokenize( f"Please reduce the length of the messages or completion.", ) else: return input_ids, input_text + + def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str: + if logprob.decoded_token is not None: + return logprob.decoded_token + return self.tokenizer.decode(token_id) From d910816c7356f4decd56eefb80e963b476cdf3e5 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 30 May 2024 05:58:37 -0700 Subject: [PATCH 371/413] [Bugfix] Automatically Detect SparseML models (#5119) --- vllm/config.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4b256d00a32df..4d05b4ea36d5c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -156,6 +156,17 @@ def _verify_embedding_mode(self) -> None: self.embedding_mode = any( ModelRegistry.is_embedding_model(arch) for arch in architectures) + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # SparseML uses a "compression_config" with a "quantization_config". + compression_cfg = getattr(self.hf_config, "compression_config", + None) + if compression_cfg is not None: + quant_cfg = compression_cfg.get("quantization_config", None) + + return quant_cfg + def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] @@ -163,12 +174,13 @@ def _verify_quantization(self) -> None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. - quant_cfg = getattr(self.hf_config, "quantization_config", None) + quant_cfg = self._parse_quant_hf_config() + if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for name, method in QUANTIZATION_METHODS.items(): + for _, method in QUANTIZATION_METHODS.items(): quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: From f758505c736ce53a13567852594c3e05215bb6b2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 30 May 2024 06:29:48 -0700 Subject: [PATCH 372/413] [CI/Build] increase wheel size limit to 200 MB (#5130) --- .buildkite/check-wheel-size.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 41d9e682572a6..75ad094fa1382 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,7 +1,7 @@ import os import zipfile -MAX_SIZE_MB = 150 +MAX_SIZE_MB = 200 def print_top_10_largest_files(zip_file): From d79d9eaaff90801668613a4e3d5d8a0004963f21 Mon Sep 17 00:00:00 2001 From: Hyunsung Lee Date: Thu, 30 May 2024 22:56:19 +0900 Subject: [PATCH 373/413] [Misc] remove duplicate definition of `seq_lens_tensor` in model_runner.py (#5129) --- vllm/worker/model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5ddd2d1b65f81..47aa70dc617af 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -518,9 +518,6 @@ def _prepare_model_input( else: multi_modal_input = None - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) query_lens_tensor = torch.tensor(query_lens, dtype=torch.long, device=self.device) From a9bcc7afb23d208efaa1b47549fa93eaa1d9d6cf Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 31 May 2024 00:59:23 +0800 Subject: [PATCH 374/413] [Doc] Use intersphinx and update entrypoints docs (#5125) --- docs/source/conf.py | 13 ++++++++++++- vllm/engine/async_llm_engine.py | 2 -- vllm/engine/llm_engine.py | 4 ++-- vllm/entrypoints/llm.py | 26 ++++++++++++++++++-------- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9da5a4991734d..cfebc2ff9bb33 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -80,7 +80,7 @@ def setup(app): generate_examples() -# Mock out external dependencies here. +# Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ "cpuinfo", "torch", @@ -115,4 +115,15 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: autodoc.ClassDocumenter = MockedClassDocumenter +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'typing_extensions': + ('https://typing-extensions.readthedocs.io/en/latest', None), + 'numpy': ('https://numpy.org/doc/stable', None), + 'torch': ('https://pytorch.org/docs/stable', None), + 'psutil': ('https://psutil.readthedocs.io/en/stable', None), +} + +autodoc_preserve_defaults = True + navigation_with_keys = False diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d4289c715d9e6..db4d2849b3f0e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -307,8 +307,6 @@ class AsyncLLMEngine: generate method when there are requests in the waiting queue. The generate method yields the outputs from the :class:`LLMEngine` to the caller. - NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`. - Args: worker_use_ray: Whether to use Ray for model workers. Required for distributed execution. Should be the same as diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 08bccf209b7c4..cb5893e707c8b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -70,8 +70,8 @@ class LLMEngine: The :class:`~vllm.LLM` class wraps this class for offline batched inference and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs` - class. For the comprehensive list of arguments, see :ref:`engine_args`. + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) Args: model_config: The configuration related to the LLM model. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9759d05577796..6e971ae73f5d0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -30,12 +30,6 @@ class LLM: this class generates texts from the model, using an intelligent batching mechanism and efficient memory management. - NOTE: This class is intended to be used for offline inference. For online - serving, use the :class:`~vllm.AsyncLLMEngine` class instead. - - NOTE: For the comprehensive list of arguments, see - :class:`~vllm.EngineArgs`. - Args: model: The name or path of a HuggingFace Transformers model. tokenizer: The name or path of a HuggingFace Transformers tokenizer. @@ -84,6 +78,12 @@ class LLM: When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig + **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Note: + This class is intended to be used for offline inference. For online + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. """ DEPRECATE_LEGACY: ClassVar[bool] = False @@ -253,7 +253,7 @@ def generate( ) -> List[RequestOutput]: """Generates the completions for the input prompts. - NOTE: This class automatically batches the given prompts, considering + This class automatically batches the given prompts, considering the memory constraint. For the best performance, put all of your prompts into a single list and pass it to this method. @@ -270,6 +270,11 @@ def generate( Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. """ if prompt_token_ids is not None or multi_modal_data is not None: inputs = self._convert_v1_inputs( @@ -393,7 +398,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. - NOTE: This class automatically batches the given prompts, considering + This class automatically batches the given prompts, considering the memory constraint. For the best performance, put all of your prompts into a single list and pass it to this method. @@ -409,6 +414,11 @@ def encode( Returns: A list of `EmbeddingRequestOutput` objects containing the generated embeddings in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. """ if prompt_token_ids is not None or multi_modal_data is not None: inputs = self._convert_v1_inputs( From 429d89720e41901c3c0499a8ed3ad5be693cc945 Mon Sep 17 00:00:00 2001 From: Chansung Park Date: Fri, 31 May 2024 02:11:07 +0900 Subject: [PATCH 375/413] add doc about serving option on dstack (#3074) Co-authored-by: Roger Wang --- docs/source/serving/deploying_with_dstack.rst | 103 ++++++++++++++++++ docs/source/serving/integrations.rst | 1 + 2 files changed, 104 insertions(+) create mode 100644 docs/source/serving/deploying_with_dstack.rst diff --git a/docs/source/serving/deploying_with_dstack.rst b/docs/source/serving/deploying_with_dstack.rst new file mode 100644 index 0000000000000..baf87314ca8e4 --- /dev/null +++ b/docs/source/serving/deploying_with_dstack.rst @@ -0,0 +1,103 @@ +.. _deploying_with_dstack: + +Deploying with dstack +============================ + +.. raw:: html + +

+ vLLM_plus_dstack +

+ +vLLM can be run on a cloud based GPU machine with `dstack `__, an open-source framework for running LLMs on any cloud. This tutorial assumes that you have already configured credentials, gateway, and GPU quotas on your cloud environment. + +To install dstack client, run: + +.. code-block:: console + + $ pip install "dstack[all] + $ dstack server + +Next, to configure your dstack project, run: + +.. code-block:: console + + $ mkdir -p vllm-dstack + $ cd vllm-dstack + $ dstack init + +Next, to provision a VM instance with LLM of your choice(`NousResearch/Llama-2-7b-chat-hf` for this example), create the following `serve.dstack.yml` file for the dstack `Service`: + +.. code-block:: yaml + + type: service + + python: "3.11" + env: + - MODEL=NousResearch/Llama-2-7b-chat-hf + port: 8000 + resources: + gpu: 24GB + commands: + - pip install vllm + - python -m vllm.entrypoints.openai.api_server --model $MODEL --port 8000 + model: + format: openai + type: chat + name: NousResearch/Llama-2-7b-chat-hf + +Then, run the following CLI for provisioning: + +.. code-block:: console + + $ dstack run . -f serve.dstack.yml + + ⠸ Getting run plan... + Configuration serve.dstack.yml + Project deep-diver-main + User deep-diver + Min resources 2..xCPU, 8GB.., 1xGPU (24GB) + Max price - + Max duration - + Spot policy auto + Retry policy no + + # BACKEND REGION INSTANCE RESOURCES SPOT PRICE + 1 gcp us-central1 g2-standard-4 4xCPU, 16GB, 1xL4 (24GB), 100GB (disk) yes $0.223804 + 2 gcp us-east1 g2-standard-4 4xCPU, 16GB, 1xL4 (24GB), 100GB (disk) yes $0.223804 + 3 gcp us-west1 g2-standard-4 4xCPU, 16GB, 1xL4 (24GB), 100GB (disk) yes $0.223804 + ... + Shown 3 of 193 offers, $5.876 max + + Continue? [y/n]: y + ⠙ Submitting run... + ⠏ Launching spicy-treefrog-1 (pulling) + spicy-treefrog-1 provisioning completed (running) + Service is published at ... + +After the provisioning, you can interact with the model by using the OpenAI SDK: + +.. code-block:: python + + from openai import OpenAI + + client = OpenAI( + base_url="https://gateway.", + api_key="" + ) + + completion = client.chat.completions.create( + model="NousResearch/Llama-2-7b-chat-hf", + messages=[ + { + "role": "user", + "content": "Compose a poem that explains the concept of recursion in programming.", + } + ] + ) + + print(completion.choices[0].message.content) + +.. note:: + + dstack automatically handles authentication on the gateway using dstack's tokens. Meanwhile, if you don't want to configure a gateway, you can provision dstack `Task` instead of `Service`. The `Task` is for development purpose only. If you want to know more about hands-on materials how to serve vLLM using dstack, check out `this repository `__ diff --git a/docs/source/serving/integrations.rst b/docs/source/serving/integrations.rst index 2066e80b03298..83a8b5a88bd38 100644 --- a/docs/source/serving/integrations.rst +++ b/docs/source/serving/integrations.rst @@ -9,4 +9,5 @@ Integrations deploying_with_triton deploying_with_bentoml deploying_with_lws + deploying_with_dstack serving_with_langchain From 87a658c81219568fc30081d9cc11327238160563 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 30 May 2024 13:13:46 -0500 Subject: [PATCH 376/413] Bump version to v0.4.3 (#5046) --- vllm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/__init__.py b/vllm/__init__.py index a0e154d24087c..dc59bf4a81931 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -12,7 +12,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -__version__ = "0.4.2" +__version__ = "0.4.3" __all__ = [ "LLM", From ab8eced95a2a379024f95263c132b0be85b25af1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 30 May 2024 19:11:58 +0000 Subject: [PATCH 377/413] make tunedgemm using custom kernels for BS=1 --- vllm/model_executor/layers/tuned_gemm.py | 49 +++++++----------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 785840aef51fa..313cbc2a00cf6 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -33,23 +33,6 @@ def load_best_sols(self): if self.tune_path is not None and Path(self.tune_path).is_file(): self.bestsols = pd.read_csv(self.tune_path) - def apply_custom(self, ds): - M, N, K = ds['M'], ds['N'], ds['K'] - #apply custom matvec (only for f16 dtype) - if N == 1: - ds1 = ds.copy() - ds1['libtype'] = 'custom' - if K == 8192 and (M == 1280 or M == 7168): #NOQA: SIM114 - ds1['solidx'] = 8 - return ds1 - elif K == 3584 and M == 8192: - ds1['solidx'] = 8 - return ds1 - elif K <= 8192 and K % 8 == 0 and M % 4 == 0: - ds1['solidx'] = 1 - return ds1 - return ds - def create_ds(self): df = self.bestsols solds = {} @@ -60,8 +43,6 @@ def create_ds(self): soltype = 1 elif ds['libtype'] == 'rocblas': soltype = 2 - elif ds['libtype'] == 'custom': - soltype = 3 solds[key] = (soltype, int(ds['solidx'])) self.solids = solds #print('>>>',solds) @@ -86,27 +67,23 @@ def mm(self, inp, weights): m = weights.shape[0] n = inp_view.shape[0] k = inp_view.shape[1] + if n == 1 and inp_view.dtype == torch.float16: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp.dtype, + device='cuda') + if (k == 8192 and + (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weights, inp_view, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weights, inp_view, out, 4) + else: + out = F.linear(inp_view, weights) + return out soltype, solidx = self.query_sol(m=m, n=n, k=k) if soltype == 1: #print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) - elif soltype == 3: - ##only matvec is supported currently - out = torch.empty(inp.shape[0], - weights.shape[0], - dtype=torch.float16, - device='cuda') - #print('>>>Matvec',inp.shape,weights.shape,soltype,solidx) - if solidx <= 1: - _custom_C.LLMM1(weights, inp, out, 4) - elif solidx == 2: - _custom_C.LLMM1(weights, inp, out, 2) - elif solidx == 8: - _custom_C.LLMM1(weights, inp, out, 8) - elif solidx == 20: - _custom_C.LLZZ(weights, inp, out, 0) - elif solidx == 21: - _custom_C.LLZZ(weights, inp, out, 1) elif soltype == 2: #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) From 10e07c5d82e648354f4027a9b414e71ba8b39534 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 30 May 2024 19:37:28 +0000 Subject: [PATCH 378/413] Revert "make tunedgemm using custom kernels for BS=1" This reverts commit ab8eced95a2a379024f95263c132b0be85b25af1. --- vllm/model_executor/layers/tuned_gemm.py | 49 +++++++++++++++++------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 313cbc2a00cf6..785840aef51fa 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -33,6 +33,23 @@ def load_best_sols(self): if self.tune_path is not None and Path(self.tune_path).is_file(): self.bestsols = pd.read_csv(self.tune_path) + def apply_custom(self, ds): + M, N, K = ds['M'], ds['N'], ds['K'] + #apply custom matvec (only for f16 dtype) + if N == 1: + ds1 = ds.copy() + ds1['libtype'] = 'custom' + if K == 8192 and (M == 1280 or M == 7168): #NOQA: SIM114 + ds1['solidx'] = 8 + return ds1 + elif K == 3584 and M == 8192: + ds1['solidx'] = 8 + return ds1 + elif K <= 8192 and K % 8 == 0 and M % 4 == 0: + ds1['solidx'] = 1 + return ds1 + return ds + def create_ds(self): df = self.bestsols solds = {} @@ -43,6 +60,8 @@ def create_ds(self): soltype = 1 elif ds['libtype'] == 'rocblas': soltype = 2 + elif ds['libtype'] == 'custom': + soltype = 3 solds[key] = (soltype, int(ds['solidx'])) self.solids = solds #print('>>>',solds) @@ -67,23 +86,27 @@ def mm(self, inp, weights): m = weights.shape[0] n = inp_view.shape[0] k = inp_view.shape[1] - if n == 1 and inp_view.dtype == torch.float16: - out = torch.empty(inp_view.shape[0], - weights.shape[0], - dtype=inp.dtype, - device='cuda') - if (k == 8192 and - (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): - _custom_C.LLMM1(weights, inp_view, out, 8) - elif k <= 8192 and k % 8 == 0 and m % 4 == 0: - _custom_C.LLMM1(weights, inp_view, out, 4) - else: - out = F.linear(inp_view, weights) - return out soltype, solidx = self.query_sol(m=m, n=n, k=k) if soltype == 1: #print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) + elif soltype == 3: + ##only matvec is supported currently + out = torch.empty(inp.shape[0], + weights.shape[0], + dtype=torch.float16, + device='cuda') + #print('>>>Matvec',inp.shape,weights.shape,soltype,solidx) + if solidx <= 1: + _custom_C.LLMM1(weights, inp, out, 4) + elif solidx == 2: + _custom_C.LLMM1(weights, inp, out, 2) + elif solidx == 8: + _custom_C.LLMM1(weights, inp, out, 8) + elif solidx == 20: + _custom_C.LLZZ(weights, inp, out, 0) + elif solidx == 21: + _custom_C.LLZZ(weights, inp, out, 1) elif soltype == 2: #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) From 9672390bc1c381d7f20d930e76e005a09c04c9b9 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 30 May 2024 19:51:52 +0000 Subject: [PATCH 379/413] Using custom kernels inside TunedGemm when not tuned. --- vllm/model_executor/layers/tuned_gemm.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 785840aef51fa..cc5ae2b8601f2 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -111,7 +111,6 @@ def mm(self, inp, weights): #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) else: - if (self.save_gemm == 1): #print('>>>Tgemm Default',inp_view.shape, # inp.shape,weights.shape,soltype,solidx) @@ -124,7 +123,19 @@ def mm(self, inp, weights): }) ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - out = F.linear(inp, weights) + + if n == 1 and inp_view.dtype == torch.float16: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + if (k == 8192 and + (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weights, inp_view, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weights, inp_view, out, 4) + else: + out = F.linear(inp, weights) if batched: return out.view(inp.shape[0], inp.shape[1], weights.shape[0]) else: From 45a1a69b9841a4cb7cc70788cf7dea1a2d3ec3d6 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 30 May 2024 16:37:16 -0500 Subject: [PATCH 380/413] [Build] Disable sm_90a in cu11 (#5141) --- CMakeLists.txt | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b668cbc97de15..8df3a7a26d884 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -177,7 +177,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) FetchContent_Declare( - cutlass + cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # CUTLASS 3.5.0 GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc @@ -200,11 +200,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The CUTLASS kernels for Hopper require sm90a to be enabled. # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. # That adds an extra 17MB to compiled binary, so instead we selectively enable it. - set_source_files_properties( - "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 11) + set_source_files_properties( + "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") + endif() endif() From b35be5403f3cf8631aefe02a35d97013657e2e47 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 30 May 2024 17:04:37 -0700 Subject: [PATCH 381/413] [Bugfix] Avoid Warnings in SparseML Activation Quantization (#5120) --- .../compressed_tensors_w8a8_statictensor.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index d16e570d12202..64a88b01cd260 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -89,23 +89,34 @@ def create_weights(self, layer: torch.nn.Module, requires_grad=False) layer.register_parameter("weight", weight) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - - set_weight_attrs(weight, {"weight_loader": weight_loader}) - + set_weight_attrs(weight, { + "weight_loader": weight_loader, + "input_dim": 1, + "output_dim": 0, + }) layer.register_parameter("input_scale", input_scale) - set_weight_attrs(input_scale, {"weight_loader": weight_loader}) + set_weight_attrs(input_scale, { + "weight_loader": weight_loader, + "ignore_warning": True, + }) layer.register_parameter("input_zero_point", input_zero_point) - set_weight_attrs(input_zero_point, {"weight_loader": weight_loader}) + set_weight_attrs(input_zero_point, { + "weight_loader": weight_loader, + "ignore_warning": True, + }) layer.register_parameter("weight_scale", weight_scale) - set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) set_weight_attrs( weight_scale, { + "weight_loader": weight_loader, "shard_splitter": self.scales_shard_splitter, - "logical_widths": output_partition_sizes + "logical_widths": output_partition_sizes, + "ignore_warning": True, }) layer.register_parameter("weight_zero_point", weight_zero_point) - set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader}) + set_weight_attrs(weight_zero_point, { + "weight_loader": weight_loader, + "ignore_warning": True + }) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): weight = layer.weight From 6d21fa1cadf1e623e302eb04c15e4927febc8cf1 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 30 May 2024 22:02:11 -0400 Subject: [PATCH 382/413] [Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5) (#5136) --- csrc/quantization/marlin/sparse/common/mma.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index 45ab67a78a1de..fd3dbda5b9c93 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -32,7 +32,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, float* c = reinterpret_cast(&frag_c); if (psel == 0) { asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) @@ -40,7 +41,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "r"(e[0])); asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) @@ -49,7 +51,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(e[0])); } else { asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) @@ -57,7 +60,8 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "r"(e[0])); asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." + "f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) From 533c2177925ba19934eab0095a50d0a783185e6b Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 31 May 2024 02:13:01 +0000 Subject: [PATCH 383/413] Fix cutlass sm_90a vesrion in CMakeList --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8df3a7a26d884..5f991af61d9bd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -200,7 +200,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The CUTLASS kernels for Hopper require sm90a to be enabled. # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. # That adds an extra 17MB to compiled binary, so instead we selectively enable it. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 11) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu" PROPERTIES From a22dea54d3e80bf069cfeed8002a193ef8b18e1b Mon Sep 17 00:00:00 2001 From: SnowDist Date: Fri, 31 May 2024 10:24:41 +0800 Subject: [PATCH 384/413] [Model] Support MAP-NEO model (#5081) Co-authored-by: Zhuohan Li --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- benchmarks/kernels/benchmark_rope.py | 2 +- csrc/attention/attention_kernels.cu | 6 ++++++ csrc/cpu/attention.cpp | 6 ++++++ tests/kernels/test_attention.py | 2 +- tests/kernels/test_cache.py | 2 +- tests/kernels/test_pos_encoding.py | 2 +- vllm/attention/ops/paged_attn.py | 2 +- 8 files changed, 18 insertions(+), 6 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index fc9621e885dc4..e6f4e9e6b9716 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -170,7 +170,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 9188e811e2982..00e55f6060b52 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -93,7 +93,7 @@ def benchmark_rope_kernels_multi_lora( parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--head-size", type=int, - choices=[64, 80, 96, 112, 128, 256], + choices=[64, 80, 96, 112, 128, 192, 256], default=128) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 45edc3252380c..8f89f89786c3b 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -754,6 +754,9 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V1(256); break; @@ -911,6 +914,9 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; case 256: LAUNCH_PAGED_ATTENTION_V2(256); break; diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 438e9bdb19f50..ed8cfbd421f0f 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher( case 128: LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; + case 192: + LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; case 256: LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); break; @@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher( case 128: LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); break; + case 192: + LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE); + break; case 256: LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); break; diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index fdf313262ca97..8bc4766fc93c4 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -28,7 +28,7 @@ # FlashAttention forward only supports head dimension at most 128 # https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 -HEAD_SIZES = [64, 80, 96, 112, 128, 256 +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256 ] if not is_hip() else [64, 80, 96, 112, 128] BLOCK_SIZES = [16, 32] diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9f0cb60dc16e2..29572cfa57499 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -11,7 +11,7 @@ NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing -HEAD_SIZES = [64, 80, 96, 112, 128, 256] +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] BLOCK_SIZES = [8, 16, 32] # Arbitrary values for testing diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 076730cdbae0d..fbabc02bf9a9d 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -10,7 +10,7 @@ IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] -HEAD_SIZES = [64, 80, 96, 112, 128, 256] +HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size NUM_HEADS = [7, 17] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index e119fdcf11113..a214f40d16514 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -31,7 +31,7 @@ class PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] + return [64, 80, 96, 112, 128, 192, 256] @staticmethod def get_kv_cache_shape( From e9d3aa04f6e55e2bb540f0810da97ddd0deebb13 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Fri, 31 May 2024 00:00:26 -0500 Subject: [PATCH 385/413] Revert "[Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5)" (#5149) --- csrc/quantization/marlin/sparse/common/mma.h | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index fd3dbda5b9c93..45ab67a78a1de 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -32,8 +32,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, float* c = reinterpret_cast(&frag_c); if (psel == 0) { asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." - "f32 " + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) @@ -41,8 +40,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "r"(e[0])); asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." - "f32 " + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x0;\n" : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) @@ -51,8 +49,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(e[0])); } else { asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." - "f32 " + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) @@ -60,8 +57,7 @@ __device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1, "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]), "r"(e[0])); asm volatile( - "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16." - "f32 " + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 " "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, " "{%12,%13,%14,%15}, %16, 0x1;\n" : "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) From a377f0bd5e1fa0ca069e3dbf28f4de5af64d0bb1 Mon Sep 17 00:00:00 2001 From: functionxu123 <1229853312@qq.com> Date: Fri, 31 May 2024 13:14:50 +0800 Subject: [PATCH 386/413] [Misc]: optimize eager mode host time (#4196) Co-authored-by: xuhao --- vllm/utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 26140e15636a4..2781eceb7ba98 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -17,6 +17,7 @@ Hashable, List, Optional, OrderedDict, Tuple, TypeVar, Union) +import numpy as np import psutil import torch @@ -501,11 +502,6 @@ def str_to_int_tuple(s: str) -> Tuple[int, ...]: f"(e.g., 1, 2, 3). Given input: {s}") from e -def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]: - assert len(x) <= max_len - return x + [pad] * (max_len - len(x)) - - def make_tensor_with_pad( x: List[List[int]], max_len: int, @@ -518,7 +514,10 @@ def make_tensor_with_pad( The padding is applied to the end of each inner list until it reaches `max_len`. """ - padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x] + padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb return torch.tensor(padded_x, dtype=dtype, device=device) From 91bc8bc0a89f6583e3688c3e1d8f640d91a88cab Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 31 May 2024 21:17:40 +0000 Subject: [PATCH 387/413] adding fp8 gemm computation --- CMakeLists.txt | 9 + csrc/cache.h | 7 +- csrc/cache_kernels.cu | 63 ---- csrc/ops.h | 22 ++ csrc/pybind.cpp | 13 +- .../fp8/amd_detail/quant_utils.cuh | 211 +++++++++++--- csrc/quantization/fp8/convert_kernel.cu | 96 +++++++ csrc/quantization/fp8/gemm_kernel.cu | 269 ++++++++++++++++++ vllm/config.py | 2 +- vllm/model_executor/layers/linear.py | 77 +++++ .../layers/quantization/__init__.py | 2 + .../layers/quantization/fp8_rocm.py | 232 +++++++++++++++ vllm/model_executor/model_loader.py | 5 + vllm/model_executor/models/llama.py | 94 ++++++ vllm/model_executor/weight_utils.py | 4 + 15 files changed, 997 insertions(+), 109 deletions(-) create mode 100644 csrc/quantization/fp8/convert_kernel.cu create mode 100644 csrc/quantization/fp8/gemm_kernel.cu create mode 100644 vllm/model_executor/layers/quantization/fp8_rocm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 2510eb4e08ec8..3b6ea4b570a99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,13 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +# +# Set rocm version dev int. +# +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_GPU_FLAGS "-DROCM_VERSION=${ROCM_VERSION_DEV_INT}") +endif() + # # Define extension targets # @@ -167,6 +174,8 @@ set(VLLM_EXT_SRC "csrc/layernorm_kernels.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" + "csrc/quantization/fp8/convert_kernel.cu" + "csrc/quantization/fp8/gemm_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") diff --git a/csrc/cache.h b/csrc/cache.h index 718a5f6cfd7f7..fa26ddb688588 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -22,9 +22,4 @@ void reshape_and_cache( torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const float kv_scale); - -// Just for unittest -void convert_fp8( - torch::Tensor& src_cache, - torch::Tensor& dst_cache); + const float kv_scale); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 24aaa2ff3e263..1638b2aeee77e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -274,66 +274,3 @@ void reshape_and_cache( TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } } - -namespace vllm { - -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const int64_t block_stride) { - const int64_t block_idx = blockIdx.x; - for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { - int64_t idx = block_idx * block_stride + i; -#if defined(ENABLE_FP8_E5M2) - dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); -#elif defined(ENABLE_FP8_E4M3) - dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); -#else - assert(false); -#endif - } -} - -} // namespace vllm - -#define CALL_CONVERT_FP8(Tout, Tin) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), \ - block_stride); - -void convert_fp8( - torch::Tensor& src_cache, - torch::Tensor& dst_cache) -{ - torch::Device src_device = src_cache.device(); - torch::Device dst_device = dst_cache.device(); - TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") - TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK( - src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); - at::cuda::OptionalCUDAGuard device_guard(src_device); - - int64_t num_blocks = src_cache.size(0); - int64_t block_stride = src_cache.stride(0); - - dim3 grid(num_blocks); - dim3 block(std::min(block_stride, int64_t(512))); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); - } -} diff --git a/csrc/ops.h b/csrc/ops.h index 41ecc1e89371b..3acbc1b3f8363 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -159,3 +159,25 @@ std::pair, std::vector> get_graph_buffer_ipc_meta( void register_graph_buffers(fptr_t _fa, const std::vector &handles, const std::vector> &offsets); #endif + +void convert_fp8( + torch::Tensor& src_cache, + torch::Tensor& dst_cache, + torch::Tensor& scale); + +torch::Tensor fp8_gemm( + torch::Tensor& a, + torch::Tensor& b, + torch::Tensor& scaleA, + torch::Tensor& scaleB, + torch::Tensor& scaleD, + int algo_idx +); + +torch::Tensor fp8_gemm_16( + torch::Tensor& a, + torch::Tensor& b, + torch::Tensor& scaleA, + torch::Tensor& scaleB, + int algo_idx +); \ No newline at end of file diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index de02afc162113..d533b6fcca498 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -60,6 +60,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "batched_rotary_embedding", &batched_rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + +// FP8 + ops.def( + "convert_fp8", + &convert_fp8, + "Convert the key and value tensors to or from fp8 data type"); + ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM"); + + ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM"); // Quantization ops #ifndef USE_ROCM @@ -90,10 +99,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); - cache_ops.def( - "convert_fp8", - &convert_fp8, - "Convert the key and value cache to fp8 data type"); // Cuda utils pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd_detail/quant_utils.cuh index 894160972d9f4..9efe735cb0835 100644 --- a/csrc/quantization/fp8/amd_detail/quant_utils.cuh +++ b/csrc/quantization/fp8/amd_detail/quant_utils.cuh @@ -303,18 +303,18 @@ __inline__ __device__ bf16_8_t vec_conversion(const Float8_& } -/* Scaled and vectorized conversions, for data exchange between high and low precision domains +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains - Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) - s.t. - Quantize(HP / scale) => FP8 - Dequant(FP8) * scale => HP + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP */ // fp8 -> half template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, const float scale) +__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, float scale) { hip_fp8 f8{a, hip_fp8::from_bits()}; __half_raw res; @@ -324,9 +324,9 @@ __inline__ __device__ uint16_t scaled_vec_conversion(const ui // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, const float scale) +__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, float scale) { -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) +#if defined(__HIP__MI300__) const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); union { __half2_raw h2r; @@ -335,7 +335,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion(const u tmp.h2r.x.data = f2[0] * scale; tmp.h2r.y.data = f2[1] * scale; return tmp.ui32; -#else +#else union { uint16_t u16[2]; uint32_t u32; @@ -349,7 +349,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion(const u // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, const float scale) +__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, float scale) { union { uint2 u32x2; @@ -362,7 +362,7 @@ __inline__ __device__ uint2 scaled_vec_conversion(const uint32_ // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale) +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, float scale) { union { uint4 u64x2; @@ -377,7 +377,7 @@ using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale) +__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { hip_fp8 f8{a, hip_fp8::from_bits()}; float f{f8}; @@ -388,7 +388,7 @@ using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale) +__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, float scale) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); @@ -398,7 +398,7 @@ __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint1 // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, const float scale) +__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, float scale) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); @@ -408,7 +408,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion(const u // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, const float scale) +__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, float scale) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale); @@ -423,7 +423,7 @@ __inline__ __device__ bf16_8_t scaled_vec_conversion(const uint // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, const float scale) +__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, float scale) { hip_fp8 fp8{a, hip_fp8::from_bits()}; return static_cast(fp8) * scale; @@ -431,9 +431,9 @@ __inline__ __device__ float scaled_vec_conversion(const uint8_t& // fp8x2 -> float2 template <> -__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, const float scale) +__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, float scale) { -#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) +#if defined(__HIP__MI300__) float2 res; const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); res.x = f2[0] * scale; @@ -457,9 +457,17 @@ __inline__ __device__ Float4_ scaled_vec_conversion(const uin return res; } +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, float scale) +{ + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; +} + // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, const float scale) +__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, float scale) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale); @@ -472,45 +480,178 @@ __inline__ __device__ Float8_ scaled_vec_conversion(const uint2& return res; } - -/* Quantize(HP / scale) => FP8 */ - -// TODO(Hai): vectorized to add - // half -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, const float scale) +__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, float scale) { __half_raw tmp; tmp.x = a; - hip_fp8 f8{static_cast(tmp.data)/scale}; + hip_fp8 f8{static_cast(tmp.data / scale)}; return f8.data; } +// halfx2 -> fp8x2 +template<> +__inline__ __device__ uint16_t scaled_vec_conversion(const uint32_t& a, float scale) +{ +#ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); +#else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; +#endif +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const uint2& a, float scale) +{ + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, float scale) +{ + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; +} + // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, const float scale) +__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, float scale) { - hip_fp8 res{__bfloat162float(a)/scale}; + hip_fp8 res{__bfloat162float(a) / scale}; return res.data; } +// bf16x2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion(const __nv_bfloat162& a, float scale) +{ + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +} + +// bf16x4 -> fp8x4 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const bf16_4_t& a, float scale) +{ + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const bf16_8_t& a, float scale) +{ + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; +} + // float -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, const float scale) +__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, float scale) { - hip_fp8 f8(a/scale); + hip_fp8 f8(a); return f8.data; } -// fp8x4 -> float4 +// floatx2 -> fp8x2 template <> -__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, const float scale) +__inline__ __device__ uint16_t scaled_vec_conversion(const float2& a, float scale) { - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +#ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f,f2.f, 0, 0); +#else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +#endif +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const float4& a, float scale) +{ + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; } } diff --git a/csrc/quantization/fp8/convert_kernel.cu b/csrc/quantization/fp8/convert_kernel.cu new file mode 100644 index 0000000000000..74636c9df9ffd --- /dev/null +++ b/csrc/quantization/fp8/convert_kernel.cu @@ -0,0 +1,96 @@ +#include +#include +#include + +#include "../../attention/attention_dtypes.h" +#if defined(ENABLE_FP8_E5M2) +#include "../fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "amd_detail/quant_utils.cuh" +#endif + +namespace vllm +{ + +template +__global__ void convert_fp8_kernel( + const Tin* __restrict__ src_data, Tout* __restrict__ dst_data, const float* scale, size_t N) +{ + const int64_t block_idx = blockIdx.x; + + using V_in_vec = typename Vec::Type; + using V_out_vec = typename Vec::Type; + auto dst_data_vec = reinterpret_cast(dst_data); + auto src_data_vec = reinterpret_cast(src_data); + + int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); + auto idx = startIdx; + if (idx >= N) { + return; + } + dst_data_vec[idx] = fp8_e4m3::scaled_vec_conversion(src_data_vec[idx], *scale); + //dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); + + //for (int64_t i = 0; i < loopSize; ++i) { + // auto idx = startIdx + i; + // if (idx >= N) { + // return; + // } + // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); + //} +} + +} // namespace vllm + +template +struct call_convert_fp8 +{ + void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) + { + const auto N = src_data.numel() / 2; + //std::cout << N << "\n"; + constexpr uint32_t loopSize = 1;//std::max(N / 50000000LL, 1); + constexpr dim3 numThreads{1024, 1, 1}; + auto neededBlocks = (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); + uint32_t actualBlocks = neededBlocks; + + //static uint32_t maxBlocks = 0; + //if (actualBlocks != maxBlocks) { + // maxBlocks = actualBlocks; + // std::cout << actualBlocks << "\n"; + //} + + const dim3 grid{actualBlocks, 1, 1}; + + const auto stream = at::cuda::getCurrentCUDAStream(); + + vllm::convert_fp8_kernel + <<>>(reinterpret_cast(src_data.data_ptr()), + reinterpret_cast(dst_data.data_ptr()), (float*)scale.data_ptr(), N); + } +}; + +void convert_fp8(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) +{ + torch::Device src_device = src_data.device(); + torch::Device dst_device = dst_data.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + auto t1 = src_data.dtype(); + auto t2 = dst_data.dtype(); + if (src_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); + } +} \ No newline at end of file diff --git a/csrc/quantization/fp8/gemm_kernel.cu b/csrc/quantization/fp8/gemm_kernel.cu new file mode 100644 index 0000000000000..0463cc75eac6c --- /dev/null +++ b/csrc/quantization/fp8/gemm_kernel.cu @@ -0,0 +1,269 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define max_workspace_size 2 * 128 * 1024 * 1024 + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#ifndef CHECK_HIP_ERROR +#define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_HIPBLASLT_ERROR +#define CHECK_HIPBLASLT_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf( \ + stderr, "hipBLASLt error: '%s'(%d) at %s:%d\n", hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, + torch::Tensor& scaleD, int algo_idx) +{ + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + auto d_scaleD = scaleD.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E4M3_FNUZ, HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + inputs.scaleD = d_scaleD; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp8 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } + } + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; +} + +torch::Tensor fp8_gemm_16( + torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, int algo_idx) +{ + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, HIP_R_16F, + HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp16 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } + } + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; +} diff --git a/vllm/config.py b/vllm/config.py index bd13949a25623..5496c892c1638 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,7 +173,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] + supported_quantization = ["awq", "gptq", "squeezellm", "marlin", "fp8"] rocm_not_supported_quantization = ["awq", "marlin"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 227ec7c848d4b..1d2a8782a96f1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -16,6 +16,7 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hip +from vllm._C import ops logger = init_logger(__name__) @@ -286,11 +287,72 @@ def __init__( assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, skip_bias_add, params_dtype, linear_method) + self._wsfs = [0, 0] + self._asfs = [0, 0] + self._osfs = [0, 0] + for name, weight in self.linear_weights.items(): + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs( + weight, { + "asf_loader": self.asf_loader, + "wsf_loader": self.wsf_loader, + "osf_loader": self.osf_loader, + "rescale": self.rescale, + }) + + def rescale(self): + if self._wsfs[0] < self._wsfs[1]: + factor = self._wsfs[0] / self._wsfs[1] + else: + factor = self._wsfs[1] / self._wsfs[0] + + weight_param = self.linear_weights["weight"] + param_data = weight_param.data + loaded_shard_id = 0 if self._wsfs[0] < self._wsfs[1] else 1 + tp_size = get_tensor_model_parallel_world_size() + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + output_dim = getattr(weight_param, "output_dim", None) + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + param_data16 = torch.empty_like(param_data, dtype=torch.float16) + ops.convert_fp8( + param_data, param_data16, + torch.tensor(factor, dtype=torch.float32, device='cuda')) + ops.convert_fp8(param_data16, param_data, + torch.tensor(1, dtype=torch.float32, device='cuda')) + #param_data *= factor[0] + pass + + def wsf_loader(self, param: Parameter, sf: torch.Tensor, + loaded_shard_id: int): + self._wsfs[loaded_shard_id] = sf + param.data.copy_(max(self._wsfs)) + + def asf_loader(self, param: Parameter, sf: torch.Tensor, + loaded_shard_id: int): + self._asfs[loaded_shard_id] = sf + param.data.copy_(max(self._asfs)) + + def osf_loader(self, param: Parameter, sf: torch.Tensor, + loaded_shard_id: int): + self._osfs[loaded_shard_id] = 1 / sf.item() + param.data.copy_(max(self._osfs)) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[int] = None): + if (param.data.dtype == torch.float8_e4m3fnuz + and loaded_weight.dtype != torch.int8): + # Quantized weights for this layer have already been loaded + return + if (param.data.dtype != loaded_weight.dtype + and loaded_weight.dtype == torch.int8 + and param.data.dtype != torch.float8_e4m3fnuz): + param.data = torch.empty_like(param.data, + dtype=torch.float8_e4m3fnuz) + param_data = param.data output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: @@ -353,6 +415,10 @@ def weight_loader(self, "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") assert param_data.shape == loaded_weight.shape + if param_data.dtype == torch.float8_e4m3fnuz: + loaded_weight[loaded_weight == -128] = 0 + assert param_data.is_contiguous() and loaded_weight.is_contiguous() + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) param_data.copy_(loaded_weight) @@ -417,9 +483,20 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: Optional[str] = None): + if param.data.dtype == torch.float8_e4m3fnuz and loaded_weight.dtype == torch.float16: + # Quantized weights for this layer have already been loaded + return param_data = param.data output_dim = getattr(param, "output_dim", None) + if param.data.dtype != loaded_weight.dtype and loaded_weight.dtype == torch.int8: + param.data = torch.empty_like(param.data, + dtype=torch.float8_e4m3fnuz) + loaded_weight[loaded_weight == -128] = 0 + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) + self.weight_loader(param, loaded_weight) + return + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ad988d48755b0..de4c3f14e2d4a 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -6,12 +6,14 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.fp8_rocm import Fp8RocmConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, "marlin": MarlinConfig, + "fp8": Fp8RocmConfig } diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py new file mode 100644 index 0000000000000..e464b41329cbd --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -0,0 +1,232 @@ +from typing import Any, Dict, Iterator, List, Optional, Tuple +from safetensors import safe_open +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, +) +from vllm._C import ops +import pandas as pd +import os + + +class Fp8RocmConfig(QuantizationConfig): + def __init__(self, config) -> None: + self.quantized_weights_path = config["quantized_weights"] + self._tuned = {} + self._stats = {} + gemm_type = os.getenv("FP8_GEMM", "fp8_16") + #print(f"Integral Cross factor = {self.factor}") + if gemm_type == "fp8_8": + self.gemm_method = Fp8RocmLinearLayer.apply_fp8_8 + tuned_filename = "/projects/tuned_fp8_8.csv" + elif gemm_type == "fp8_16": + self.gemm_method = Fp8RocmLinearLayer.apply_fp8_16 + tuned_filename = "/projects/tuned_fp8_16.csv" + else: + raise Exception(f"Unknown fp8 gemm type: {gemm_type}") + try: + df = pd.read_csv(tuned_filename) + except: + return + + for i in range(len(df)): + shape = df.iloc[i] + m = shape["M"] + n = shape["N"] + k = shape["K"] + algo = shape["algo"] + self._tuned[(m, n, k)] = algo + + @staticmethod + def get_config_filenames() -> List[str]: + return ["serenity_config.json"] + + @classmethod + def from_config(cls, config) -> "Fp8RocmConfig": + return cls(config) + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.uint8, torch.float8_e4m3fnuz] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(cls) -> str: + return "serenity" + + def get_linear_method(self) -> "Fp8RocmLinearLayer": + return Fp8RocmLinearLayer(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key + ), f"Overwriting existing tensor attribute: {key}" + setattr(weight, key, value) + + +class Fp8RocmLinearLayer(LinearMethodBase): + def __init__(self, config: Fp8RocmConfig) -> None: + self._config = config + + def get_tensor(self) -> Iterator[Tuple[str, torch.Tensor]]: + with safe_open( + self._config.quantized_weights_path, framework="pt" + ) as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + # for a, b in self.get_tensor(): + # print(f"{a}: {b.shape}") + # pass + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + # orig_weight = Parameter(torch.empty(output_size_per_partition, + # input_size_per_partition, + # dtype=params_dtype), + # requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + # set_weight_attrs(orig_weight, {"input_dim": 1, "output_dim": 0}) + return { + "weight": weight, + # "orig_weight": orig_weight, + "activation_scaling_factor": Parameter( + torch.empty(1, dtype=torch.float32, device="cuda") + ), + "weights_scaling_factor": Parameter( + torch.empty(1, dtype=torch.float32, device="cuda") + ), + "output_scaling_factor": Parameter( + torch.empty(1, dtype=torch.float32, device="cuda") + ) + } + + def apply_fp8_16( + self, + x: torch.Tensor, + weight: torch.Tensor, + asf: torch.Tensor, + wsf: torch.Tensor, + osf: torch.Tensor, + ) -> torch.Tensor: + x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + ops.convert_fp8(x, x8, asf) + m = weight.shape[0] + n = x.shape[0] + k = x.shape[1] + + algo = self._config._tuned.get((m, n, k)) + if algo is None: + import os + + # print(f"Not found: {m} {n} {k}") + if os.getenv("TUNE_FP8") == "1": + try: + df = pd.read_csv("/projects/fp8_tune.csv") + except: + df = pd.DataFrame(columns=["M", "N", "K"]) + df = pd.concat( + [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] + ).drop_duplicates() + df.to_csv("/projects/fp8_tune.csv", index=False) + # print(f"{m},{n},{k}") + algo = 0 + res = ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) + return res + + def apply_fp8_8( + self, + x: torch.Tensor, + weight: torch.Tensor, + asf: torch.Tensor, + wsf: torch.Tensor, + osf: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert not bias + x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + ops.convert_fp8(x, x8, asf) + m = weight.shape[0] + n = x.shape[0] + k = x.shape[1] + + algo = self._config._tuned.get((m, n, k)) + if algo is None: + import os + + # print(f"Not found: {m} {n} {k}") + if os.getenv("TUNE_FP8") == "1": + try: + df = pd.read_csv("/projects/fp8_tune.csv") + except: + df = pd.DataFrame(columns=["M", "N", "K"]) + df = pd.concat( + [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] + ).drop_duplicates() + df.to_csv("/projects/fp8_tune.csv", index=False) + # print(f"{m},{n},{k}") + algo = 0 + + res = ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) + res16 = torch.empty_like(res, dtype=torch.float16) + ops.convert_fp8(res, res16, 1/osf) + return res16 + + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + weight: torch.Tensor = weights["weight"] + if weight.dtype == torch.float8_e4m3fnuz: + asf: torch.Tensor = weights["activation_scaling_factor"] * 2 + wsf: torch.Tensor = weights["weights_scaling_factor"] * 2 + osf: torch.Tensor = weights["output_scaling_factor"] / 2 + #my_osf: torch.Tensor = self._config.factor / weights["my_osf"] + #with open("ratio.txt", "a") as f: + # f.write(f'{weights["output_scaling_factor"].item()},{weights["my_osf"].item()}\n') + #return self.test(weight, asf, wsf, osf, x, weights["my_osf"], bias) + return self._config.gemm_method(self, x, weight, asf, wsf, osf) + + return F.linear(x, weight, bias) \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2745dbd89ab0f..b348bf8e323ec 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -97,6 +97,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, # random values to the weights. initialize_dummy_weights(model) else: + load_ammo_quantized_weights = getattr(model, "load_ammo_quantized_weights", None) + if quant_config is not None and load_ammo_quantized_weights is not None: + load_ammo_quantized_weights(model_config.model, model_config.download_dir, + model_config.load_format, model_config.revision, + quant_config) # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2172e96e8501c..a4d7b6119e0b5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -32,6 +32,7 @@ from vllm.config import LoRAConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -428,6 +429,99 @@ def load_weights(self, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def load_ammo_quantized_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + quant_config: Optional[QuantizationConfig] = None): + import os.path + #import json + params_dict = dict(self.named_parameters()) + #with open("/projects/a.txt", "r") as f: + # j = json.load(f) + # for k, v in j.items(): + # params_dict[k].data.copy_(v) + quant_shards = [ + ("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj + ("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj + ] + quant_map = [ + ("mlp.down_proj", "mlp.proj"), + ("self_attn.o_proj", "attention.dense"), + ("self_attn.qkv_proj", "attention.qkv"), + ] + tp_rank = get_tensor_model_parallel_rank() + quantized_filename: str = quant_config.quantized_weights_path[ + f"rank0"] + #quantized_filename = "quantized/osf/rank0.safetensors" + if not quantized_filename.startswith('/'): + quantized_filename = os.path.join(model_name_or_path, + quantized_filename) + #if not os.path.isdir(quantized_filename): + # quantized_filename = os.path.dirname(quantized_filename) + for name, loaded_weight in hf_model_weights_iterator( + quantized_filename, cache_dir, 'safetensors', revision): + #print(name) + name = name.replace('transformer', 'model') + name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor') + #if "output_scaling_factor" in name: + # print(f"{name} {loaded_weight}") + #if "kv_cache_scaling_factor" in name: + # print(f"KVK: {name} {loaded_weight}") + #if "lm_head" in name: + # print(f"LM {name}: {loaded_weight}") + for (param_name, weight_name, shard_id) in quant_shards: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "activation_scaling_factor" in name: + asf_loader = getattr(param, "asf_loader", + default_weight_loader) + asf_loader(param, loaded_weight, shard_id) + elif "weights_scaling_factor" in name: + wsf_loader = getattr(param, "wsf_loader", + default_weight_loader) + #print(f"{name} {shard_id} {loaded_weight}") + wsf_loader(param, loaded_weight, shard_id) + elif "output_scaling_factor" in name: + osf_loader = getattr(param, "osf_loader", + default_weight_loader) + #print(f"{name} {shard_id} {loaded_weight}") + osf_loader(param, loaded_weight, shard_id) + else: + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + for (param_name, weight_name) in quant_map: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + if "activation_scaling_factor" in name or "weights_scaling_factor" in name: + param.data.copy_(loaded_weight) + if "output_scaling_factor" in name: + param.data.copy_(1 / loaded_weight.item()) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + break + rescale = [ + "mlp.gate_up_proj.weight" + ] + for name, param in params_dict.items(): + for suffix in rescale: + if name.endswith(suffix): + param.rescale() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 0961478930d74..8016941216994 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -169,6 +169,10 @@ def prepare_hf_model_weights( elif load_format == "safetensors": use_safetensors = True allow_patterns = ["*.safetensors"] + if os.path.isfile(model_name_or_path): + return (os.path.dirname(model_name_or_path), + [model_name_or_path], + True) elif load_format == "pt": allow_patterns = ["*.pt"] elif load_format == "npcache": From e9899fb7a4d9e032198d26ef84f1dd2cfd9621aa Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 31 May 2024 14:29:19 -0700 Subject: [PATCH 388/413] [Model] Enable FP8 QKV in MoE and refine kernel tuning script (#5039) --- benchmarks/kernels/benchmark_mixtral_moe.py | 48 ++++-- ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 138 +++++++++++++++++ ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 146 ++++++++++++++++++ ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 108 +++++++------ ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 146 ++++++++++++++++++ ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 84 +++++----- ...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 146 ++++++++++++++++++ vllm/model_executor/models/mixtral.py | 9 -- 8 files changed, 711 insertions(+), 114 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 5280b214144c9..196ec8cfce88e 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -11,25 +11,36 @@ from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - -def main(dtype: str): +def main(model, tp_size, gpu, dtype: str): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) method = fused_moe for bs in [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ]: - run_grid(bs, method=method, dtype=dtype) - - -def run_grid(bs, method, dtype: str): - d_model = 4096 + run_grid(bs, + model=model, + method=method, + gpu=gpu, + tp_size=tp_size, + dtype=dtype) + + +def run_grid(bs, model, method, gpu, tp_size, dtype: str): + if model == '8x7B': + d_model = 4096 + model_intermediate_size = 14336 + num_layers = 32 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + num_layers = 56 + else: + raise ValueError(f'Unsupported Mixtral model {model}') num_total_experts = 8 top_k = 2 - tp_size = 2 - model_intermediate_size = 14336 - num_layers = 32 + # tp_size = 2 num_calls = 100 num_warmup_trials = 1 @@ -211,5 +222,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, choices=['float8', 'float16'], help='Data type used for fused_moe kernel computations', ) + parser.add_argument('--model', + type=str, + default='8x7B', + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + parser.add_argument('--tp-size', + type=int, + default=2, + help='Tensor paralleli size') + parser.add_argument('--gpu', + type=int, + default=0, + help="GPU ID for benchmarking") args = parser.parse_args() - sys.exit(main(args.dtype)) + sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..3f3ccdafa88f3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,138 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..0c495e7e290c6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json index 9287808a94d0e..5b78c30f08b68 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -3,61 +3,59 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 }, "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 }, "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1 + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 }, "8": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 5 + "num_stages": 2 }, "16": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "24": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5 }, - "32": { + "24": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 4 }, - "48": { + "32": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 }, - "64": { + "48": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, @@ -65,37 +63,45 @@ "num_warps": 4, "num_stages": 4 }, - "96": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, - "256": { + "96": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, - "512": { + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 }, "1024": { "BLOCK_SIZE_M": 128, @@ -109,7 +115,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, @@ -125,7 +131,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 }, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..60a65724d68b9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json index 2ad07bf79a25c..75f8b0017b9c6 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -2,104 +2,104 @@ "1": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "2": { - "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 2 }, "8": { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 8, + "num_warps": 4, "num_stages": 4 }, "24": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 4 }, "48": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "64": { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 4 }, "96": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4 }, "128": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 }, "256": { "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 4 + "num_stages": 5 }, "512": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 } diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..34b916e574f88 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d6dd7fa1fe9e2..2f4237339486e 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -278,15 +278,6 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - if isinstance( - quant_config, - Fp8Config) and not quant_config.is_checkpoint_fp8_serialized: - print_warning_once( - "For Mixtral FP8 quantization, we currently do not quantize " - "the attention layers until their FP8 performance is improved." - ) - quant_config = None - self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, From 929a7f05eaabd1a6a21f8858d8d6f9ccacb4a315 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 31 May 2024 23:20:11 +0000 Subject: [PATCH 389/413] adding gradlib for fp8 --- gradlib_fp8/csrc/hipbsolgemm.cu | 673 ++++++++++++++++++ gradlib_fp8/gemm_runner.py | 58 ++ gradlib_fp8/gemm_tuner.py | 93 +++ gradlib_fp8/gradlib/GemmTuner.py | 143 ++++ gradlib_fp8/setup.py | 89 +++ .../layers/quantization/fp8_rocm.py | 2 +- 6 files changed, 1057 insertions(+), 1 deletion(-) create mode 100644 gradlib_fp8/csrc/hipbsolgemm.cu create mode 100644 gradlib_fp8/gemm_runner.py create mode 100644 gradlib_fp8/gemm_tuner.py create mode 100644 gradlib_fp8/gradlib/GemmTuner.py create mode 100644 gradlib_fp8/setup.py diff --git a/gradlib_fp8/csrc/hipbsolgemm.cu b/gradlib_fp8/csrc/hipbsolgemm.cu new file mode 100644 index 0000000000000..4d6ca59520dc0 --- /dev/null +++ b/gradlib_fp8/csrc/hipbsolgemm.cu @@ -0,0 +1,673 @@ +// #ifdef __gfx908__ +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h +// #undef __HIP_NO_HALF_OPERATORS__ +// #undef __HIP_NO_HALF_CONVERSIONS__ +// #endif + +#include +#include +#include +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include "nvToolsExt.h" + +//#include + + +// #ifdef USE_ROCM +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) +// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) +// #endif + +// #ifdef __HIP_PLATFORM_HCC__ +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) +// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) +// #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #ifdef ROCM_BACKWARD_PASS_GUARD +// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; +// #endif +// #endif +// #endif + +#ifndef CHECK_HIP_ERROR +#define CHECK_HIP_ERROR(error) \ + if(error != hipSuccess) \ + { \ + fprintf(stderr, \ + "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_HIPBLAS_ERROR +#define CHECK_HIPBLAS_ERROR(error) \ + if(error != HIPBLAS_STATUS_SUCCESS) \ + { \ + fprintf(stderr, \ + "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), \ + error, \ + __FILE__, \ + __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +namespace { + /*thread_local*/ cudaStream_t weight_stream; + // BUG: DLM has event and stream on different devices error + // In multi-GPU scenerio, do names defined in this namespace exist on all devices? + // C++ keyword: thread_local <- maybe this can help? + /*thread_local*/ cudaEvent_t event; + + // hipBLASLt + hipblasLtHandle_t hipblaslt_handle; + hipblasLtMatmulPreference_t preference; + size_t workspace_size = 2*128*1024*1024; + //uint64_t workspace_size = 0; + void* d_workspace; + int request_solutions = 1; + int returnedAlgoCount = 0; + + struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipDataType dtype; + + friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); + } + }; + + // std::map, std::vector> heuristic_map; + std::map heuristic_map; + + hipEvent_t start, stop; + int bench_iters { 1 }; + int warmup_iters { 1 }; + + bool cout_print = false; + + torch::Tensor dTensor; + + //std::vector heuristicResult; +} + +//find all hipblaslt solutions for given gemm problem +std::vector hipblasLtMatmul_findallsols_wrapper( + hipblasLtHandle_t handle, + hipblasOperation_t op_A, + hipblasOperation_t op_B, + int m, int n, int k, + const void *alpha, + const void *a, + int lda, + const void *b, + int ldb, + const void *beta, + void *c, + int ldc, + hipDataType intype, + hipDataType outtype, + hipStream_t &stream) +{ + int flag { 0 }; + hipblasLtMatrixLayout_t matA, matB, matC; + hipblasLtMatmulDesc_t matmul; + if (op_A == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); + } + if (op_B == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); + } + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); + + //std::vector heuristicResult(10); + //CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + // handle, matmul, matA, matB, matC, matC, + // preference, 10, heuristicResult.data(), &returnedAlgoCount)); + std::vector heuristicResult; + CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + op_A, + op_B, + intype, + intype, + outtype, + outtype, + HIPBLAS_COMPUTE_32F, + heuristicResult)); + + std::vector algoIndex; + int returned_algo_count = heuristicResult.size(); + //for (int i = 0; i < returnedAlgoCount; i++) { + for (int i = 0; i < returned_algo_count; i++) { + auto algo = heuristicResult[i].algo; + size_t ret_workspace_size = 0; + auto status = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, + alpha, + matA, + matB, + beta, + matC, + matC, + algo, + ret_workspace_size + ); + if (status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size heuristicResult(1); + if (solution_index<0) { + //nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; + if (cout_print) { + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") + << " (" << m << ", " << n << ", " << k << "), dtype: " << intype + << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + } + //std::vector heuristicResult(request_solutions); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + handle, matmul, matA, matB, matC, matC, + preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); + if((returnedAlgoCount != request_solutions) && cout_print) { + std::cout << "less solution found! request: " << request_solutions + << ", found: " << returnedAlgoCount << std::endl; + } + //heuristic_map[gemm_key] = heuristicResult[0]; +/* + if (returnedAlgoCount == 1) { + heuristic_map[gemm_key] = heuristicResult[0]; + } else { + // benchmark requested solutions and pick best one + int bestIndex { -1 }; + double bestMs { std::numeric_limits::max() }; + for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { + // warm up + for (int iter { 0 }; iter < warmup_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + // performance measuring + double eventMs; + CHECK_HIP_ERROR(hipEventRecord(start, stream)); + for (int iter { 0 }; iter < bench_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + CHECK_HIP_ERROR(hipEventRecord(stop, stream)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + float temp; + CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); + eventMs = double(temp); + eventMs /= bench_iters; + + if (cout_print) { + std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; + } + if (bestMs > eventMs) { + bestMs = eventMs; + bestIndex = sol; + if (cout_print) { + std::cout << " *" << std::endl; + } + } else { + if (cout_print) { + std::cout << std::endl; + } + } + } + heuristic_map[gemm_key] = heuristicResult[bestIndex]; + } +*/ + //nvtxRangePop(); + } else { + std::vector algoIndex(1); + algoIndex[0]=solution_index; + //std::vector tmpAlgo; + CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); + } + + //size_t ret_workspace_size = 0; + + //auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, + // alpha, + // matA, + // matB, + // beta, + // matC, + // matC, + // heuristicResult[0].algo, + // ret_workspace_size + //); + //if (status1 == HIPBLAS_STATUS_SUCCESS) { + // std::cout << "Workspace size" << ret_workspace_size << std::endl; + + //} else { + // std::cout << "Algo not supported!!!" << std::endl; + + //} + hipblasStatus_t status = hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, + &heuristicResult[0].algo, + d_workspace, workspace_size, + stream); + + //nvtxRangePushA("hipBLASLt variables deletion"); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); + //nvtxRangePop(); + + return status; +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +torch::Tensor HipbSolIdxBlas( + const torch::Tensor& mat1, + const torch::Tensor& mat2, + const int solution_index, + at::optional Type = at::nullopt, + at::optional scale1 = at::nullopt, + at::optional scale2 = at::nullopt, + at::optional scaleOut = at::nullopt + ) +{ + auto mat1_strides { mat1.strides() }; + auto mat2_strides { mat2.strides() }; + auto mat1_sizes { mat1.sizes() }; + auto mat2_sizes { mat2.sizes() }; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; + + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + + auto inType { mat1.options().dtype() }; + auto outType = inType.toScalarType(); + if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; + auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; + + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; + + float one { 1.0f }; + float zero { 0.0f }; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + + void * d_scale1 = nullptr, * d_scale2 = nullptr, * d_scaleOut = nullptr; + if (scale1.has_value()) { + d_scale1 = static_cast(scale1.value().data_ptr()); + } + if (scale2.has_value()) { + d_scale2 = static_cast(scale2.value().data_ptr()); + } + if (scaleOut.has_value()) { + d_scaleOut = static_cast(scaleOut.value().data_ptr()); + } + + + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; + void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; + void *ptrC { static_cast(result.data_ptr()) }; + if (transpose_result) std::swap(d_scale1, d_scale2); + auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + + CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( + hipblaslt_handle, + transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + &one, + ptrA, mat1_ld, d_scale1, + ptrB, mat2_ld, d_scale2, + &zero, + ptrC, result_ld, d_scaleOut, + hipblasInType, + hipblasOutType, + current_stream,solution_index)); + + return result; +} + +//find all hipblas solutions and return them to python land +std::vector HipbFindAllSolIdxBlas( + const torch::Tensor& mat1, + const torch::Tensor& mat2, + at::optional Type = at::nullopt + ) +{ + auto mat1_strides { mat1.strides() }; + auto mat2_strides { mat2.strides() }; + auto mat1_sizes { mat1.sizes() }; + auto mat2_sizes { mat2.sizes() }; + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + + auto inType { mat1.options().dtype() }; + auto outType = inType.toScalarType(); + if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; + auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + float one { 1.0f }; + float zero { 0.0f }; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; + void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; + void *ptrC { static_cast(result.data_ptr()) }; + auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + + return hipblasLtMatmul_findallsols_wrapper( + hipblaslt_handle, + transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + &one, + ptrA, mat1_ld, + ptrB, mat2_ld, + &zero, + ptrC, result_ld, + hipblasInType, + hipblasOutType, + current_stream); + +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void hipb_create_extension() +{ + //CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + //CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); + CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); + + //CHECK_HIP_ERROR(hipEventCreate(&start)); + //CHECK_HIP_ERROR(hipEventCreate(&stop)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void hipb_destroy_extension() +{ + //CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + //CHECK_HIP_ERROR(hipEventDestroy(event)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); + + //CHECK_HIP_ERROR(hipEventDestroy(start)); + //CHECK_HIP_ERROR(hipEventDestroy(stop)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); + m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); + m.def("hipb_mm", &HipbSolIdxBlas, "mm", py::arg("mat1"), py::arg("mat2"), py::arg("solution_index"), py::arg("outType")=at::nullopt, py::arg("scale1")=at::nullopt, py::arg("scale2")=at::nullopt, py::arg("scaleOut")=at::nullopt); + m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", py::arg("mat1"), py::arg("mat2"), py::arg("outType")=at::nullopt); +} diff --git a/gradlib_fp8/gemm_runner.py b/gradlib_fp8/gemm_runner.py new file mode 100644 index 0000000000000..683faf7c936e0 --- /dev/null +++ b/gradlib_fp8/gemm_runner.py @@ -0,0 +1,58 @@ +import torch +import hipbsolidxgemm +import numpy as np +import torch.nn.functional as F +import sys +import pandas as pd +import timeit + +hipbsolidxgemm.hipb_create_extension() + +class TunedGemm: + def __init__(self,tuned_csv_file): + self.bestsols = pd.read_csv(tuned_csv_file,index_col=[0]) + self.create_ds() + def create_ds(self): + df = self.bestsols + solds = {} + for i in range(len(df)): + ds = df.iloc[i] + key = (ds['M'],ds['N'],ds['K']) + solds[key] = (int(ds['solidx'])) + #print(solds) + self.solids = solds + def query_sol(self,m,n,k): + return self.solids.get((m,n,k),(0,0)) + def mm(self,inp,weights): + soltype,solidx = self.query_sol(m=weights.shape[0],n=inp.shape[0],k=inp.shape[1]) + if soltype==1: + out = hipbsolidxgemm.hipb_mm(inp,weights.t(),solidx) + else: + out = F.linear(inp,weights) + return out + def run_all_tuned_sols(self): + for i in range(len(self.bestsols)): + ds = self.bestsols.iloc[i] + print('>>> Running tuned solution') + print(ds) + inp = torch.randn((ds['N'], ds['K']), dtype=get_dtype(ds['dtype']), device='cuda') + weights = torch.randn((ds['M'], ds['K']), dtype=get_dtype(ds['dtype']), device='cuda') + self.mm(inp,weights) + +def get_dtype(dtype_csv): + if dtype_csv=='torch.float16': + dtype = torch.float16 + elif dtype_csv=='torch.bfloat16': + dtype = torch.bfloat16 + elif dtype_csv=='torch.float32': + dtype = torch.float32 + elif dtype_csv=='torch.float32': + dtype = torch.float32 + return dtype + +if __name__ == '__main__': + tgemm = TunedGemm(sys.argv[1]) #csv file with tuned sols goes in argv[1] + print(tgemm.bestsols) + tgemm.run_all_tuned_sols() + + diff --git a/gradlib_fp8/gemm_tuner.py b/gradlib_fp8/gemm_tuner.py new file mode 100644 index 0000000000000..7cf0614aeb840 --- /dev/null +++ b/gradlib_fp8/gemm_tuner.py @@ -0,0 +1,93 @@ +import torch +import os +import argparse +from gradlib.GemmTuner import GemmTuner +import hipbsolidxgemm +import numpy as np +import torch.nn.functional as F +import sys +import pandas as pd +import json +import random +from pathlib import Path +hipbsolidxgemm.hipb_create_extension() + +''' +{'architectures': ['LlamaForCausalLM'], 'bos_token_id': 1, 'eos_token_id': 2, 'hidden_act': 'silu', 'hidden_size': 5120, 'initializer_range': 0.02, +'intermediate_size': 13824, 'max_position_embeddings': 2048, 'model_type': 'llama', 'num_attention_heads': 40, 'num_hidden_layers': 40, 'num_key_value_heads': 40, +'pretraining_tp': 1, 'rms_norm_eps': 1e-05, 'rope_scaling': None, 'tie_word_embeddings': False, 'torch_dtype': 'float16', 'transformers_version': '4.33.0.dev0', 'use_cache': True, 'vocab_size': 32000} +''' +def generate_mk_sets(model_dir, tp=1): + f = open(f'{model_dir}/config.json') + data = json.load(f) + hidden_size = data['hidden_size'] + intermediate_size = data['intermediate_size'] + total_num_heads = data['num_attention_heads'] + total_num_kv_heads = data['num_key_value_heads'] + head_dim = hidden_size // total_num_heads + return [((total_num_heads + (2*total_num_kv_heads)) * head_dim // tp, hidden_size), (hidden_size, hidden_size // tp), (intermediate_size *2 // tp, hidden_size), (hidden_size, intermediate_size // tp) ], hidden_size + +def get_dtype(dtype_str): + dtype = torch.float16 + if dtype_str == 'f32': + dtype = torch.float32 + elif dtype_str == 'bf16': + dtype = torch.bfloat16 + elif dtype_str == 'f16': + dtype = torch.float16 + elif dtype_str == 'f8': + dtype = torch.float8_e4m3fnuz + else: + print('>>> Warning! Invalid dtype', dtype_str, 'using default dtype f16') + return dtype + + +def list_of_ints(arg): + return list(map(int, arg.split(','))) + +def load_input_gemms(input_file): + if Path(input_file).is_file(): + return + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", type=str, default=os.getenv('GTUNE_MODEL', ""), help="Enter the location of your model directory") + parser.add_argument("--tuned_file", type=str, default=os.getenv('GTUNE_TUNED', "tuned.csv"), help="output file for tuned gemm solutions") + parser.add_argument("--input_file", type=str, default=os.getenv('GTUNE_INPUT', None), help="list of gemms to tune for, mutually exclusive with model_dir") + parser.add_argument("--tp", type=int, default=os.getenv('GTUNE_TP', 1), help="Tensor parallelism to be used.") + parser.add_argument("--indtype", type=str, default='f8', help="dtype f32 f16 bf16") + parser.add_argument("--outdtype", type=str, default='f16', help="dtype f32 f16 bf16") + parser.add_argument("--batch_size", type=int, default=os.getenv('GTUNE_BATCH_SIZE', 1), help="Batch size to tune for") + parser.add_argument("--nsets", type=list_of_ints, default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], help="N sizes to tune for: 1,128,2048") + args = parser.parse_args() + + indtype = get_dtype(args.indtype) + outdtype = get_dtype(args.outdtype) + + gtuner = GemmTuner(indtype, outdtype, args.tuned_file) + nsets = [i * args.batch_size for i in args.nsets] + if args.input_file: + print(f">>> Loading {args.input_file}") + if not Path(args.input_file).is_file(): + print(f">>> ERROR: {args.input_file} does not exist. Exiting") + exit(1) + shapes = pd.read_csv(args.input_file) + for i in range(len(shapes)): + ds = shapes.iloc[i] + gtuner.add_gemm(ds['M'],ds['N'],ds['K']) + else: + if not args.model_dir: + print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") + #LL2 13B sizes + mksets = [(15360, 5120), (5120, 5120), (27648, 5120), (5120, 13824)] + gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm + else: + mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp) + gtuner.add_gemm(m=32000//args.tp, n=1 * args.batch_size, k=hidden_size) #TODO: Handle cases where vocab_size is not divisible by tp + + for n in sorted(nsets): + for m, k in mksets: + gtuner.add_gemm(m, n, k) + + gtuner.find_best_sols() diff --git a/gradlib_fp8/gradlib/GemmTuner.py b/gradlib_fp8/gradlib/GemmTuner.py new file mode 100644 index 0000000000000..55071228fb32d --- /dev/null +++ b/gradlib_fp8/gradlib/GemmTuner.py @@ -0,0 +1,143 @@ +import torch +import os +import argparse +import hipbsolidxgemm +import numpy as np +import torch.nn.functional as F +import sys +import pandas as pd +import json +import random +from pathlib import Path + +hipbsolidxgemm.hipb_create_extension() + +rtol = 1e-5 +atol = 1 +dtype = torch.float16 + +class Gemm: + def __init__(self,m,n,k,indtype,outdtype): + self.m=m + self.k=k + self.n=n + self.indtype=indtype + self.outdtype=outdtype + self.nb = 37 + self.inp = torch.randn((self.n, self.k), device='cuda').to(self.indtype) + self.weights = torch.randn((self.m, self.k), device='cuda').to(self.indtype) + #weights2 is used in measurement/warm iters to ensure HBM fetch for weight tensors + self.weights2 = torch.randn((self.nb, self.m, self.k), device='cuda').to(self.indtype) + self.blob = torch.ones(128*1024*1024, dtype=torch.float32, device='cuda') + self.topn = 20 #number of top solutions from each source + self.hipb_sols=[] + self.rtol = 1e-5 + self.atol = 1 + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + + def find_hipblas_sols(self): + sols = hipbsolidxgemm.hipb_findallsols(self.inp,self.weights.t(),self.outdtype) + print('M N K',self.m,self.n,self.k,'>>> Total hipb solutions',len(sols), flush=True) + #print(sols) + self.hipb_sols = sols + + + def check_gemm_ref(self,libtype,solidx): + ref = F.linear(self.inp.to(torch.float32),self.weights.to(torch.float32)).to(self.outdtype) + c = hipbsolidxgemm.hipb_mm(self.inp,self.weights.t(),solidx,self.outdtype) + if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): + #print('>>>',libtype,'Solidx',solidx,'passed reference test') + return True + else: + print('>>>','Solidx',solidx,'FAILED reference test', flush=True) + print(ref, flush=True) + print(c, flush=True) + return False + def hipb_time_sol(self,solidx,cold_iters=2,warm_iters=10): + #print('>>>hipbtime',solidx) + for i in range(cold_iters): + c = hipbsolidxgemm.hipb_mm(self.inp,self.weights.t(),solidx,self.outdtype) + self.start.record() + for i in range(warm_iters): + c = hipbsolidxgemm.hipb_mm(self.inp,self.weights2 [random.randint(0,self.nb-1)].t(),solidx,self.outdtype) + self.end.record() + torch.cuda.synchronize() + gtime = self.start.elapsed_time(self.end)/warm_iters + #print('>>> Solidx GTime',solidx,gtime,'ms') + return gtime + def hipb_time_all_sols(self,fast_mode=0,top_sols=0): + coldi=20; warmi=20 + if fast_mode: coldi=2; warmi=2 + solutions = self.hipb_sols + if top_sols: solutions = self.hipb_top_sols + gtimes = {} + for solidx in solutions: + gtimes[solidx] = self.hipb_time_sol(solidx, cold_iters=coldi, warm_iters=warmi) + self.hipb_gtimedf = pd.DataFrame.from_dict(gtimes,orient='index',columns=['gtimems']).sort_values(by='gtimems') + self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') + print('>>> HipBlasLt top solutions, Fast Mode',fast_mode) + print(self.hipb_gtimedf.head(self.topn)) + def warmup(self,warmi=500): + for i in range(warmi): + self.blob = self.blob + 0.00001 + def functional_check_topn_fastest(self): + hipb_topn = [] + for solidx in self.hipb_gtimedf.index[:self.topn]: + if self.check_gemm_ref(libtype='hipblaslt',solidx=solidx): + hipb_topn.append(solidx) + self.hipb_top_sols = hipb_topn + + def find_fastest_solution(self): + self.find_hipblas_sols() + self.warmup() + self.hipb_time_all_sols(fast_mode=1) + self.functional_check_topn_fastest() + self.warmup() + self.hipb_time_all_sols(fast_mode=0,top_sols=1) + if len(self.hipb_gtimedf)>0: + best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] + self.best_solidx = self.hipb_gtimedf.index[0] + self.best_soltime = best_hipb_time + else: + print('>>> No hipblas solutions found!',flush=True) + self.best_solidx = 0 + self.best_soltime = 0 + print('>>> Fastest Solution is',self.best_solidx,self.best_soltime,flush=True) + + +class GemmTuner: + def __init__(self, indtype, outdtype, tuned_file=None): + self.gemm_problems = pd.DataFrame(columns=['M','N','K']) + self.indtype = indtype + self.outdtype = outdtype + self.tuned_file = tuned_file + if Path(tuned_file).is_file(): + self.gdf = pd.read_csv(tuned_file) + else: + self.gdf = None + + def add_gemm(self,m,n,k): + if ( self.gdf is None or (self.gdf[(self.gdf['M'] == m) & (self.gdf['N'] == n) & (self.gdf['K'] == k)].empty)): + entry = {'M':[m], 'N':[n], 'K':[k]} + df = pd.DataFrame(entry) + self.gemm_problems = pd.concat([self.gemm_problems, df],ignore_index=True) + else: + print(f">>>Info: Found Duplicate shape(M:{m}, N:{n}, K:{k}), skipping") + + def find_best_sols(self): + df = self.gemm_problems + soldf = pd.DataFrame() + for i in range(len(df)): + ds = df.iloc[i] + gemmobj = Gemm(ds['M'],ds['N'],ds['K'],indtype=self.indtype,outdtype=self.outdtype) + gemmobj.find_fastest_solution() + soldf.loc[i,'solidx'] = gemmobj.best_solidx + soldf.loc[i,'soltimems'] = gemmobj.best_soltime + soldf['indtype'] = self.indtype + soldf['outdtype'] = self.outdtype + finaldf = pd.concat([self.gemm_problems, soldf],axis=1) + finaldf = pd.concat([finaldf, self.gdf]) + finaldf.to_csv(self.tuned_file, index=False) + print(finaldf) diff --git a/gradlib_fp8/setup.py b/gradlib_fp8/setup.py new file mode 100644 index 0000000000000..933503ec5d54e --- /dev/null +++ b/gradlib_fp8/setup.py @@ -0,0 +1,89 @@ +import torch +import setuptools +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from torch.utils.hipify import hipify_python +import os +import subprocess +import re + +this_dir = os.path.dirname(os.path.abspath(__file__)) +#gpus = subprocess.check_output("/opt/rocm/bin/rocminfo").decode('UTF-8').split('\n') +#gpus = list(set([re.search('(gfx94.)', g).group(0) for g in gpus if 'gfx94' in g])) +gpus = ['gfx90a','gfx940','gfx941','gfx942'] +#gpus = ['gfx90a','gfx940'] +extra_args = ["--offload-arch=" + g for g in gpus] + + +#sets_rocm_pytorch = False +maj_ver, min_ver, *_ = torch.__version__.split('.') +if int(maj_ver) > 1 or (int(maj_ver) == 1 and int(min_ver) >= 5): + from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False + +ext_modules = [] + +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, 'include', 'ATen', 'CUDAGenerator.h')): + generator_flag = ['-DOLD_GENERATOR'] + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +version_ge_1_1 = [] +if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] +version_ge_1_3 = [] +if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] +version_ge_1_5 = [] +if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] +version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + +include_dirs=[os.path.join(this_dir, 'csrc')] + +#if is_rocm_pytorch: +# import shutil +# with hipify_python.GeneratedFileCleaner(keep_intermediates=True) as clean_ctx: +# hipify_python.hipify(project_directory=this_dir, output_directory=this_dir, includes="csrc/*", +# show_detailed=True, is_pytorch_extension=True, clean_ctx=clean_ctx) + +if not is_rocm_pytorch: + ext_modules.append( + CUDAExtension( + name='gradlib', + sources=['grad_funcs.cu'], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc':['-O3','-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', "--expt-relaxed-constexpr", "-ftemplate-depth=1024", '-gencode=arch=compute_70,code=sm_70','-gencode=arch=compute_80,code=sm_80','-gencode=arch=compute_80,code=compute_80'] + } + ) + ) +elif is_rocm_pytorch: + ext_modules.append( + CUDAExtension( + name='hipbsolidxgemm', + sources=['./csrc/hipbsolgemm.cu'], + include_dirs=include_dirs, + # add additional libraries argument for hipblaslt + libraries=['hipblaslt'], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc':['-O3','-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', + "-ftemplate-depth=1024"] + extra_args + } + ) + ) + +setup( + name='gradlib', + packages=['gradlib'], + ext_modules=ext_modules, + cmdclass={ + 'build_ext': BuildExtension +}) + +# python setup.py build && cp build/lib*/gradlib* ../ diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index e464b41329cbd..0e380083ca657 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -170,7 +170,7 @@ def apply_fp8_16( ).drop_duplicates() df.to_csv("/projects/fp8_tune.csv", index=False) # print(f"{m},{n},{k}") - algo = 0 + algo = 11943 res = ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res From 2b5c5d1023219d33b05186bea3e40b57f9199629 Mon Sep 17 00:00:00 2001 From: charlifu Date: Sat, 1 Jun 2024 00:19:46 +0000 Subject: [PATCH 390/413] adding doc --- ROCm_performance.md | 27 +++++++++++++++++++ .../layers/quantization/fp8_rocm.py | 16 +++++------ vllm/model_executor/models/llama.py | 3 +-- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/ROCm_performance.md b/ROCm_performance.md index 180c848a21950..f46741f3f19aa 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -18,3 +18,30 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. + +## Fp8 Quantization + +To use fp8 quantization, first step is to use Nvidia ammo to quantize your model to fp8 format, following this [instruction](https://github.com/vllm-project/vllm/blob/main/examples/fp8/quantizer/README.md). This will give a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. We will need to put the safetensor file under your model folder, and add file called `serenity_config.json`, which contains a json object with a key: `"quantized_weights": "quantized/osf/rank0.safetensors"`, the value should be the releative path of your safetensor file containing the quantized weights. + +Then we can run a model with fp8 quantization using vllm, just add a parameter `quantization="fp8"` when creating the vllm.LLM object. + +## Gemm Tunning for Fp8 + +To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model. + +To obtain all the shapes of gemms during the execution of the model, set the env value TUNE_FP8=1 and the run the model as usual. We will get the a file called `/fp8_shapes.csv`. + +Next, run gradlib to obtain the best solutions of these shapes: + +``` +cd gradlib_fp8 +python3 -m pip uninstall gradlib +python3 setup.py install +python3 gemm_tunner.py --input_file /fp8_shapes.csv --tuned_file /tuned_fp8_16.csv +cd ../gradlib +python3 -m pip uninstall gradlib +python3 setup.py install +cd .. +``` + +Now, when running inference with fp8, we are using the tunned gemm for best performance. \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index 0e380083ca657..d39c074b97325 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -21,10 +21,10 @@ def __init__(self, config) -> None: #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": self.gemm_method = Fp8RocmLinearLayer.apply_fp8_8 - tuned_filename = "/projects/tuned_fp8_8.csv" + tuned_filename = "/tuned_fp8_8.csv" elif gemm_type == "fp8_16": self.gemm_method = Fp8RocmLinearLayer.apply_fp8_16 - tuned_filename = "/projects/tuned_fp8_16.csv" + tuned_filename = "/tuned_fp8_16.csv" else: raise Exception(f"Unknown fp8 gemm type: {gemm_type}") try: @@ -37,7 +37,7 @@ def __init__(self, config) -> None: m = shape["M"] n = shape["N"] k = shape["K"] - algo = shape["algo"] + algo = shape["solidx"] self._tuned[(m, n, k)] = algo @staticmethod @@ -162,15 +162,15 @@ def apply_fp8_16( # print(f"Not found: {m} {n} {k}") if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/projects/fp8_tune.csv") + df = pd.read_csv("/fp8_shapes.csv") except: df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] ).drop_duplicates() - df.to_csv("/projects/fp8_tune.csv", index=False) + df.to_csv("/fp8_shapes.csv", index=False) # print(f"{m},{n},{k}") - algo = 11943 + algo = 0 res = ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res @@ -197,13 +197,13 @@ def apply_fp8_8( # print(f"Not found: {m} {n} {k}") if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/projects/fp8_tune.csv") + df = pd.read_csv("/fp8_shapes.csv") except: df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] ).drop_duplicates() - df.to_csv("/projects/fp8_tune.csv", index=False) + df.to_csv("/fp8_shapes.csv", index=False) # print(f"{m},{n},{k}") algo = 0 diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a4d7b6119e0b5..f021be7ee4822 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -454,8 +454,7 @@ def load_ammo_quantized_weights( ("self_attn.qkv_proj", "attention.qkv"), ] tp_rank = get_tensor_model_parallel_rank() - quantized_filename: str = quant_config.quantized_weights_path[ - f"rank0"] + quantized_filename: str = quant_config.quantized_weights_path #quantized_filename = "quantized/osf/rank0.safetensors" if not quantized_filename.startswith('/'): quantized_filename = os.path.join(model_name_or_path, From 657579113f714c2e74bca373ecfb6c2c245b4101 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 31 May 2024 17:20:19 -0700 Subject: [PATCH 391/413] [Doc] Add checkmark for GPTBigCodeForCausalLM LoRA support (#5171) --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e4bae80343a2c..82e71e61975c8 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -62,7 +62,7 @@ Alongside each architecture, we include some popular models that use it. * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - - + - ✅︎ * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. From 1197e02141df1a7442f21ff6922c98ec0bba153e Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 31 May 2024 20:21:38 -0400 Subject: [PATCH 392/413] [Build] Guard against older CUDA versions when building CUTLASS 3.x kernels (#5168) --- csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu | 10 ++++++++-- csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu | 11 ++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 5fd6d8ff20867..531414bc45165 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -1,3 +1,9 @@ +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + #include #include @@ -6,8 +12,6 @@ #include #include -// clang-format will break include orders -// clang-format off #include "cutlass/cutlass.h" #include "cute/tensor.hpp" @@ -241,3 +245,5 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a, } } } + +#endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index dab73ac6c831e..eb532f2ac7a9b 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,5 +1,6 @@ +#include + #include -#include #include void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +#endif void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a, if (version_num >= 90) { // Hopper + + // Guard against compilation issues for sm90 kernels +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales); +#else + cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales); +#endif } else if (version_num == 89) { // Ada Lovelace cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales); From 977361bdbebde2aca7043db2245d6fb312bd1156 Mon Sep 17 00:00:00 2001 From: charlifu Date: Sat, 1 Jun 2024 00:47:27 +0000 Subject: [PATCH 393/413] doc fix --- ROCm_performance.md | 6 +++--- vllm/model_executor/layers/quantization/fp8_rocm.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ROCm_performance.md b/ROCm_performance.md index f46741f3f19aa..0f12ed1adc9af 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -21,9 +21,9 @@ The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head ## Fp8 Quantization -To use fp8 quantization, first step is to use Nvidia ammo to quantize your model to fp8 format, following this [instruction](https://github.com/vllm-project/vllm/blob/main/examples/fp8/quantizer/README.md). This will give a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. We will need to put the safetensor file under your model folder, and add file called `serenity_config.json`, which contains a json object with a key: `"quantized_weights": "quantized/osf/rank0.safetensors"`, the value should be the releative path of your safetensor file containing the quantized weights. +To use fp8 quantization, first step is to quantize your model to fp8 format. Generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be added under your model folder along with a file called `serenity_config.json`, which contains a json object with a key: `"quantized_weights": "quantized/osf/rank0.safetensors"`, the value should be the relative path of your safetensor file containing the quantized weights. -Then we can run a model with fp8 quantization using vllm, just add a parameter `quantization="fp8"` when creating the vllm.LLM object. +Then we can run a model with fp8 quantization using vllm, just add a parameter `quantization="fp8"` when creating the `vllm.LLM` object. ## Gemm Tunning for Fp8 @@ -37,7 +37,7 @@ Next, run gradlib to obtain the best solutions of these shapes: cd gradlib_fp8 python3 -m pip uninstall gradlib python3 setup.py install -python3 gemm_tunner.py --input_file /fp8_shapes.csv --tuned_file /tuned_fp8_16.csv +python3 gemm_tuner.py --input_file /fp8_shapes.csv --tuned_file /tuned_fp8_16.csv cd ../gradlib python3 -m pip uninstall gradlib python3 setup.py install diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index d39c074b97325..55642ae23ac6d 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -59,7 +59,7 @@ def get_min_capability(cls) -> int: @classmethod def get_name(cls) -> str: - return "serenity" + return "Fp8Rocm" def get_linear_method(self) -> "Fp8RocmLinearLayer": return Fp8RocmLinearLayer(self) From a360ff80bb34f9dfcd21cf880c2030daa2d6b3a3 Mon Sep 17 00:00:00 2001 From: Daniele Date: Sat, 1 Jun 2024 06:06:45 +0200 Subject: [PATCH 394/413] [CI/Build] CMakeLists: build all extensions' cmake targets at the same time (#5034) --- setup.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index b4baebb0d4801..d99fc050f6d84 100644 --- a/setup.py +++ b/setup.py @@ -187,19 +187,22 @@ def build_extensions(self) -> None: if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) + targets = [] # Build all the extensions for ext in self.extensions: self.configure(ext) + targets.append(remove_prefix(ext.name, "vllm.")) - ext_target_name = remove_prefix(ext.name, "vllm.") - num_jobs, _ = self.compute_num_jobs() + num_jobs, _ = self.compute_num_jobs() - build_args = [ - '--build', '.', '--target', ext_target_name, '-j', - str(num_jobs) - ] + build_args = [ + "--build", + ".", + f"-j={num_jobs}", + *[f"--target={name}" for name in targets], + ] - subprocess.check_call(['cmake', *build_args], cwd=self.build_temp) + subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) def _is_cuda() -> bool: From 7fed416f9ef4b6aacbaef660cbc059040be7bdbe Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 3 Jun 2024 15:06:39 +0000 Subject: [PATCH 395/413] fix the model loading fp8 --- vllm/model_executor/model_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index b348bf8e323ec..cb01edc25fe7e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -98,7 +98,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, initialize_dummy_weights(model) else: load_ammo_quantized_weights = getattr(model, "load_ammo_quantized_weights", None) - if quant_config is not None and load_ammo_quantized_weights is not None: + if model_config.quantization is not None and load_ammo_quantized_weights is not None: load_ammo_quantized_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision, quant_config) From 34e010c45b5ca1eb255e0d42ae8c7c05ad0d9fee Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:19:47 -0400 Subject: [PATCH 396/413] Update linear.py Fix bias handling with tgemm --- vllm/model_executor/layers/linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1d2a8782a96f1..6a0f55101e660 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -104,6 +104,8 @@ def apply_weights(self, if bias is not None: return F.linear(x, weight) + bias return F.linear(x, weight) + elif bias is not None: + return F.linear(x, weight, bias) return tgemm.mm(x, weight) From 40198075b0fbeb7dcf021d84f1fb96dcfc257ce5 Mon Sep 17 00:00:00 2001 From: Joe Shajrawi <17753158+shajrawi@users.noreply.github.com> Date: Thu, 30 May 2024 11:21:58 -0500 Subject: [PATCH 397/413] Update Dockerfile.rocm Updated the base docker to ROCm 6.1.1 Updated the RCCL pin to a new one with performance improvements --- Dockerfile.rocm | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 74db3c1a796be..11f9efb2e7371 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,5 +1,5 @@ # default base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.1.2" +ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_release-2.1.2" ARG COMMON_WORKDIR=/app @@ -40,7 +40,7 @@ WORKDIR ${COMMON_WORKDIR} # ----------------------- # hipBLASLt build stages FROM base AS build_hipblaslt -ARG HIPBLASLT_BRANCH="ee51a9d1" +ARG HIPBLASLT_BRANCH="6f65c6e" RUN git clone https://github.com/ROCm/hipBLASLt \ && cd hipBLASLt \ && git checkout ${HIPBLASLT_BRANCH} \ @@ -56,7 +56,7 @@ FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt # ----------------------- # RCCL build stages FROM base AS build_rccl -ARG RCCL_BRANCH="eeea3b6" +ARG RCCL_BRANCH="73221b4" RUN git clone https://github.com/ROCm/rccl \ && cd rccl \ && git checkout ${RCCL_BRANCH} \ From 324cc8bab8c57219fbe21eede1c9b12534133df3 Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Tue, 4 Jun 2024 01:12:22 +0000 Subject: [PATCH 398/413] Use world group to broadcast metadata on ROCm Partially reverts [Core][Distributed] use cpu group to broadcast metadata in cpu (https://github.com/vllm-project/vllm/pull/4444) --- vllm/distributed/communication_op.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index dffeca38e65a8..b04adb532dd38 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -12,6 +12,7 @@ get_tensor_model_parallel_world_size, get_tp_ca_communicator, get_tp_pynccl_communicator) +from vllm.utils import is_hip @dataclass @@ -251,7 +252,10 @@ def broadcast_tensor_dict( return tensor_dict group = group or torch.distributed.group.WORLD - metadata_group = metadata_group or get_cpu_world_group() + if is_hip(): + metadata_group = metadata_group or torch.distributed.group.WORLD + else: + metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" From b373a0edab27c472e8db07bb6d42b1e20a35ea24 Mon Sep 17 00:00:00 2001 From: lcskrishna Date: Thu, 23 May 2024 03:15:00 +0000 Subject: [PATCH 399/413] Custom PagedAttn optimizations for ROCm initial commit for v0.4.0 with paged attn optimization update the integration code updates to custom attention kenrel update unit test case for custom update conditions to pick paged attn v2 vs custom update env condition enable more parameters in custom unit testing update conditions for custom vs v2 update gqa ratio condition for using custom kernel updated docs, cleanup and enabled it by default fixes imports for custom paged attn update the custom paged attn with latest data update conditions of max-context-len --- CMakeLists.txt | 1 + ROCm_performance.md | 6 + .../kernels/benchmark_paged_attention.py | 67 +- csrc/custom/custom.cu | 25 + .../custom/paged_attention/attention_ll4mi.cu | 849 ++++++++++++++++++ tests/kernels/test_attention_custom.py | 292 ++++++ vllm/attention/ops/paged_attn.py | 91 +- 7 files changed, 1284 insertions(+), 47 deletions(-) create mode 100644 csrc/custom/paged_attention/attention_ll4mi.cu create mode 100644 tests/kernels/test_attention_custom.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e75d13eef2bd..3a7cb981ac3e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,6 +229,7 @@ set(CUSTOM_SRC "csrc/custom/custom_kernels.cu" "csrc/custom/fused_kernels.cu" "csrc/custom/custom.cu" +"csrc/custom/paged_attention/attention_ll4mi.cu" ) define_gpu_extension_target( diff --git a/ROCm_performance.md b/ROCm_performance.md index b39f3b42aab76..180c848a21950 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -12,3 +12,9 @@ The default attention function on ROCm is using triton attention kernel. To fall ## Tunable ops Pytorch tunable ops are supported. Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, also define `PYTORCH_TUNABLEOP_TUNING=1` + +## Custom PagedAttention + +On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. +Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. +The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index e6f4e9e6b9716..0fcfc0a295ca2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,10 +6,11 @@ import torch from vllm import _custom_ops as ops +from vllm._custom_C import paged_attention_custom from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random NUM_BLOCKS = 1024 -PARTITION_SIZE = 512 +PARTITION_SIZE = 256 @torch.inference_mode() @@ -77,7 +78,11 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + if not args.custom_paged_attn: + global PARTITION_SIZE + PARTITION_SIZE = 512 + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -117,24 +122,43 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: kv_scale, ) elif version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - ) + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + else: + paged_attention_custom( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -188,6 +212,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + parser.add_argument("--custom-paged-attn", + action="store_true", + help="Use custom paged attention") args = parser.parse_args() print(args) diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index aeff9cc5e6ae7..d75b2d2e41005 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -64,11 +64,36 @@ void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { at::cuda::getCurrentCUDAStream()); } +void paged_attention_custom( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); + // declare the extension module with the AddGPU function: PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ m.doc() = "pybind11 example plugin"; m.def("LLMM1", &LLMM1); m.def("LLMM_Silu", &LLMM_Silu); m.def("LLZZ", &LLZZ); + m.def( + "paged_attention_custom", + &paged_attention_custom, + "PagedAttention LL4Mi Custom."); //m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu new file mode 100644 index 0000000000000..6c9e84ab2f5f4 --- /dev/null +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -0,0 +1,849 @@ +//TODO: add license terms +#include +#include +#include + +#include + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define WARP_SIZE 64 + +#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 +#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { _Half4 xy[2]; } _Half8; +////// Non temporal load stores /////// + +#if 1 + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + +#else + +template +__device__ __forceinline__ T load(const T* addr) { + return __builtin_nontemporal_load(addr); +} + +template <> +__device__ __forceinline__ +float2 load (const float2* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +float4 load (const float4* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result1 = __builtin_nontemporal_load(addr_alias); + auto result2 = __builtin_nontemporal_load(addr_alias + 1); + float4 ret{}; + auto ret_alias = reinterpret_cast(&result1); + ret.x = ret_alias->x; + ret.y = ret_alias->y; + ret_alias = reinterpret_cast(&result2); + ret.z = ret_alias->x; + ret.w = ret_alias->y; + return ret; +} + +template <> +__device__ __forceinline__ +__half load (const __half* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast<__half *>(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +__half2 load (const __half2* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast<__half2 *>(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +vllm::Half4_ load (const vllm::Half4_* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result = __builtin_nontemporal_load(addr_alias); + auto ret = reinterpret_cast(&result); + return ret[0]; +} + +template <> +__device__ __forceinline__ +vllm::Half8_ load (const vllm::Half8_* addr) { + auto addr_alias { reinterpret_cast(addr) }; + auto result1 = __builtin_nontemporal_load(addr_alias); + auto result2 = __builtin_nontemporal_load(addr_alias + 1); + vllm::Half8_ ret {}; + auto ret_alias = reinterpret_cast(&result1); + ret.x = ret_alias->x; + ret.y = ret_alias->y; + ret_alias = reinterpret_cast(&result2); + ret.z = ret_alias->x; + ret.w = ret_alias->y; + return ret; +} + +//// Not using nontemporal stores for now +template +__device__ __forceinline__ void store(T value, T* addr) { + return __builtin_nontemporal_store(value, addr); +} + +#endif + +/////////////////////////////////////// + +//grid (num_seqs, num_partitions,num_heads/gqa_ratio) +//block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] +#if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] +#endif + int max_ctx_blocks + ) { + constexpr int NWARPS = NUM_THREADS/WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid%4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + //exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO,4); // each 4 lanes fetch 4 different qheads, total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4*QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4+1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4+1]; + _Half8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE/x; + _Half8 Klocal[KHELOOP]; + constexpr int VHELOOP = HEAD_SIZE/WARP_SIZE; //v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; //16 separate 4xtokens across warp -> 16/2 8xtokens + _Half8 Vlocal[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h=0; h= context_len) { //warp out of context + #pragma unroll + for(int h=0;h(block_table[block_idx]); + + //each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = q + seq_idx*q_stride + wg_start_head_idx*HEAD_SIZE; + const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid/4; + #pragma unroll + for (int h=0; h(k_ptr); + + const int physical_block_offset = local_token_idx%BLOCK_SIZE; //since x=half8, physical_block_offset is already cast as _H8 + + + #pragma unroll + for (int d=0;d(v_ptr); + //iterate over each v block + #pragma unroll + for (int b=0;b(vphysical_blocks[b]); + const _Half8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride)/8; + //iterate over each head elem (within head_size) + #pragma unroll + for (int h=0;h8) { + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, 10, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, 10, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, 11, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, 11, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, 12, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, 12, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, 13, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, 13, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, 14, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, 14, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, 15, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, 15, 0); + } //KHELOOP>8 + dout[h]*=scale; + } + //transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 lanes + #pragma unroll + for (int h=0;h>2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h=0;h=4; mask/=2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h],mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h=0;h=4; mask/=2) { + exp_sum[h] += __shfl_xor(exp_sum[h],mask); + } + } + + + #pragma unroll + for (int h=0;h every 4 lanes hold 4 heads, each lane holds 4 tokens, there are 4x16 tokens across warp + float16x4 logits[QHLOOP]; + #pragma unroll + for (int h=0;h= context_len) { //warp out of context + #pragma unroll + for (int qh=0; qh partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh=0; qh +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + //if num_partitions==1, main kernel will write to out directly, no work in reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2*WARP_SIZE]; + + if (warpid==0) { + + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + //valid partition is the last valid partition in case threadid > num partitions + const int valid_partition = (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions-1; + const int valid_partition2 = (WARP_SIZE+threadIdx.x < num_partitions) ? WARP_SIZE+threadIdx.x : num_partitions-1; + float reg_max_logit = max_logits_ptr[valid_partition]; + float reg_max_logit2 = max_logits_ptr[valid_partition2]; + float max_logit = fmaxf(reg_max_logit,reg_max_logit2); + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float global_exp_sum = 0.0f; + float rescaled_exp_sum = exp_sums_ptr[valid_partition]; + float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; + rescaled_exp_sum *= (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x+WARP_SIZE < num_partitions) ? expf(reg_max_logit2 - max_logit) : 0.0f; + global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + shared_exp_sums[threadIdx.x+WARP_SIZE] = rescaled_exp_sum2; + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x==0) { + shared_global_exp_sum = global_exp_sum; + } + }//warpid == 0 + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = 0.0f; + } + const int last_partition_offset = (num_partitions-1)*HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx=0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK*HEAD_SIZE; j+=HEAD_SIZE) { + //lastj is last valid partition + const int lastj_offset = (j JCHUNK) { + #pragma unroll + for (int j = JCHUNK*HEAD_SIZE; j < 2*JCHUNK*HEAD_SIZE; j+=HEAD_SIZE) { + const int lastj_offset = (j 2*JCHUNK) { + #pragma unroll + for (int j = 2*JCHUNK*HEAD_SIZE; j < MAX_NPAR*HEAD_SIZE; j+=HEAD_SIZE) { + const int lastj_offset = (j JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2*JCHUNK; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } + if (num_partitions > 2*JCHUNK) { + #pragma unroll + for (int j = 2*JCHUNK; j < MAX_NPAR; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } + } + } + + if (num_partitions > MAX_NPAR) { + idx=0; + #pragma unroll + for (int j = MAX_NPAR*HEAD_SIZE; j < 2*MAX_NPAR*HEAD_SIZE; j+=HEAD_SIZE) { + //lastj is last valid partition + const int lastj_offset = (j \ + <<>>( \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + num_kv_heads, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,out_ptr,max_ctx_blocks); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + const int num_kv_heads, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes) { + + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); +#if 0 + T* qk_out_ptr = reinterpret_cast(qk_out.data_ptr()); + T* softmax_out_ptr = reinterpret_cast(softmax_out.data_ptr()); +#endif + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads/num_kv_heads; + assert(num_heads%num_kv_heads==0); + assert(head_size==HEAD_SIZE); + assert(max_num_partitions<=128); + + constexpr int NTHR = PARTITION_SIZE; + dim3 grid(num_seqs,max_num_partitions,num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: LAUNCH_CUSTOM_ATTENTION(1); break; + case 2: LAUNCH_CUSTOM_ATTENTION(2); break; + case 3: LAUNCH_CUSTOM_ATTENTION(3); break; + case 4: LAUNCH_CUSTOM_ATTENTION(4); break; + case 5: LAUNCH_CUSTOM_ATTENTION(5); break; + case 6: LAUNCH_CUSTOM_ATTENTION(6); break; + case 7: LAUNCH_CUSTOM_ATTENTION(7); break; + case 8: LAUNCH_CUSTOM_ATTENTION(8); break; + case 9: LAUNCH_CUSTOM_ATTENTION(9); break; + case 10: LAUNCH_CUSTOM_ATTENTION(10); break; + case 11: LAUNCH_CUSTOM_ATTENTION(11); break; + case 12: LAUNCH_CUSTOM_ATTENTION(12); break; + case 13: LAUNCH_CUSTOM_ATTENTION(13); break; + case 14: LAUNCH_CUSTOM_ATTENTION(14); break; + case 15: LAUNCH_CUSTOM_ATTENTION(15); break; + case 16: LAUNCH_CUSTOM_ATTENTION(16); break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + //dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + //dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + //reduction kernel is only required if max_context_len > partition size, otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max supported by graphing, not the actual max among + // all the sequences: in that case reduction kernel will still run but return immediately + if (max_context_len > PARTITION_SIZE) { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, + exp_sums_ptr, + max_logits_ptr, + tmp_out_ptr, + context_lens_ptr, + max_num_partitions); + } +} + +#define CALL_CUSTOM_LAUNCHER(T,BLK_SIZE,HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len,\ + alibi_slopes); + +#define CALL_CUSTOM_LAUNCHER_BLK(T,HEAD_SIZE) \ + switch (block_size) { \ + case 8: CALL_CUSTOM_LAUNCHER(T,8,HEAD_SIZE); break; \ + case 16: CALL_CUSTOM_LAUNCHER(T,16,HEAD_SIZE); break; \ + case 32: CALL_CUSTOM_LAUNCHER(T,32,HEAD_SIZE); break; \ + default: TORCH_CHECK(false, "Unsupported block size: ", block_size); break; \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ + switch (head_size) { \ + case 64: CALL_CUSTOM_LAUNCHER_BLK(T,64); break; \ + case 128: CALL_CUSTOM_LAUNCHER_BLK(T,128); break; \ + default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; \ + } + +void paged_attention_custom( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, +#if 0 + torch::Tensor& qk_out, + torch::Tensor& softmax_out, +#endif + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + const int head_size = query.size(2); + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py new file mode 100644 index 0000000000000..5bdbf126c22fa --- /dev/null +++ b/tests/kernels/test_attention_custom.py @@ -0,0 +1,292 @@ +import random +from typing import Optional, Tuple + +import pytest +import torch +from allclose_default import get_default_atol, get_default_rtol + +from vllm._C import cache_ops, ops +from vllm._custom_C import paged_attention_custom +from vllm.utils import get_max_shared_memory_bytes, is_hip + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 256 +# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} +DTYPES = [torch.half, torch.bfloat16, torch.float + ] if not is_hip() else [torch.half] +NUM_GEN_SEQS = [1, 17, 64] # Arbitrary values for testing +NUM_HEADS = [(8 * x, 8) for x in range(1, 17)] # Arbitrary values for testing + +# FlashAttention forward only supports head dimension at most 128 +# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62 +HEAD_SIZES = [128] +BLOCK_SIZES = [16] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto"] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len).int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("version", ["custom"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + context_lens[-1] = MAX_SEQ_LEN + #context_lens = [8192 for _ in range(num_seqs)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int) + #print('>>> ctx lens', context_lens) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + kv_scale = 1.0 + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + elif version == "v2" or version == "custom": + num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + ) + elif version == "custom": + paged_attention_custom( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + ) + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + cache_ops.convert_fp8(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + cache_ops.convert_fp8(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if is_hip() else 1e-3 + rtol = get_default_rtol(output) if is_hip() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + atol = 5e-3 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index a214f40d16514..0d3ee47193306 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -5,9 +6,16 @@ from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.utils import is_hip + +custom_attn_available = is_hip() and \ + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0") +if custom_attn_available: + from vllm._custom_C import paged_attention_custom # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 +_PARTITION_SIZE_V1V2 = 512 +_PARTITION_SIZE_CUSTOM = 256 @dataclass @@ -108,6 +116,16 @@ def forward_decode( output = torch.empty_like(query) block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape + gqa_ratio = num_heads // num_kv_heads + use_custom = (custom_attn_available and query.dtype == torch.half + and head_size == 128 and block_size == 16 + and kv_cache_dtype == "auto" + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 32768) + if not use_custom: + _PARTITION_SIZE = _PARTITION_SIZE_V1V2 + else: + _PARTITION_SIZE = _PARTITION_SIZE_CUSTOM max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use @@ -118,8 +136,8 @@ def forward_decode( # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = (max_seq_len <= 8192 - and (max_num_partitions == 1 or num_seqs * num_heads > 512)) - + and (max_num_partitions == 1 or num_seqs * num_heads > 512) + and not use_custom) if use_v1: # Run PagedAttention V1. ops.paged_attention_v1( @@ -143,7 +161,7 @@ def forward_decode( blocksparse_head_sliding_step, ) else: - # Run PagedAttention V2. + # Run PagedAttention V2 or PagedAttention Custom. assert _PARTITION_SIZE % block_size == 0 tmp_output = torch.empty( size=(num_seqs, num_heads, max_num_partitions, head_size), @@ -156,29 +174,48 @@ def forward_decode( device=output.device, ) max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - kv_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) + if not use_custom: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + kv_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + else: + paged_attention_custom( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + ) return output @staticmethod From c893d7010b650084cc5b836b81ec6e357b5d0004 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 3 Jun 2024 18:19:47 -0400 Subject: [PATCH 400/413] Update linear.py Fix bias handling with tgemm Don't use custom matvec kernel for bf16 --- vllm/model_executor/layers/linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 94f1537d64154..29b5fe77ae705 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -92,7 +92,7 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = layer.weight - if is_hip() and x.view(-1, x.size(-1)).shape[0] == 1: + if is_hip() and x.dtype == torch.float16 and x.view(-1, x.size(-1)).shape[0] == 1: batched = False if x.dim() == 3: inp = x.view(-1, x.size(-1)) @@ -120,6 +120,8 @@ def apply(self, if bias is not None: return F.linear(x, weight) + bias return F.linear(x, weight) + elif bias is not None: + return F.linear(x, weight, bias) return tgemm.mm(x, weight) From a6af475474c124c9922209c58463b446d97ca754 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 5 Jun 2024 19:17:39 +0000 Subject: [PATCH 401/413] update base docker image --- Dockerfile.rocm | 2 +- csrc/quantization/fp8/gemm_kernel.cu | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 11f9efb2e7371..49edca537d18e 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,5 +1,5 @@ # default base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_release-2.1.2" +ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_staging" ARG COMMON_WORKDIR=/app diff --git a/csrc/quantization/fp8/gemm_kernel.cu b/csrc/quantization/fp8/gemm_kernel.cu index 0463cc75eac6c..558228db2c084 100644 --- a/csrc/quantization/fp8/gemm_kernel.cu +++ b/csrc/quantization/fp8/gemm_kernel.cu @@ -101,7 +101,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA auto d_scaleD = scaleD.data_ptr(); auto handle = at::cuda::getCurrentCUDABlasLtHandle(); - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + auto stream = at::cuda::getCurrentCUDAStream(); hipblaslt_ext::GemmPreference gemmPref; gemmPref.setMaxWorkspaceBytes(0); @@ -218,7 +218,7 @@ torch::Tensor fp8_gemm_16( auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); auto handle = at::cuda::getCurrentCUDABlasLtHandle(); - auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + auto stream = at::cuda::getCurrentCUDAStream(); hipblaslt_ext::GemmPreference gemmPref; gemmPref.setMaxWorkspaceBytes(0); From 6e7ea443f65e339fbd33c11658e80ffa8892c619 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 6 Jun 2024 00:05:35 +0000 Subject: [PATCH 402/413] remove apply_custom --- vllm/model_executor/layers/tuned_gemm.py | 38 +----------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index cc5ae2b8601f2..10bd187f01101 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -33,35 +33,16 @@ def load_best_sols(self): if self.tune_path is not None and Path(self.tune_path).is_file(): self.bestsols = pd.read_csv(self.tune_path) - def apply_custom(self, ds): - M, N, K = ds['M'], ds['N'], ds['K'] - #apply custom matvec (only for f16 dtype) - if N == 1: - ds1 = ds.copy() - ds1['libtype'] = 'custom' - if K == 8192 and (M == 1280 or M == 7168): #NOQA: SIM114 - ds1['solidx'] = 8 - return ds1 - elif K == 3584 and M == 8192: - ds1['solidx'] = 8 - return ds1 - elif K <= 8192 and K % 8 == 0 and M % 4 == 0: - ds1['solidx'] = 1 - return ds1 - return ds - def create_ds(self): df = self.bestsols solds = {} for i in range(len(df)): - ds = self.apply_custom(df.iloc[i]) + ds = df.iloc[i] key = (ds['M'], ds['N'], ds['K']) if ds['libtype'] == 'hipblaslt': soltype = 1 elif ds['libtype'] == 'rocblas': soltype = 2 - elif ds['libtype'] == 'custom': - soltype = 3 solds[key] = (soltype, int(ds['solidx'])) self.solids = solds #print('>>>',solds) @@ -90,23 +71,6 @@ def mm(self, inp, weights): if soltype == 1: #print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) - elif soltype == 3: - ##only matvec is supported currently - out = torch.empty(inp.shape[0], - weights.shape[0], - dtype=torch.float16, - device='cuda') - #print('>>>Matvec',inp.shape,weights.shape,soltype,solidx) - if solidx <= 1: - _custom_C.LLMM1(weights, inp, out, 4) - elif solidx == 2: - _custom_C.LLMM1(weights, inp, out, 2) - elif solidx == 8: - _custom_C.LLMM1(weights, inp, out, 8) - elif solidx == 20: - _custom_C.LLZZ(weights, inp, out, 0) - elif solidx == 21: - _custom_C.LLZZ(weights, inp, out, 1) elif soltype == 2: #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) From 32d4afa2730c68d221db83099bfc8cb66950a8e8 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Thu, 6 Jun 2024 14:18:37 -0500 Subject: [PATCH 403/413] Use inp_view for out = F.linear() in TunedGemm (#36) * use inp_view for out = F.linear() * add missing control path --- vllm/model_executor/layers/tuned_gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 10bd187f01101..096945f41bab6 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -98,8 +98,10 @@ def mm(self, inp, weights): _custom_C.LLMM1(weights, inp_view, out, 8) elif k <= 8192 and k % 8 == 0 and m % 4 == 0: _custom_C.LLMM1(weights, inp_view, out, 4) + else: + out = F.linear(inp_view, weights) else: - out = F.linear(inp, weights) + out = F.linear(inp_view, weights) if batched: return out.view(inp.shape[0], inp.shape[1], weights.shape[0]) else: From 65c02de7207c769a44baceb96b15c364e92526cf Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 6 Jun 2024 20:06:51 +0000 Subject: [PATCH 404/413] fix --- Dockerfile.rocm | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 49edca537d18e..04102560452a1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -133,6 +133,14 @@ FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm # vLLM (and gradlib) build stages FROM fetch_vllm AS build_vllm ARG COMMON_WORKDIR +# Install hipblaslt +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ +if ls /install/*.deb; then \ + apt-get purge -y hipblaslt \ + && dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ +fi # Build vLLM RUN cd vllm \ && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist From d571264f19cb55c8837efc7e03202247e4f3afdd Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 6 Jun 2024 16:22:53 -0400 Subject: [PATCH 405/413] Using rocm_flash_attention that supports bias computed from alibi slopes; Using attn_fwd triton kernel from ROCm/triton main_perf that does not cause triton compolier to hang (#38) --- vllm/attention/backends/rocm_flash_attn.py | 85 ++- vllm/attention/ops/triton_flash_attention.py | 532 +++++++------------ 2 files changed, 256 insertions(+), 361 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6019d917b4494..8acd401833cc9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -102,6 +102,62 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): use_cuda_graph: bool +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_alibi_bias_v2( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], + make_attn_mask: bool = True +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -155,7 +211,7 @@ def __init__( # AMD Radeon 7900 series (gfx1100) currently does not support # xFormers nor FlashAttention. As a temporary workaround, we use # naive PyTorch implementation of attention. - self.attn_fuc = _naive_attention() + self.attn_func = _naive_attention() logger.debug("Using naive attention in ROCmBackend") elif self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 @@ -229,13 +285,19 @@ def forward( # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) + att_masks = None + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias_v2( + self.alibi_slopes, query.dtype, + attn_metadata.prompt_lens, make_attn_mask=False) # type: ignore if self.use_naive_attn: - output = self.attn_fuc( + output = self.attn_func( query, key, value, attn_metadata.prompt_lens, self.scale, + att_masks ) else: output, _ = self.attn_func( @@ -249,6 +311,7 @@ def forward( attn_metadata.max_prompt_len, True, self.scale, + att_masks[0][None] if att_masks is not None else None, ) else: output = self.attn_func( @@ -304,17 +367,19 @@ def _naive_attention( value: torch.Tensor, prompt_lens: List[int], scale: float, + attn_masks: Optional[List[torch.Tensor]], ) -> torch.Tensor: num_tokens = query.shape[0] output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(prompt_lens): + for i, prompt_len in enumerate(prompt_lens): end = start + prompt_len out = _naive_masked_attention( query[None, start:end], key[None, start:end], value[None, start:end], scale, + attn_masks[i], ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) @@ -332,14 +397,16 @@ def _naive_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, + attn_mask: Optional[torch.Tensor], ) -> torch.Tensor: seq_len, _, _ = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min + if attn_mask is None: + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 77390f2d0d696..c99029175b5a2 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -61,81 +61,55 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor @triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 + if MASK_STEPS: # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], - actual_seqlen_k, - dtype=tl.int32) + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -145,13 +119,15 @@ def _attn_fwd_inner( qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 + qk += (bias * 1.44269504089) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) @@ -159,51 +135,29 @@ def _attn_fwd_inner( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i @@ -303,63 +257,23 @@ def _attn_fwd_inner( num_warps=4, ), ], - key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - hq, - hk, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -384,118 +298,82 @@ def attn_fwd( # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: + if (IS_CAUSAL): # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? return - is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) else: - bias_ptr = None + alibi_slope = None + if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * hq + off_h_q) \ - * seqlen_q * seqlen_k + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] else: - encoded_softmax_block_ptr = 0 + encoded_sm_ptrs = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -504,8 +382,11 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -517,96 +398,50 @@ def attn_fwd( else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its + # Compute for full blocks. Here we set causal to false regardless of its actual # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -619,45 +454,34 @@ def attn_fwd( start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 + if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] >= - out_mask_boundary[None, :]) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def check_args( @@ -687,8 +511,7 @@ def check_args( assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - # TODO: Fix assert to check head size <=256 once supported - assert head_size <= 128 + assert head_size <= 256 assert o.shape == q.shape assert (nheads_q % nheads_k) == 0 @@ -739,7 +562,7 @@ def forward( o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128} + unpadded_head_dims = {32, 64, 128, 256} if head_size not in unpadded_head_dims: padded_d_model = None for i in unpadded_head_dims: @@ -771,6 +594,8 @@ def forward( ) else: bias_strides = (0, 0, 0, 0) + alibi_strides = (0, 0) + M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) attn_fwd[grid]( q, @@ -778,28 +603,31 @@ def forward( v, bias, sm_scale, - None, + M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, + *alibi_strides, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - hq=nheads_q, - hk=nheads_k, + alibi_slopes=None, + HQ=nheads_q, + HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, + USE_BIAS=bias is not None, + USE_ALIBI=False, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, ) From b0ba2db99af8e0d33c38330ce859457ba63dd1c9 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 14 May 2024 15:13:41 +0000 Subject: [PATCH 406/413] adding rocm fp8 fp8 computation Using convert_fp8 kernel delete convert.cu clean up clean up remove extra kernels remove int8 -> fp8 convert fix naming fix typo clean up add compilation guard add convert_fp8 in cache_ops clean up adding missing quant config back fix the convert_fp8 issue convert_fp8 fix fix --- CMakeLists.txt | 12 + csrc/cache.h | 6 +- csrc/cache_kernels.cu | 77 +--- csrc/ops.h | 21 + csrc/pybind.cpp | 9 +- csrc/quantization/fp8/amd/gemm_kernel.cu | 269 +++++++++++ csrc/quantization/fp8/amd/quant_utils.cuh | 424 ++++++++++++------ csrc/quantization/fp8/common.cu | 89 +++- vllm/_custom_ops.py | 5 +- vllm/config.py | 2 +- .../layers/quantization/__init__.py | 4 +- .../layers/quantization/fp8_rocm.py | 303 +++++++++++++ vllm/model_executor/model_loader/loader.py | 4 + vllm/model_executor/models/llama.py | 52 +++ 14 files changed, 1041 insertions(+), 236 deletions(-) create mode 100644 csrc/quantization/fp8/amd/gemm_kernel.cu create mode 100644 vllm/model_executor/layers/quantization/fp8_rocm.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a7cb981ac3e4..ad562d9c996f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,13 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() +# +# Set rocm version dev int. +# +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_GPU_FLAGS "-DROCM_VERSION=${ROCM_VERSION_DEV_INT}") +endif() + # # Define extension targets # @@ -173,6 +180,11 @@ set(VLLM_EXT_SRC "csrc/moe_align_block_size_kernels.cu" "csrc/pybind.cpp") +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_EXT_SRC + "csrc/quantization/fp8/amd/gemm_kernel.cu") +endif() + if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) SET(CUTLASS_ENABLE_HEADERS_ONLY=ON) diff --git a/csrc/cache.h b/csrc/cache.h index 435ae3e57f555..064815b7403db 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -21,8 +21,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype); - -// Just for unittest -void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float scale, const std::string& kv_cache_dtype); + const std::string& kv_cache_dtype); \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d924ac39b89ca..2ab63b21db1fb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -308,79 +308,4 @@ void reshape_and_cache_flash( slot_mapping.data_ptr(), block_stride, key_stride, value_stride, num_heads, head_size, block_size); }); -} - -namespace vllm { - -template -__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, - Tout* __restrict__ dst_cache, - const float kv_scale, - const int64_t block_stride) { - const int64_t block_idx = blockIdx.x; - for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { - int64_t idx = block_idx * block_stride + i; - dst_cache[idx] = - fp8::scaled_convert(src_cache[idx], kv_scale); - } -} - -} // namespace vllm - -#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ - vllm::convert_fp8_kernel<<>>( \ - reinterpret_cast(src_cache.data_ptr()), \ - reinterpret_cast(dst_cache.data_ptr()), kv_scale, block_stride); - -// Only for testing. -void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float kv_scale, const std::string& kv_cache_dtype) { - torch::Device src_device = src_cache.device(); - torch::Device dst_device = dst_cache.device(); - TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") - TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK(src_device.index() == dst_device.index(), - "src and dst must be on the same GPU"); - at::cuda::OptionalCUDAGuard device_guard(src_device); - - int64_t num_blocks = src_cache.size(0); - int64_t block_stride = src_cache.stride(0); - - dim3 grid(num_blocks); - dim3 block(std::min(block_stride, int64_t(512))); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (kv_cache_dtype == "auto") { - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); - } - } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { - if (src_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (src_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (src_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, - vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (dst_cache.dtype() == at::ScalarType::Float) { - CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (dst_cache.dtype() == at::ScalarType::Half) { - CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { - CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); - } - } else { - TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); - } -} +} \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h index 567d9fae4bd2a..d6cdfab434f2c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -113,6 +113,27 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale); + +#ifdef USE_ROCM +torch::Tensor fp8_gemm( + torch::Tensor& a, + torch::Tensor& b, + torch::Tensor& scaleA, + torch::Tensor& scaleB, + torch::Tensor& scaleD, + int algo_idx +); + +torch::Tensor fp8_gemm_16( + torch::Tensor& a, + torch::Tensor& b, + torch::Tensor& scaleA, + torch::Tensor& scaleB, + int algo_idx +); +#endif + void moe_align_block_size(torch::Tensor topk_ids, int num_experts, int block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index cdbec4a34d77f..a4693ccc2ae75 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -66,6 +66,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("moe_align_block_size", &moe_align_block_size, "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); + ops.def("convert_fp8", &convert_fp8, + "Convert the key and value cache to fp8 data type"); + +#ifdef USE_ROCM + ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM with fp8 output"); + ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM with fp16 output"); +#endif ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, "Compute int8 quantized tensor for given scaling factor"); @@ -80,8 +87,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Reshape the key and value tensors and cache them"); cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, "Reshape the key and value tensors and cache them"); - cache_ops.def("convert_fp8", &convert_fp8, - "Convert the key and value cache to fp8 data type"); // Cuda utils pybind11::module cuda_utils = diff --git a/csrc/quantization/fp8/amd/gemm_kernel.cu b/csrc/quantization/fp8/amd/gemm_kernel.cu new file mode 100644 index 0000000000000..5464e9381e343 --- /dev/null +++ b/csrc/quantization/fp8/amd/gemm_kernel.cu @@ -0,0 +1,269 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define max_workspace_size 2 * 128 * 1024 * 1024 + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#ifndef CHECK_HIP_ERROR +#define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_HIPBLASLT_ERROR +#define CHECK_HIPBLASLT_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf( \ + stderr, "hipBLASLt error: '%s'(%d) at %s:%d\n", hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, + torch::Tensor& scaleD, int algo_idx) +{ + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + auto d_scaleD = scaleD.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::cuda::getCurrentCUDAStream(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E4M3_FNUZ, HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + inputs.scaleD = d_scaleD; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp8 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } + } + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; +} + +torch::Tensor fp8_gemm_16( + torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, int algo_idx) +{ + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::cuda::getCurrentCUDAStream(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, HIP_R_16F, + HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp16 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } + } + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; +} \ No newline at end of file diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index 35123d7fc65d4..23d975fe0f37e 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,9 +5,7 @@ #include #include -#include "../../../attention/dtype_fp8.cuh" -#include "../../../attention/dtype_float32.cuh" -#include "../../../attention/dtype_bfloat16.cuh" +#include "../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM @@ -309,212 +307,344 @@ vec_conversion(const Float8_& a) { // fp8 -> half template <> -__inline__ __device__ uint16_t -scaled_vec_conversion(const uint8_t& a, const float scale) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; +__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, float scale) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t& a, const float scale) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; - #else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = - scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion( - static_cast(a >> 8U), scale); - return tmp.u32; - #endif +__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, float scale) +{ +#if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; +#else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); + return tmp.u32; +#endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 -scaled_vec_conversion(const uint32_t& a, const float scale) { - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = - scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; +__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, float scale) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 -scaled_vec_conversion(const uint2& a, const float scale) { - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, float scale) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 -scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, - const float scale) { - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 -scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, - const float scale) { - __nv_bfloat162 res; - res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); - res.y = - scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); - return res; +__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, float scale) +{ + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t& a, const float scale) { - bf16_4_t res; - res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), - scale); - return res; +__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, float scale) +{ + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t -scaled_vec_conversion(const uint2& a, const float scale) { - bf16_4_t tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, float scale) +{ + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion( - const uint8_t& a, const float scale) { - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; +__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, float scale) +{ + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; } // fp8x2 -> float2 template <> -__inline__ __device__ float2 -scaled_vec_conversion(const uint16_t& a, const float scale) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; - #else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), - scale); - return res; - #endif +__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, float scale) +{ +#if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; +#else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); + return res; +#endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ -scaled_vec_conversion(const uint32_t& a, const float scale) { - Float4_ res; - res.x = scaled_vec_conversion((uint16_t)a, scale); - res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) +{ + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; } -// fp8x8 -> float8 +// fp8x4 -> float4 template <> -__inline__ __device__ Float8_ -scaled_vec_conversion(const uint2& a, const float scale) { - Float4_ tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, float scale) +{ + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; } -/* Quantize(HP / scale) => FP8 */ - -// TODO(Hai): vectorized to add +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, float scale) +{ + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} // half -> fp8 template <> -__inline__ __device__ uint8_t -scaled_vec_conversion(const uint16_t& a, const float scale) { - __half_raw tmp; - tmp.x = a; - - hip_fp8 f8{static_cast(tmp.data) / scale}; - return f8.data; +__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, float scale) +{ + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data / scale)}; + return f8.data; +} + +// halfx2 -> fp8x2 +template<> +__inline__ __device__ uint16_t scaled_vec_conversion(const uint32_t& a, float scale) +{ +#ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); +#else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; +#endif +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const uint2& a, float scale) +{ + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, float scale) +{ + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16& a, const float scale) { - hip_fp8 res{__bfloat162float(a) / scale}; - return res.data; +__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, float scale) +{ + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; } -// float -> fp8 +// bf16x2 -> fp8x2 template <> -__inline__ __device__ uint8_t -scaled_vec_conversion(const float& a, const float scale) { - hip_fp8 f8(a / scale); - return f8.data; +__inline__ __device__ uint16_t scaled_vec_conversion(const __nv_bfloat162& a, float scale) +{ + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; } -// fp8x4 -> float4 +// bf16x4 -> fp8x4 template <> -__inline__ __device__ float4 -scaled_vec_conversion(const uint32_t& a, const float scale) { - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ uint32_t scaled_vec_conversion(const bf16_4_t& a, float scale) +{ + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const bf16_8_t& a, float scale) +{ + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, float scale) +{ + hip_fp8 f8(a); + return f8.data; +} + +// floatx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion(const float2& a, float scale) +{ +#ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f,f2.f, 0, 0); +#else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +#endif +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t scaled_vec_conversion(const float4& a, float scale) +{ + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; } #endif // ENABLE_FP8 diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 55be3305a9b8c..937df5a0bec13 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,7 +7,41 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#ifdef USE_ROCM + #include "amd/quant_utils.cuh" +#else + #include "nvidia/quant_utils.cuh" +#endif + namespace vllm { + +template +__global__ void convert_fp8_kernel( + const Tin* __restrict__ src_data, Tout* __restrict__ dst_data, const float* scale, size_t N) +{ + const int64_t block_idx = blockIdx.x; + + using V_in_vec = typename Vec::Type; + using V_out_vec = typename Vec::Type; + auto dst_data_vec = reinterpret_cast(dst_data); + auto src_data_vec = reinterpret_cast(src_data); + + int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); + auto idx = startIdx; + if (idx >= N) { + return; + } + dst_data_vec[idx] = fp8::scaled_vec_conversion(src_data_vec[idx], *scale); + //dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); + + //for (int64_t i = 0; i < loopSize; ++i) { + // auto idx = startIdx + i; + // if (idx >= N) { + // return; + // } + // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); + //} +} __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { float old; @@ -83,7 +117,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } } -} // namespace vllm +} // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] @@ -122,3 +156,56 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] scale.data_ptr(), num_elems); }); } + +template +struct call_convert_fp8 +{ + void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) + { + const auto N = src_data.numel() / 2; + //std::cout << N << "\n"; + constexpr uint32_t loopSize = 1;//std::max(N / 50000000LL, 1); + constexpr dim3 numThreads{1024, 1, 1}; + auto neededBlocks = (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); + uint32_t actualBlocks = neededBlocks; + + //static uint32_t maxBlocks = 0; + //if (actualBlocks != maxBlocks) { + // maxBlocks = actualBlocks; + // std::cout << actualBlocks << "\n"; + //} + + const dim3 grid{actualBlocks, 1, 1}; + + const auto stream = at::cuda::getCurrentCUDAStream(); + + vllm::convert_fp8_kernel + <<>>(reinterpret_cast(src_data.data_ptr()), + reinterpret_cast(dst_data.data_ptr()), (float*)scale.data_ptr(), N); + } +}; + +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale) +{ + torch::Device src_device = src_data.device(); + torch::Device dst_device = dst_data.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + auto t1 = src_data.dtype(); + auto t2 = dst_data.dtype(); + if (src_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); + } +} \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 22cf5a44e341f..fe1c4eff698d3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -328,9 +328,8 @@ def swap_blocks(src: torch.Tensor, dst: torch.Tensor, def convert_fp8(output: torch.Tensor, input: torch.Tensor, - scale: float = 1.0, - kv_dtype: str = "fp8") -> None: - vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) + scale: float = 1.0) -> None: + vllm_ops.convert_fp8(output, input, torch.Tensor([scale])) #TODO: cuda_utils, custom_ar diff --git a/vllm/config.py b/vllm/config.py index 579a210ba3412..63471aa5301b1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -170,7 +170,7 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["gptq", "squeezellm"] + rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 7b9abe1b629a1..b963576aa4471 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig) from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fp8_rocm import Fp8RocmConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) @@ -16,12 +17,13 @@ GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.utils import is_hip QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, - "fp8": Fp8Config, + "fp8": Fp8Config if not is_hip() else Fp8RocmConfig, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py new file mode 100644 index 0000000000000..caa53fb6ceee8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -0,0 +1,303 @@ +from typing import Any, Dict, List, Optional, Tuple, Union, Iterator + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter +import torch.nn.functional as F +from safetensors import safe_open + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +import pandas as pd +import os + +try: + from vllm._C import ops as vllm_ops +except ImportError: + pass + +logger = init_logger(__name__) + + +class Fp8RocmConfig(QuantizationConfig): + def __init__(self) -> None: + # self.quantized_weights_path = config["quantized_weights"] + self._tuned = {} + self._stats = {} + gemm_type = os.getenv("FP8_GEMM", "fp8_16") + #print(f"Integral Cross factor = {self.factor}") + if gemm_type == "fp8_8": + self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8 + tuned_filename = "/projects/tuned_fp8_8.csv" + elif gemm_type == "fp8_16": + self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 + tuned_filename = "/projects/tuned_fp8_16.csv" + else: + raise Exception(f"Unknown fp8 gemm type: {gemm_type}") + try: + df = pd.read_csv(tuned_filename) + except: + return + + for i in range(len(df)): + shape = df.iloc[i] + m = shape["M"] + n = shape["N"] + k = shape["K"] + algo = shape["algo"] + self._tuned[(m, n, k)] = algo + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config) -> "Fp8RocmConfig": + return cls(config) + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.uint8, torch.float8_e4m3fnuz] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 94 + + @classmethod + def get_name(cls) -> str: + return "Fp8Rocm" + + def get_quant_method(self, + layer: torch.nn.Module) -> Optional["Fp8RocmLinearMethod"]: + if isinstance(layer, LinearBase): + return Fp8RocmLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8RocmLinearMethod(LinearMethodBase): + def __init__(self, config: Fp8RocmConfig): + self._config = config + + + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, { + **extra_weight_attrs, + "fp8_scales_shard_indexer": + self.scales_shard_indexer, + }) + + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fnuz, + ), + requires_grad=False, + ) + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0 + }) + + self._create_scale_param( + scale_name="weights_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param( + scale_name="activation_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param( + scale_name="output_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + + def process_weights_after_loading(self, layer: Module) -> None: + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + layer.activation_scaling_factor = Parameter(layer.activation_scaling_factor.max(), + requires_grad=False) + layer.output_scaling_factor = Parameter(layer.output_scaling_factor.reciprocal().max(), + requires_grad=False) + + max_w_scale = layer.weights_scaling_factor.max() + if len(layer.logical_widths) > 1: + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = _per_tensor_dequantize(layer.weight[start:end, :], + layer.weights_scaling_factor[idx]) + + layer.weight[start:end, :] = _per_tensor_quantize( + weight_dq, max_w_scale) + start = end + layer.weights_scaling_factor = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight, requires_grad=False) + + + def scales_shard_indexer( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, int): + pass + elif isinstance(shard_id, str): + if shard_id not in qkv_idxs: + raise ValueError(f"Unknown shard_id: {shard_id}") + shard_id = qkv_idxs[shard_id] + else: + ValueError(f"Shard id must be int or str but got {type(shard_id)}") + + # To handle the scalar loaded tensor + if loaded_weight.numel() == 1 and len(loaded_weight.shape) != 0: + loaded_weight = torch.scalar_tensor(loaded_weight[0]) + + return param[shard_id], loaded_weight + + def apply_fp8_16( + self, + x: torch.Tensor, + weight: torch.Tensor, + asf: torch.Tensor, + wsf: torch.Tensor, + osf: torch.Tensor, + ) -> torch.Tensor: + x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + vllm_ops.convert_fp8(x8, x, asf) + m = weight.shape[0] + n = x.shape[0] + k = x.shape[1] + + algo = self._config._tuned.get((m, n, k)) + if algo is None: + import os + + if os.getenv("TUNE_FP8") == "1": + try: + df = pd.read_csv("/projects/fp8_shapes.csv") + except: + df = pd.DataFrame(columns=["M", "N", "K"]) + df = pd.concat( + [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] + ).drop_duplicates() + df.to_csv("/projects/fp8_shapes.csv", index=False) + algo = 0 + res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) + return res + + def apply_fp8_8( + self, + x: torch.Tensor, + weight: torch.Tensor, + asf: torch.Tensor, + wsf: torch.Tensor, + osf: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert not bias + x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + vllm_ops.convert_fp8(x8, x, asf) + m = weight.shape[0] + n = x.shape[0] + k = x.shape[1] + + algo = self._config._tuned.get((m, n, k)) + if algo is None: + import os + + if os.getenv("TUNE_FP8") == "1": + try: + df = pd.read_csv("/projects/fp8_shapes.csv") + except: + df = pd.DataFrame(columns=["M", "N", "K"]) + df = pd.concat( + [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] + ).drop_duplicates() + df.to_csv("/projects/fp8_shapese.csv", index=False) + algo = 0 + + res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) + res16 = torch.empty_like(res, dtype=torch.float16) + vllm_ops.convert_fp8(res16, res, 1/osf) + return res16 + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + weight: torch.Tensor = layer.weight + if weight.dtype == torch.float8_e4m3fnuz: + asf: torch.Tensor = layer.activation_scaling_factor * 2 + wsf: torch.Tensor = layer.weights_scaling_factor * 2 + osf: torch.Tensor = layer.output_scaling_factor / 2 + + return self._config.gemm_method(self, x, weight, asf, wsf, osf) + + return F.linear(x, weight, bias) + + +def _per_tensor_quantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fnuz) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fnuz) + + +def _per_tensor_dequantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index b7b5b5e7695f4..aa143c65d82b9 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -247,6 +247,10 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) + if model_config.quantization == 'fp8' and model_config.quantization_param_path is not None: + model.load_quantized_weights( + safetensors_weights_iterator([model_config.model + model_config.quantization_param_path]) + ) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 218876d2b42c9..2b8d3573f45cf 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -438,6 +438,58 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + #with open("/projects/a.txt", "r") as f: + # j = json.load(f) + # for k, v in j.items(): + # params_dict[k].data.copy_(v) + quant_shards = [ + ("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj + ("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj + ] + quant_map = [ + ("mlp.down_proj", "mlp.proj"), + ("self_attn.o_proj", "attention.dense"), + ("self_attn.qkv_proj", "attention.qkv"), + ] + for name, loaded_weight in weights: + #print(name) + name = name.replace('transformer', 'model') + name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor') + loaded_weight = loaded_weight.to("cuda") + if loaded_weight.dtype == torch.int8: + loaded_weight[loaded_weight == -128] = 0 + assert loaded_weight.is_contiguous + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) + for (param_name, weight_name, shard_id) in quant_shards: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + for (param_name, weight_name) in quant_map: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + if "activation_scaling_factor" in name or "weights_scaling_factor" in name: + param.data.copy_(loaded_weight) + elif "output_scaling_factor" in name: + param.data.copy_(loaded_weight) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + break # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should From 86bbfef79330fd58cea288bd3eb3dd97fbe0a38b Mon Sep 17 00:00:00 2001 From: Matthew Wong Date: Thu, 6 Jun 2024 21:16:30 +0000 Subject: [PATCH 407/413] Fixes from main: Restore use of FA Triton as default update base docker image remove apply_custom Use inp_view for out = F.linear() in TunedGemm (#36) * use inp_view for out = F.linear() * add missing control path fix --- Dockerfile.rocm | 10 ++++- vllm/envs.py | 2 +- vllm/model_executor/layers/tuned_gemm.py | 54 +++++++----------------- 3 files changed, 26 insertions(+), 40 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 38872c6c0931d..6d0dd31d346f1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,5 +1,5 @@ # default base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_release-2.1.2" +ARG BASE_IMAGE="rocm/pytorch:rocm6.1.1_ubuntu20.04_py3.9_pytorch_staging" ARG COMMON_WORKDIR=/app @@ -133,6 +133,14 @@ FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm # vLLM (and gradlib) build stages FROM fetch_vllm AS build_vllm ARG COMMON_WORKDIR +# Install hipblaslt +RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ +if ls /install/*.deb; then \ + apt-get purge -y hipblaslt \ + && dpkg -i /install/*.deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \ +fi # Build vLLM RUN cd vllm \ && python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist diff --git a/vllm/envs.py b/vllm/envs.py index 5cc766bcf8705..35421b9026f1e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -8,7 +8,7 @@ VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None - VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_USE_TRITON_FLASH_ATTN: bool = True RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index a76f7c6af3d36..a3d299c05caef 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -33,35 +33,16 @@ def load_best_sols(self): if self.tune_path is not None and Path(self.tune_path).is_file(): self.bestsols = pd.read_csv(self.tune_path) - def apply_custom(self, ds): - M, N, K = ds['M'], ds['N'], ds['K'] - #apply custom matvec (only for f16 dtype) - if N == 1: - ds1 = ds.copy() - ds1['libtype'] = 'custom' - if K == 8192 and (M == 1280 or M == 7168): #NOQA: SIM114 - ds1['solidx'] = 8 - return ds1 - elif K == 3584 and M == 8192: - ds1['solidx'] = 8 - return ds1 - elif K <= 8192 and K % 8 == 0 and M % 4 == 0: - ds1['solidx'] = 1 - return ds1 - return ds - def create_ds(self): df: pd.DataFrame = self.bestsols solds = {} for i in range(len(df)): - ds = self.apply_custom(df.iloc[i]) + ds = df.iloc[i] key = (ds['M'], ds['N'], ds['K']) if ds['libtype'] == 'hipblaslt': soltype = 1 elif ds['libtype'] == 'rocblas': soltype = 2 - elif ds['libtype'] == 'custom': - soltype = 3 solds[key] = (soltype, int(ds['solidx'])) self.solids = solds #print('>>>',solds) @@ -90,23 +71,6 @@ def mm(self, inp, weights): if soltype == 1: #print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) - elif soltype == 3: - ##only matvec is supported currently - out = torch.empty(inp.shape[0], - weights.shape[0], - dtype=torch.float16, - device='cuda') - #print('>>>Matvec',inp.shape,weights.shape,soltype,solidx) - if solidx <= 1: - _custom_C.LLMM1(weights, inp, out, 4) - elif solidx == 2: - _custom_C.LLMM1(weights, inp, out, 2) - elif solidx == 8: - _custom_C.LLMM1(weights, inp, out, 8) - elif solidx == 20: - _custom_C.LLZZ(weights, inp, out, 0) - elif solidx == 21: - _custom_C.LLZZ(weights, inp, out, 1) elif soltype == 2: #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) @@ -124,7 +88,21 @@ def mm(self, inp, weights): }) ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - out = F.linear(inp, weights) + + if n == 1 and inp_view.dtype == torch.float16: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + if (k == 8192 and + (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weights, inp_view, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weights, inp_view, out, 4) + else: + out = F.linear(inp_view, weights) + else: + out = F.linear(inp_view, weights) if batched: return out.view(inp.shape[0], inp.shape[1], weights.shape[0]) else: From a8228756cfece32ff3225b3f17c14e032dbc6187 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:01:33 -0400 Subject: [PATCH 408/413] Re-applying G42 bias triton fix on 0.4.3 (#41) * Using rocm_flash_attention that supports bias computed from alibi slopes; Using attn_fwd triton kernel from ROCm/triton main_perf that does not cause triton compolier to hang * Uninitialized variable fix --- vllm/attention/backends/rocm_flash_attn.py | 87 +++- vllm/attention/ops/triton_flash_attention.py | 517 ++++++------------- 2 files changed, 251 insertions(+), 353 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e92e6c5e2dc8d..ad6ec10892b6e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -165,6 +165,62 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_alibi_bias_v2( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], + make_attn_mask: bool = True +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + class ROCmFlashAttentionImpl(AttentionImpl): """ @@ -324,7 +380,12 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + att_masks = None if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias_v2( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens, make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, @@ -336,8 +397,13 @@ def forward( prefill_meta.max_prefill_seq_len, True, self.scale, + att_masks[0][None] if att_masks is not None else None, ) elif self.use_naive_attn: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias_v2( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens, make_attn_mask=True) # type: ignore if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) @@ -348,6 +414,7 @@ def forward( value, prefill_meta.seq_lens, self.scale, + att_masks ) else: out = self.attn_func( @@ -408,16 +475,18 @@ def _naive_attention( value: torch.Tensor, seq_lens: List[int], scale: float, + attn_masks: Optional[List[torch.Tensor]], ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, seq_len in enumerate(seq_lens): + for i, seq_len in enumerate(seq_lens): end = start + seq_len out = _naive_masked_attention( query[start:end], key[start:end], value[start:end], scale, + attn_masks[i], ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) @@ -431,16 +500,18 @@ def _naive_masked_attention( key: torch.Tensor, value: torch.Tensor, scale: float, + attn_mask: Optional[torch.Tensor], ) -> torch.Tensor: seq_len, head_size, head_dim = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min + if attn_mask is None: + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out + return out \ No newline at end of file diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index f94211116a746..c99029175b5a2 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -61,81 +61,55 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor @triton.jit -def _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - actual_seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, - OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, -): +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 + if MASK_STEPS: # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. - # check if this masking works for that case. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], - actual_seqlen_k, - dtype=tl.int32) + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -145,13 +119,15 @@ def _attn_fwd_inner( qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 + qk += (bias * 1.44269504089) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) @@ -159,51 +135,29 @@ def _attn_fwd_inner( # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + acc += tl.dot(p.to(v.type.element_ty), v) + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i @@ -306,60 +260,20 @@ def _attn_fwd_inner( key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit -def attn_fwd( - Q, - K, - V, - bias, - sm_scale, - L, - Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, - cu_seqlens_q, - cu_seqlens_k, - dropout_p, - philox_seed, - philox_offset_base, - encoded_softmax, - HQ: tl.constexpr, - HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, - MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, - VARLEN: tl.constexpr, - IS_CAUSAL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, -): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, + stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -384,120 +298,82 @@ def attn_fwd( # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: + if (IS_CAUSAL): # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? return # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) else: - bias_ptr = None + alibi_slope = None + if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] else: - encoded_softmax_block_ptr = 0 + encoded_sm_ptrs = None # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) @@ -506,8 +382,11 @@ def attn_fwd( # have native e^x support in HW. qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + q = (q * qk_scale).to(q.type.element_ty) # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -519,96 +398,50 @@ def attn_fwd( else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its + # Compute for full blocks. Here we set causal to false regardless of its actual # value because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, + start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, + n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -621,45 +454,34 @@ def attn_fwd( start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 + if IS_CAUSAL: if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] >= - out_mask_boundary[None, :]) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def check_args( @@ -772,6 +594,8 @@ def forward( ) else: bias_strides = (0, 0, 0, 0) + alibi_strides = (0, 0) + M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) attn_fwd[grid]( q, @@ -779,19 +603,21 @@ def forward( v, bias, sm_scale, - None, + M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, + *alibi_strides, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=None, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, @@ -800,7 +626,8 @@ def forward( IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, + USE_BIAS=bias is not None, + USE_ALIBI=False, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, ) From 9d2f09313438f68647fcfb54cb4a92c9c00c3e21 Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Fri, 7 Jun 2024 05:57:46 -0500 Subject: [PATCH 409/413] Fix RCCL pkg broken install, update linear.py custom logic, update requirements, disable custom_C for CUDA (#42) --- CMakeLists.txt | 8 ++++---- Dockerfile.rocm | 2 ++ requirements-rocm.txt | 1 + vllm/model_executor/layers/linear.py | 28 ++-------------------------- 4 files changed, 9 insertions(+), 30 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ad562d9c996f3..15b9cfe677a57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -343,7 +343,7 @@ add_custom_target(default) if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") message(STATUS "Enabling C extension.") add_dependencies(default _C) - add_dependencies(default _custom_C) + message(STATUS "Enabling moe extension.") add_dependencies(default _moe_C) @@ -357,7 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") endif() endif() -if(VLLM_GPU_LANG STREQUAL "CUDA") - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) +if(VLLM_GPU_LANG STREQUAL "HIP") + message(STATUS "Enabling custom extension.") + add_dependencies(default _custom_C) endif() diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 6d0dd31d346f1..83a483075f8a4 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -178,6 +178,8 @@ RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \ RUN --mount=type=bind,from=export_rccl,src=/,target=/install \ if ls /install/*.deb; then \ dpkg -i /install/*.deb \ + # RCCL needs to be installed twice + && dpkg -i /install/*.deb \ && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status; \ fi diff --git a/requirements-rocm.txt b/requirements-rocm.txt index cc42839a975d0..cc60d21e717b5 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -4,3 +4,4 @@ # Dependencies for AMD GPUs ray >= 2.10.0 pytest-asyncio +pandas # Required for fp8 linear diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 29b5fe77ae705..476301a216c48 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -92,34 +92,10 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = layer.weight - if is_hip() and x.dtype == torch.float16 and x.view(-1, x.size(-1)).shape[0] == 1: - batched = False - if x.dim() == 3: - inp = x.view(-1, x.size(-1)) - batched = True - else: - inp = x - m, k = weight.shape[0], inp.shape[1] - out = torch.empty(inp.shape[0], - weight.shape[0], - dtype=inp.dtype, - device='cuda') - if (k == 8192 and - (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): - _custom_C.LLMM1(weight, inp, out, 8) - elif k <= 8192 and k % 8 == 0 and m % 4 == 0: - _custom_C.LLMM1(weight, inp, out, 4) - else: - out = F.linear(inp, weight) - if batched: - out = out.view(x.shape[0], x.shape[1], weight.shape[0]) - if bias is not None: - out = out + bias - return out if self.separate_bias_add: if bias is not None: - return F.linear(x, weight) + bias - return F.linear(x, weight) + return tgemm.mm(x, weight) + bias + return tgemm.mm(x, weight) elif bias is not None: return F.linear(x, weight, bias) return tgemm.mm(x, weight) From 27eea97d7d55fb4c0530d48c6ede0ea560b7ed1c Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:52:46 -0500 Subject: [PATCH 410/413] Linting main in line with upstream requirements (#43) * Linting * Fix linting for triton: unmeld if with constexpr --- .../kernels/benchmark_paged_attention.py | 3 +- csrc/custom/custom.cu | 137 +- csrc/custom/custom_kernels.cu | 604 ++++---- csrc/custom/fused_kernels.cu | 343 ++--- .../custom/paged_attention/attention_ll4mi.cu | 1335 +++++++++-------- csrc/ops.h | 26 +- csrc/pybind.cpp | 4 +- csrc/quantization/fp8/amd/gemm_kernel.cu | 514 ++++--- csrc/quantization/fp8/amd/quant_utils.cuh | 491 +++--- csrc/quantization/fp8/common.cu | 156 +- vllm/attention/backends/rocm_flash_attn.py | 91 +- vllm/attention/ops/triton_flash_attention.py | 295 ++-- vllm/distributed/communication_op.py | 3 +- vllm/distributed/parallel_state.py | 3 +- vllm/model_executor/layers/linear.py | 2 - .../layers/quantization/__init__.py | 2 +- .../layers/quantization/fp8_rocm.py | 123 +- vllm/model_executor/model_loader/loader.py | 9 +- vllm/model_executor/models/llama.py | 14 +- 19 files changed, 2197 insertions(+), 1958 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 0fcfc0a295ca2..d0d990410bc6e 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -81,8 +81,7 @@ def main( if not args.custom_paged_attn: global PARTITION_SIZE PARTITION_SIZE = 512 - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index d75b2d2e41005..3da25ece3e87c 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -6,94 +6,89 @@ namespace py = pybind11; // declare templates for front (cpp) and back (cuda) sides of function: -//template - -void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block); -void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block) { - int M = in_a.size(0); - int K = in_a.size(1); - LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block); +// template + +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); +void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int rows_per_block) { + int M = in_a.size(0); + int K = in_a.size(1); + LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); } -void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream,const int rows_per_block); - -//template -void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block=4) { - int M = in_a.size(0); - int K = in_a.size(1); - //if (N != in_b.numel()) - // throw std::invalid_argument("Size mismatch A.numel(): " + std::to_string(in_a.numel()) - // + ", B.numel(): " + std::to_string(in_b.numel())); - - //out_c.resize_({N}); - - // call the kernel function... - LLGemm1(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block); +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); + +// template +void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int rows_per_block = 4) { + int M = in_a.size(0); + int K = in_a.size(1); + // if (N != in_b.numel()) + // throw std::invalid_argument("Size mismatch A.numel(): " + + // std::to_string(in_a.numel()) + // + ", B.numel(): " + + // std::to_string(in_b.numel())); + + // out_c.resize_({N}); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); } -void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx); +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx); -void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int solidx=0) { - int M = in_a.size(0); - int K = in_a.size(1); +void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int solidx = 0) { + int M = in_a.size(0); + int K = in_a.size(1); - LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), - out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),solidx); + LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), solidx); } // instantiate the CPP template for T=float: -//template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c); - - -void MMGPUKernel(float *in_a, float *in_b, float *out_c, - int numARows, int numAColumns, - int numBRows, int numBColumns, - int numCRows, int numCColumns, - cudaStream_t stream); +// template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor +// out_c); +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream); void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { - auto matA_sizes { in_a.sizes() }; - auto matB_sizes { in_b.sizes() }; - auto matO_sizes { out_c.sizes() }; - MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), - matA_sizes[0], matA_sizes[1], - matB_sizes[0], matB_sizes[1], - matO_sizes[0], matO_sizes[1], - at::cuda::getCurrentCUDAStream()); + auto matA_sizes{in_a.sizes()}; + auto matB_sizes{in_b.sizes()}; + auto matO_sizes{out_c.sizes()}; + MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), matA_sizes[0], matA_sizes[1], + matB_sizes[0], matB_sizes[1], matO_sizes[0], matO_sizes[1], + at::cuda::getCurrentCUDAStream()); } -void paged_attention_custom( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, +void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, + float scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int block_size, + int max_context_len, #if 0 torch::Tensor& qk_out, torch::Tensor& softmax_out, #endif - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); // declare the extension module with the AddGPU function: -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ - m.doc() = "pybind11 example plugin"; - m.def("LLMM1", &LLMM1); - m.def("LLMM_Silu", &LLMM_Silu); - m.def("LLZZ", &LLZZ); - m.def( - "paged_attention_custom", - &paged_attention_custom, - "PagedAttention LL4Mi Custom."); -//m.def("MMCustomGPU", &MMCustomGPU); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "pybind11 example plugin"; + m.def("LLMM1", &LLMM1); + m.def("LLMM_Silu", &LLMM_Silu); + m.def("LLZZ", &LLZZ); + m.def("paged_attention_custom", &paged_attention_custom, + "PagedAttention LL4Mi Custom."); + // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index b5ab0dbe8317c..6321f7ba23b3f 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -7,361 +7,355 @@ constexpr int WARP_SIZE = 64; template __device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); + return __builtin_nontemporal_load(addr); } __device__ __forceinline__ float4 load_ntmprl(const float4* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); - return make_float4(dat0,dat1,dat2,dat3); + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); } -//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being -//grid is M/A_NUM_ROWS blocks +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks template -__global__ void LLGemm1_kernel(float4 *af4, __half2 *bf4, __half2 *c) { - __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; - const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * blockDim.x; - //int row_addr_1 = row_addr + CUDA_NUM_THREADS; - //int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; - //int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; - const int threadid = threadIdx.x; - const int warp = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid/16; - const int qthreadid = threadid%16; - float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - //float4 colB_elem4; - __half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; - __half2 acch2; - __half2 oval; - - //rowA_elem4 = af4[row_addr + threadid]; - //__syncthreads(); - //rowA_elem4_1 = af4[row_addr_1 + threadid]; - //rowA_elem4_2 = af4[row_addr_2 + threadid]; - //rowA_elem4_3 = af4[row_addr_3 + threadid]; - #pragma unroll - for (int i=0; i(&colB_elem4); - //auto Bf2x = *Bh2ptr; - //auto Bf2y = *(Bh2ptr+1); - //auto Bf2z = *(Bh2ptr+2); - //auto Bf2w = *(Bh2ptr+3); - auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); - __half2 *ah2lptr; - #pragma unroll - for (int i=0; i= 1; mask /= 2) { - #pragma unroll - for (int i=0; i(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } - ////if (qthreadid= 1; mask /= 2) { - //#pragma unroll - //for (int i=0; i= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i8) { - // #pragma unroll - // for (int j=0; j<8; j++) { - // acc[2*threadid] += red_smem[2*threadid][j]; - // acc[2*threadid+1] += red_smem[2*threadid+1][j]; - // } - // } - // #pragma unroll - // for (int j=0; j= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i8) { + // #pragma unroll + // for (int j=0; j<8; j++) { + // acc[2*threadid] += red_smem[2*threadid][j]; + // acc[2*threadid+1] += red_smem[2*threadid+1][j]; + // } + // } + // #pragma unroll + // for (int j=0; j -void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) { - float4 *af4 = reinterpret_cast(in_a); - auto *bf4 = reinterpret_cast<__half2*>(in_b); - auto *c = reinterpret_cast<__half2*>(out_c); - //constexpr int A_ROWS_PER_BLOCK = 8; - const int NUM_THREADS = K*2/16; - int NUM_BLOCKS = M/rows_per_block; - if (rows_per_block==2) { - LLGemm1_kernel<2><<>>(af4, bf4, c); - } - else if (rows_per_block==4) { - LLGemm1_kernel<4><<>>(af4, bf4, c); - } - else if (rows_per_block==8) { - LLGemm1_kernel<8><<>>(af4, bf4, c); - } - else if (rows_per_block==16) { - LLGemm1_kernel<16><<>>(af4, bf4, c); - } - else { - NUM_BLOCKS = M/4; - LLGemm1_kernel<4><<>>(af4, bf4, c); - } - +// template +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<__half2*>(out_c); + // constexpr int A_ROWS_PER_BLOCK = 8; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm1_kernel<2><<>>(af4, bf4, c); + } else if (rows_per_block == 4) { + LLGemm1_kernel<4><<>>(af4, bf4, c); + } else if (rows_per_block == 8) { + LLGemm1_kernel<8><<>>(af4, bf4, c); + } else if (rows_per_block == 16) { + LLGemm1_kernel<16><<>>(af4, bf4, c); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel<4><<>>(af4, bf4, c); + } - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } // instantiate the kernel template for T=float: -//template void AddGPUKernel(float *in_a, float *in_b, float *out_c, const int M, const int K, cudaStream_t stream); +// template void AddGPUKernel(float *in_a, float *in_b, float *out_c, +// const int M, const int K, cudaStream_t stream); const unsigned int TILE_WIDTH = 32; // Compute C = A * B -__global__ void matrixMultiplyShared(float *A, float *B, float *C, - int numARows, int numAColumns, - int numBRows, int numBColumns, - int numCRows, int numCColumns) { - __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 - __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; - - int Row = blockDim.y * blockIdx.y + threadIdx.y; - int Col = blockDim.x * blockIdx.x + threadIdx.x; - float Cvalue = 0.0; - sA[threadIdx.y][threadIdx.x] = 0.0; - sB[threadIdx.y][threadIdx.x] = 0.0; - - for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { - if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { - sA[threadIdx.y][threadIdx.x] = A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; - } else { - sA[threadIdx.y][threadIdx.x] = 0.0; - } - if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { - sB[threadIdx.y][threadIdx.x] = B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; - } else { - sB[threadIdx.y][threadIdx.x] = 0.0; - } - __syncthreads(); - for (int j = 0; j < TILE_WIDTH; ++j) { - Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; - } - } - if (Row < numCRows && Col < numCColumns) { - C[Row * numCColumns + Col] = Cvalue; - } +__global__ void matrixMultiplyShared(float* A, float* B, float* C, int numARows, + int numAColumns, int numBRows, + int numBColumns, int numCRows, + int numCColumns) { + __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 + __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; + + int Row = blockDim.y * blockIdx.y + threadIdx.y; + int Col = blockDim.x * blockIdx.x + threadIdx.x; + float Cvalue = 0.0; + sA[threadIdx.y][threadIdx.x] = 0.0; + sB[threadIdx.y][threadIdx.x] = 0.0; + + for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { + if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { + sA[threadIdx.y][threadIdx.x] = + A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; + } else { + sA[threadIdx.y][threadIdx.x] = 0.0; + } + if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { + sB[threadIdx.y][threadIdx.x] = + B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; + } else { + sB[threadIdx.y][threadIdx.x] = 0.0; + } + __syncthreads(); + for (int j = 0; j < TILE_WIDTH; ++j) { + Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; + } + } + if (Row < numCRows && Col < numCColumns) { + C[Row * numCColumns + Col] = Cvalue; + } } - -void MMGPUKernel(float *in_a, float *in_b, float *out_c, - int numARows, int numAColumns, - int numBRows, int numBColumns, - int numCRows, int numCColumns, - cudaStream_t stream) { - - // Initialize the grid and block dimensions - dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); - dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); - //@@ Launch the GPU Kernel here - matrixMultiplyShared <<>> - (in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, numCColumns); - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream) { + // Initialize the grid and block dimensions + dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); + dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); + //@@ Launch the GPU Kernel here + matrixMultiplyShared<<>>( + in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, + numCColumns); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } - - -template -__global__ -__launch_bounds__(512) -void HGEMV_WFPerRow(int m, int n, const _Float16 *A, int lda, const _Float16 *x, _Float16 *y) -{ +template +__global__ __launch_bounds__(512) void HGEMV_WFPerRow( + int m, int n, const _Float16* A, int lda, const _Float16* x, _Float16* y) { int num_row_per_block = CTA / nThreads_per_row; - int row_id = (blockIdx.x*num_row_per_block+threadIdx.y)*MT0; - int inc = (gridDim.x * num_row_per_block)*MT0; + int row_id = (blockIdx.x * num_row_per_block + threadIdx.y) * MT0; + int inc = (gridDim.x * num_row_per_block) * MT0; while (row_id < m) { float2 sum2[MT0]; #pragma unroll - for (int i = 0; i < MT0; ++i) - { - sum2[i] = {0.0,0.0}; + for (int i = 0; i < MT0; ++i) { + sum2[i] = {0.0, 0.0}; } - for (int j = threadIdx.x; j < n; j += (nThreads_per_row*MT1)){ - bool is_active = j < n; - if (is_active) { - float2 x2[MT1>>1]; + for (int j = threadIdx.x; j < n; j += (nThreads_per_row * MT1)) { + bool is_active = j < n; + if (is_active) { + float2 x2[MT1 >> 1]; #pragma unroll - for(int offset = 0; offset < MT1; offset += 2) - { - x2[offset>>1] = {x[j+nThreads_per_row*offset], x[j+nThreads_per_row*(offset+1)]}; - } - float2 a2[MT0][MT1>>1]; + for (int offset = 0; offset < MT1; offset += 2) { + x2[offset >> 1] = {x[j + nThreads_per_row * offset], + x[j + nThreads_per_row * (offset + 1)]}; + } + float2 a2[MT0][MT1 >> 1]; #pragma unroll - for (int i = 0; i < MT0; i++) - { + for (int i = 0; i < MT0; i++) { #pragma unroll - for (int offset = 0; offset < MT1; offset += 2) - { - a2[i][offset>>1] = {A[(row_id+i)*n+j+nThreads_per_row*offset], A[(row_id+i)*n+j+nThreads_per_row*(offset+1)]}; - } - } + for (int offset = 0; offset < MT1; offset += 2) { + a2[i][offset >> 1] = { + A[(row_id + i) * n + j + nThreads_per_row * offset], + A[(row_id + i) * n + j + nThreads_per_row * (offset + 1)]}; + } + } #pragma unroll - for (int i = 0; i < MT0; i++) - { + for (int i = 0; i < MT0; i++) { #pragma unroll - for (int offset = 0; offset < (MT1>>1); offset++) - { - sum2[i] += a2[i][offset]*x2[offset]; - } - } - + for (int offset = 0; offset < (MT1 >> 1); offset++) { + sum2[i] += a2[i][offset] * x2[offset]; + } } + } } float sum[MT0]; #pragma unroll - for (int i = 0; i < MT0; i++) - { - sum[i] = sum2[i].x+sum2[i].y; + for (int i = 0; i < MT0; i++) { + sum[i] = sum2[i].x + sum2[i].y; } #pragma unroll - for (int i = 0; i < MT0; i++) - { -#pragma unroll - for (int offset = nThreads_per_row >> 1; offset >= 1; offset = offset >> 1) { - sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); - } + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = nThreads_per_row >> 1; offset >= 1; + offset = offset >> 1) { + sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); + } } - if (threadIdx.x == 0) - { + if (threadIdx.x == 0) { #pragma unroll - for (int i = 0; i < MT0; i++) - { - y[row_id+i] = sum[i]; - } + for (int i = 0; i < MT0; i++) { + y[row_id + i] = sum[i]; + } } row_id += inc; } } -void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx=0) { - //m -> M, n-> K - dim3 grid(1024); - dim3 block(64, 8); - if (solidx==0) { - HGEMV_WFPerRow<64, 512, 4, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - else if (solidx==1) { - HGEMV_WFPerRow<64, 512, 2, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - else if (solidx==2) { - HGEMV_WFPerRow<64, 512, 1, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - else { - HGEMV_WFPerRow<64, 512, 4, 8><<>>(M, K, reinterpret_cast(in_a), K, - reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); - } - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx = 0) { + // m -> M, n-> K + dim3 grid(1024); + dim3 block(64, 8); + if (solidx == 0) { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 1) { + HGEMV_WFPerRow<64, 512, 2, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 2) { + HGEMV_WFPerRow<64, 512, 1, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } diff --git a/csrc/custom/fused_kernels.cu b/csrc/custom/fused_kernels.cu index 5a4a11f914eb9..4f3eea4562949 100644 --- a/csrc/custom/fused_kernels.cu +++ b/csrc/custom/fused_kernels.cu @@ -5,188 +5,191 @@ constexpr int WARP_SIZE = 64; -template +template __device__ __forceinline__ T silu(const T& x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } template __device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); + return __builtin_nontemporal_load(addr); } __device__ __forceinline__ float4 load_ntmprl(const float4* addr) { - auto addr_alias = reinterpret_cast(addr); - auto dat0 = loadnt(addr_alias); - auto dat1 = loadnt(addr_alias + 1); - auto dat2 = loadnt(addr_alias + 2); - auto dat3 = loadnt(addr_alias + 3); - //auto dat0 = *(addr_alias); - //auto dat1 = *(addr_alias+1); - //auto dat2 = *(addr_alias+2); - //auto dat3 = *(addr_alias+3); - return make_float4(dat0,dat1,dat2,dat3); + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); } -//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being -//grid is M/A_NUM_ROWS blocks +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks template -__global__ void LLGemm_Silu_kernel(float4 *af4, __half2 *bf4, _Float16 *c, const int d) { - __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; - const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK/2 * blockDim.x; - const int row_addr_d = row_addr + d * blockDim.x; - //int row_addr_1 = row_addr + CUDA_NUM_THREADS; - //int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; - //int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; - const int threadid = threadIdx.x; - const int warp = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid/16; - const int qthreadid = threadid%16; - float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; - //float4 colB_elem4; - __half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w; - float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; - float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; - __half2 acch2; - __half2 oval; - - //rowA_elem4 = af4[row_addr + threadid]; - //__syncthreads(); - //rowA_elem4_1 = af4[row_addr_1 + threadid]; - //rowA_elem4_2 = af4[row_addr_2 + threadid]; - //rowA_elem4_3 = af4[row_addr_3 + threadid]; - #pragma unroll - for (int i=0; i(&colB_elem4); - //auto Bf2x = *Bh2ptr; - //auto Bf2y = *(Bh2ptr+1); - //auto Bf2z = *(Bh2ptr+2); - //auto Bf2w = *(Bh2ptr+3); - auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); - __half2 *ah2lptr; - #pragma unroll - for (int i=0; i= 1; mask /= 2) { - #pragma unroll - for (int i=0; i= 1; mask /= 2) { - //#pragma unroll - //for (int i=0; i(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i -void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) { - float4 *af4 = reinterpret_cast(in_a); - auto *bf4 = reinterpret_cast<__half2*>(in_b); - auto *c = reinterpret_cast<_Float16*>(out_c); - const int d = M/2; - const int NUM_THREADS = K*2/16; - int NUM_BLOCKS = M/rows_per_block; - if (rows_per_block==2) { - LLGemm_Silu_kernel<2><<>>(af4, bf4, c, d); - } - else if (rows_per_block==4) { - LLGemm_Silu_kernel<4><<>>(af4, bf4, c, d); - } - else if (rows_per_block==8) { - LLGemm_Silu_kernel<8><<>>(af4, bf4, c, d); - } - else if (rows_per_block==16) { - LLGemm_Silu_kernel<16><<>>(af4, bf4, c, d); - } - else { - NUM_BLOCKS = M/4; - LLGemm_Silu_kernel<4><<>>(af4, bf4, c, d); - } - - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +// template +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<_Float16*>(out_c); + const int d = M / 2; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm_Silu_kernel<2> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 4) { + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 8) { + LLGemm_Silu_kernel<8> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 16) { + LLGemm_Silu_kernel<16> + <<>>(af4, bf4, c, d); + } else { + NUM_BLOCKS = M / 4; + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } - diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index 6c9e84ab2f5f4..dcabc7932cfd5 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -1,4 +1,4 @@ -//TODO: add license terms +// TODO: add license terms #include #include #include @@ -14,9 +14,12 @@ #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; -using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; -typedef struct _Half8 { _Half4 xy[2]; } _Half8; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; ////// Non temporal load stores /////// #if 1 @@ -39,68 +42,62 @@ __device__ __forceinline__ T load(const T* addr) { } template <> -__device__ __forceinline__ -float2 load (const float2* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ float2 load(const float2* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); + auto ret = reinterpret_cast(&result); return ret[0]; } template <> -__device__ __forceinline__ -float4 load (const float4* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ float4 load(const float4* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result1 = __builtin_nontemporal_load(addr_alias); auto result2 = __builtin_nontemporal_load(addr_alias + 1); float4 ret{}; - auto ret_alias = reinterpret_cast(&result1); + auto ret_alias = reinterpret_cast(&result1); ret.x = ret_alias->x; ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); + ret_alias = reinterpret_cast(&result2); ret.z = ret_alias->x; ret.w = ret_alias->y; return ret; } template <> -__device__ __forceinline__ -__half load (const __half* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ __half load(const __half* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half *>(&result); + auto ret = reinterpret_cast<__half*>(&result); return ret[0]; } template <> -__device__ __forceinline__ -__half2 load (const __half2* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ __half2 load(const __half2* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast<__half2 *>(&result); + auto ret = reinterpret_cast<__half2*>(&result); return ret[0]; } template <> -__device__ __forceinline__ -vllm::Half4_ load (const vllm::Half4_* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ vllm::Half4_ load(const vllm::Half4_* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result = __builtin_nontemporal_load(addr_alias); - auto ret = reinterpret_cast(&result); + auto ret = reinterpret_cast(&result); return ret[0]; } template <> -__device__ __forceinline__ -vllm::Half8_ load (const vllm::Half8_* addr) { - auto addr_alias { reinterpret_cast(addr) }; +__device__ __forceinline__ vllm::Half8_ load(const vllm::Half8_* addr) { + auto addr_alias{reinterpret_cast(addr)}; auto result1 = __builtin_nontemporal_load(addr_alias); auto result2 = __builtin_nontemporal_load(addr_alias + 1); - vllm::Half8_ ret {}; - auto ret_alias = reinterpret_cast(&result1); + vllm::Half8_ ret{}; + auto ret_alias = reinterpret_cast(&result1); ret.x = ret_alias->x; ret.y = ret_alias->y; - ret_alias = reinterpret_cast(&result2); + ret_alias = reinterpret_cast(&result2); ret.z = ret_alias->x; ret.w = ret_alias->y; return ret; @@ -116,394 +113,456 @@ __device__ __forceinline__ void store(T value, T* addr) { /////////////////////////////////////// -//grid (num_seqs, num_partitions,num_heads/gqa_ratio) -//block (partition size) -template +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] #if 0 scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] #endif - int max_ctx_blocks - ) { - constexpr int NWARPS = NUM_THREADS/WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; - const int lane4id = laneid%4; - - const int seq_idx = blockIdx.x; - const int partition_idx = blockIdx.y; - const int partition_size = blockDim.x; - const int max_num_partitions = gridDim.y; - - const int context_len = context_lens[seq_idx]; - const int partition_start_token_idx = partition_idx * partition_size; - //exit if partition is out of context for seq - if (partition_start_token_idx >= context_len) { - return; - } - constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO,4); // each 4 lanes fetch 4 different qheads, total qheads =8, so qhloop is 2 - constexpr int GQA_RATIO4 = 4*QHLOOP; - __shared__ float shared_qk_max[NWARPS][GQA_RATIO4+1]; - __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4+1]; - _Half8 Qlocal[QHLOOP]; - constexpr int x = 16 / sizeof(scalar_t); - constexpr int KHELOOP = HEAD_SIZE/x; - _Half8 Klocal[KHELOOP]; - constexpr int VHELOOP = HEAD_SIZE/WARP_SIZE; //v head_size dimension is distributed across lanes - constexpr int VTLOOP = 8; //16 separate 4xtokens across warp -> 16/2 8xtokens - _Half8 Vlocal[VHELOOP][VTLOOP]; - floatx4 dout[QHLOOP]; - float qk_max[QHLOOP]; - #pragma unroll - for (int h=0; h= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _Half8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _Half8 Klocal[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _Half8 Vlocal[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } - const int warp_start_token_idx = partition_start_token_idx + warpid*WARP_SIZE; + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; - if (warp_start_token_idx >= context_len) { //warp out of context - #pragma unroll - for(int h=0;h= context_len) { // warp out of context +#pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; +#pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } - const int local_token_idx = threadIdx.x; - const int global_token_idx = partition_start_token_idx + local_token_idx; + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + const _Half8* k_ptrh8 = reinterpret_cast(k_ptr); - const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 - //int32 physical_block_number leads to overflow when multiplied with kv_block_stride - const int64_t physical_block_number = static_cast(block_table[block_idx]); +#pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } - //each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems - const scalar_t* q_ptr = q + seq_idx*q_stride + wg_start_head_idx*HEAD_SIZE; - const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); - const int qhead_elemh8 = laneid/4; - #pragma unroll - for (int h=0; h(k_ptr); - - const int physical_block_offset = local_token_idx%BLOCK_SIZE; //since x=half8, physical_block_offset is already cast as _H8 - + } - #pragma unroll - for (int d=0;d(v_ptr); +// iterate over each v block +#pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _Half8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; +// iterate over each head elem (within head_size) +#pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _Half8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; +// iterate over all velems within block +#pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } + } - constexpr int VBLOCKS=8*VTLOOP/BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - - const int warp_start_block_idx = warp_start_token_idx/BLOCK_SIZE; - //fetch vphysical block numbers - #pragma unroll - for (int b=0;b 8) { + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); + dout[h] = + GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, + 10, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, + 10, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, + 11, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, + 11, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, + 12, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, + 12, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, + 13, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, + 13, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, + 14, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, + 14, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, + 15, 0); + dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, + 15, 0); + } // KHELOOP>8 + dout[h] *= scale; + } +// transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 +// lanes +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; +#pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } + dout[h] = tmp; + } - const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; - const _Half8* v_ptrh8 = reinterpret_cast(v_ptr); - //iterate over each v block - #pragma unroll - for (int b=0;b(vphysical_blocks[b]); - const _Half8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride)/8; - //iterate over each head elem (within head_size) - #pragma unroll - for (int h=0;h> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { +#pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } + } - #pragma unroll - for (int h=0;h8) { - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[8].xy[0], dout[h], 4, 8, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[8].xy[1], dout[h], 4, 8, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[9].xy[0], dout[h], 4, 9, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[9].xy[1], dout[h], 4, 9, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[10].xy[0], dout[h], 4, 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[10].xy[1], dout[h], 4, 10, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[11].xy[0], dout[h], 4, 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[11].xy[1], dout[h], 4, 11, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[12].xy[0], dout[h], 4, 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[12].xy[1], dout[h], 4, 12, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[13].xy[0], dout[h], 4, 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[13].xy[1], dout[h], 4, 13, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[14].xy[0], dout[h], 4, 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[14].xy[1], dout[h], 4, 14, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[15].xy[0], dout[h], 4, 15, 0); - dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[1], Klocal[15].xy[1], dout[h], 4, 15, 0); - } //KHELOOP>8 - dout[h]*=scale; - } - //transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 lanes - #pragma unroll - for (int h=0;h>2); - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll - for (int h=0;h=4; mask/=2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h],mask)); - } - } - - float exp_sum[QHLOOP]; - #pragma unroll - for (int h=0;h=4; mask/=2) { - exp_sum[h] += __shfl_xor(exp_sum[h],mask); - } - } +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; +#pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + float exp_sum[QHLOOP]; +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; +#pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } - #pragma unroll - for (int h=0;h every 4 lanes hold 4 heads, each lane holds 4 tokens, there are 4x16 tokens across warp - float16x4 logits[QHLOOP]; - #pragma unroll - for (int h=0;h every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + float16x4 logits[QHLOOP]; +#pragma unroll + for (int h = 0; h < QHLOOP; h++) { +#pragma unroll + for (int i = 0; i < 4; i++) { + logits[h][i] = (scalar_t)dout[h][i]; + } + } - __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS+1]; + __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; - if (warp_start_token_idx >= context_len) { //warp out of context - #pragma unroll - for (int qh=0; qh= context_len) { // warp out of context +#pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { +#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } } - else{//warp in context - //iterate across heads - #pragma unroll - for (int qh=0; qh partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; } - }//warp in context - - __syncthreads(); - - if (warpid==0) { - float16x4 vout[QHLOOP][VHELOOP]; - //iterate across heads - scalar_t* out_ptr; - int out_num_partitions; - if (context_len > partition_size) { - out_num_partitions = max_num_partitions; - out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; - } else { - out_num_partitions = 1; - out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; +#pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { +// iterate over each v head elem (within head_size) +#pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; +#pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] += vout_shared[qh][vh][laneid][w]; } - #pragma unroll - for (int qh=0; qh -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - //if num_partitions==1, main kernel will write to out directly, no work in reduction kernel - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warpid = threadIdx.x / WARP_SIZE; - const int laneid = threadIdx.x % WARP_SIZE; - - __shared__ float shared_global_exp_sum; - __shared__ float shared_exp_sums[2*WARP_SIZE]; - - if (warpid==0) { - - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } - //valid partition is the last valid partition in case threadid > num partitions - const int valid_partition = (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions-1; - const int valid_partition2 = (WARP_SIZE+threadIdx.x < num_partitions) ? WARP_SIZE+threadIdx.x : num_partitions-1; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; float reg_max_logit = max_logits_ptr[valid_partition]; float reg_max_logit2 = max_logits_ptr[valid_partition2]; - float max_logit = fmaxf(reg_max_logit,reg_max_logit2); + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); - #pragma unroll +#pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); } - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; float rescaled_exp_sum = exp_sums_ptr[valid_partition]; float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; - rescaled_exp_sum *= (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; - rescaled_exp_sum2 *= (threadIdx.x+WARP_SIZE < num_partitions) ? expf(reg_max_logit2 - max_logit) : 0.0f; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; shared_exp_sums[threadIdx.x] = rescaled_exp_sum; - shared_exp_sums[threadIdx.x+WARP_SIZE] = rescaled_exp_sum2; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; - #pragma unroll +#pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { global_exp_sum += __shfl_xor(global_exp_sum, mask); } - if (threadIdx.x==0) { + if (threadIdx.x == 0) { shared_global_exp_sum = global_exp_sum; } - }//warpid == 0 - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; - constexpr int MAX_NPAR = 64; - scalar_t tmps[MAX_NPAR]; - #pragma unroll - for (int j = 0; j < MAX_NPAR; j++) { - tmps[j] = 0.0f; + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; +#pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = 0.0f; + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + +#pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { +#pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; } - const int last_partition_offset = (num_partitions-1)*HEAD_SIZE; - const int num_partition_offset = (num_partitions)*HEAD_SIZE; - int idx=0; - constexpr int JCHUNK = 16; - - #pragma unroll - for (int j = 0; j < JCHUNK*HEAD_SIZE; j+=HEAD_SIZE) { - //lastj is last valid partition - const int lastj_offset = (j 2 * JCHUNK) { +#pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; tmps[idx] = tmp_out_ptr[lastj_offset]; idx++; + } } - __syncthreads(); - - if (num_partitions > JCHUNK) { - #pragma unroll - for (int j = JCHUNK*HEAD_SIZE; j < 2*JCHUNK*HEAD_SIZE; j+=HEAD_SIZE) { - const int lastj_offset = (j JCHUNK - if (num_partitions > 2*JCHUNK) { - #pragma unroll - for (int j = 2*JCHUNK*HEAD_SIZE; j < MAX_NPAR*HEAD_SIZE; j+=HEAD_SIZE) { - const int lastj_offset = (j JCHUNK - - // Aggregate tmp_out to out. - float acc = 0.0f; - #pragma unroll - for (int j = 0; j < JCHUNK; j++) { + // Aggregate tmp_out to out. + float acc = 0.0f; +#pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { +#pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { acc += tmps[j] * shared_exp_sums[j]; } - if (num_partitions > JCHUNK) { - #pragma unroll - for (int j = JCHUNK; j < 2*JCHUNK; j++) { - acc += tmps[j] * shared_exp_sums[j]; - } - if (num_partitions > 2*JCHUNK) { - #pragma unroll - for (int j = 2*JCHUNK; j < MAX_NPAR; j++) { - acc += tmps[j] * shared_exp_sums[j]; - } - } + if (num_partitions > 2 * JCHUNK) { +#pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += tmps[j] * shared_exp_sums[j]; + } } + } - if (num_partitions > MAX_NPAR) { - idx=0; - #pragma unroll - for (int j = MAX_NPAR*HEAD_SIZE; j < 2*MAX_NPAR*HEAD_SIZE; j+=HEAD_SIZE) { - //lastj is last valid partition - const int lastj_offset = (j MAX_NPAR) { + idx = 0; +#pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; } - const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); - acc *= inv_global_exp_sum; - //from_float(out_ptr[threadIdx.x], acc); - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = (scalar_t)acc; +#pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += tmps[j] * shared_exp_sums[j + MAX_NPAR]; + } } + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + // from_float(out_ptr[threadIdx.x], acc); + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = (scalar_t)acc; +} + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks); -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ - <<>>( \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - num_kv_heads, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,out_ptr,max_ctx_blocks); - -template +template void paged_attention_custom_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - const int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, #if 0 torch::Tensor& qk_out, torch::Tensor& softmax_out, #endif - const c10::optional& alibi_slopes) { + const c10::optional& alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); @@ -712,9 +774,10 @@ void paged_attention_custom_launcher( int kv_head_stride = key_cache.stride(1); // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -731,116 +794,152 @@ void paged_attention_custom_launcher( #endif const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); - const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - const int gqa_ratio = num_heads/num_kv_heads; - assert(num_heads%num_kv_heads==0); - assert(head_size==HEAD_SIZE); - assert(max_num_partitions<=128); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); constexpr int NTHR = PARTITION_SIZE; - dim3 grid(num_seqs,max_num_partitions,num_kv_heads); + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (gqa_ratio) { - case 1: LAUNCH_CUSTOM_ATTENTION(1); break; - case 2: LAUNCH_CUSTOM_ATTENTION(2); break; - case 3: LAUNCH_CUSTOM_ATTENTION(3); break; - case 4: LAUNCH_CUSTOM_ATTENTION(4); break; - case 5: LAUNCH_CUSTOM_ATTENTION(5); break; - case 6: LAUNCH_CUSTOM_ATTENTION(6); break; - case 7: LAUNCH_CUSTOM_ATTENTION(7); break; - case 8: LAUNCH_CUSTOM_ATTENTION(8); break; - case 9: LAUNCH_CUSTOM_ATTENTION(9); break; - case 10: LAUNCH_CUSTOM_ATTENTION(10); break; - case 11: LAUNCH_CUSTOM_ATTENTION(11); break; - case 12: LAUNCH_CUSTOM_ATTENTION(12); break; - case 13: LAUNCH_CUSTOM_ATTENTION(13); break; - case 14: LAUNCH_CUSTOM_ATTENTION(14); break; - case 15: LAUNCH_CUSTOM_ATTENTION(15); break; - case 16: LAUNCH_CUSTOM_ATTENTION(16); break; - default: - TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); - break; + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; } - //dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); - //dim3 block2(1024); - // LAUNCH_CUSTOM_ATTENTION2; - - //reduction kernel is only required if max_context_len > partition size, otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max supported by graphing, not the actual max among - // all the sequences: in that case reduction kernel will still run but return immediately + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); paged_attention_ll4mi_reduce_kernel - <<>>( - out_ptr, - exp_sums_ptr, - max_logits_ptr, - tmp_out_ptr, - context_lens_ptr, - max_num_partitions); + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); } } -#define CALL_CUSTOM_LAUNCHER(T,BLK_SIZE,HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len,\ - alibi_slopes); - -#define CALL_CUSTOM_LAUNCHER_BLK(T,HEAD_SIZE) \ - switch (block_size) { \ - case 8: CALL_CUSTOM_LAUNCHER(T,8,HEAD_SIZE); break; \ - case 16: CALL_CUSTOM_LAUNCHER(T,16,HEAD_SIZE); break; \ - case 32: CALL_CUSTOM_LAUNCHER(T,32,HEAD_SIZE); break; \ - default: TORCH_CHECK(false, "Unsupported block size: ", block_size); break; \ - } +#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \ + switch (block_size) { \ + case 8: \ + CALL_CUSTOM_LAUNCHER(T, 8, HEAD_SIZE); \ + break; \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } -#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ - switch (head_size) { \ - case 64: CALL_CUSTOM_LAUNCHER_BLK(T,64); break; \ - case 128: CALL_CUSTOM_LAUNCHER_BLK(T,128); break; \ - default: TORCH_CHECK(false, "Unsupported head size: ", head_size); break; \ - } +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } void paged_attention_custom( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int num_kv_heads, float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, int max_context_len, #if 0 torch::Tensor& qk_out, torch::Tensor& softmax_out, #endif - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype) { - const int head_size = query.size(2); - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + const int head_size = query.size(2); + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } #undef WARP_SIZE diff --git a/csrc/ops.h b/csrc/ops.h index d6cdfab434f2c..aa015c3d5dc39 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -113,25 +113,17 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale); +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, + torch::Tensor& scale); #ifdef USE_ROCM -torch::Tensor fp8_gemm( - torch::Tensor& a, - torch::Tensor& b, - torch::Tensor& scaleA, - torch::Tensor& scaleB, - torch::Tensor& scaleD, - int algo_idx -); - -torch::Tensor fp8_gemm_16( - torch::Tensor& a, - torch::Tensor& b, - torch::Tensor& scaleA, - torch::Tensor& scaleB, - int algo_idx -); +torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + torch::Tensor& scaleD, int algo_idx); + +torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + int algo_idx); #endif void moe_align_block_size(torch::Tensor topk_ids, int num_experts, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a4693ccc2ae75..a507af396bcf9 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,8 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Aligning the number of tokens to be processed by each expert such " "that it is divisible by the block size."); ops.def("convert_fp8", &convert_fp8, - "Convert the key and value cache to fp8 data type"); - + "Convert the key and value cache to fp8 data type"); + #ifdef USE_ROCM ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM with fp8 output"); ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM with fp16 output"); diff --git a/csrc/quantization/fp8/amd/gemm_kernel.cu b/csrc/quantization/fp8/amd/gemm_kernel.cu index 5464e9381e343..f8586b77d7792 100644 --- a/csrc/quantization/fp8/amd/gemm_kernel.cu +++ b/csrc/quantization/fp8/amd/gemm_kernel.cu @@ -12,258 +12,290 @@ #define max_workspace_size 2 * 128 * 1024 * 1024 #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if (error != hipSuccess) { \ - fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", hipGetErrorString(error), error, __FILE__, __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLASLT_ERROR -#define CHECK_HIPBLASLT_ERROR(error) \ - if (error != HIPBLAS_STATUS_SUCCESS) { \ - fprintf( \ - stderr, "hipBLASLt error: '%s'(%d) at %s:%d\n", hipblasStatusToString(error), error, __FILE__, __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIPBLASLT_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLASLt error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif -torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, - torch::Tensor& scaleD, int algo_idx) -{ - auto a_strides{a.strides()}; - auto b_strides{b.strides()}; - auto a_sizes{a.sizes()}; - auto b_sizes{b.sizes()}; - - // CHECK_INPUT(a); - // CHECK_INPUT(b); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, - "The input tensors should be in fp8."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); - TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); - - auto options{at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)}; - auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; - - constexpr bool transpose_result = true; - bool transpose_a; - bool transpose_b; - if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { - transpose_b = false; - } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { - transpose_b = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); +torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + torch::Tensor& scaleD, int algo_idx) { + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && + b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{ + at::TensorOptions().dtype(torch::kFloat8_e4m3fnuz).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && + (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && + (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && + (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && + (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), + // sizeof(float), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(d_scaleB, + // &(transpose_result ? scaleA : scaleB), sizeof(float), + // hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + auto d_scaleD = scaleD.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::cuda::getCurrentCUDAStream(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, + HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue + epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. + // (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + inputs.scaleD = d_scaleD; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR( + gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp8 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } } - if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { - transpose_a = false; - } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { - transpose_a = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) { - bool tmp = transpose_a; - transpose_a = !transpose_b; - transpose_b = !tmp; - a_strides = b.strides(); - b_strides = a.strides(); - a_sizes = b.sizes(); - b_sizes = a.sizes(); - } - - float alpha = 1.0f; - float beta = 0.0f; - int64_t m = a_sizes[transpose_result ? 1 : 0]; - int64_t k = a_sizes[transpose_result ? 0 : 1]; - int64_t n = b_sizes[transpose_result ? 0 : 1]; - - void* d_a = static_cast((transpose_result ? b : a).data_ptr()); - void* d_b = static_cast((transpose_result ? a : b).data_ptr()); - void* d_d = static_cast(result.data_ptr()); - - // void *d_scaleA, *d_scaleB, *d_workspace; - // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); - auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); - auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); - auto d_scaleD = scaleD.data_ptr(); - - auto handle = at::cuda::getCurrentCUDABlasLtHandle(); - auto stream = at::cuda::getCurrentCUDAStream(); - - hipblaslt_ext::GemmPreference gemmPref; - gemmPref.setMaxWorkspaceBytes(0); - hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, - HIP_R_8F_E4M3_FNUZ, HIPBLAS_COMPUTE_32F); - - hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) - hipblaslt_ext::GemmInputs inputs; - inputs.a = d_a; - inputs.b = d_b; - inputs.c = d_d; - inputs.d = d_d; - inputs.alpha = α - inputs.beta = β - inputs.scaleA = d_scaleA; - inputs.scaleB = d_scaleB; - inputs.scaleD = d_scaleD; - gemm.setProblem(m, n, k, 1, epilogue, inputs); - if (algo_idx == 0) { - constexpr int request_solutions = 1024; - std::vector heuristicResult; - heuristicResult.reserve(request_solutions); - CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); - static size_t solSize = 0; - if (heuristicResult.size() != solSize) { - std::cout << "fp8 sols: " << heuristicResult.size() << "\n"; - solSize = heuristicResult.size(); - for (auto& res : heuristicResult) { - auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); - std::cout << idx << "\n"; - } - } - TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); - algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); - } - std::vector algoIndex(1); - algoIndex[0] = algo_idx; - std::vector tmpAlgo; - TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); - - CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); - CHECK_HIPBLASLT_ERROR(gemm.run(stream)); - - // hipFree(d_scaleA); - // hipFree(d_scaleB); - - return result; + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; } -torch::Tensor fp8_gemm_16( - torch::Tensor& a, torch::Tensor& b, torch::Tensor& scaleA, torch::Tensor& scaleB, int algo_idx) -{ - auto a_strides{a.strides()}; - auto b_strides{b.strides()}; - auto a_sizes{a.sizes()}; - auto b_sizes{b.sizes()}; - - // CHECK_INPUT(a); - // CHECK_INPUT(b); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && b.dtype() == torch::kFloat8_e4m3fnuz, - "The input tensors should be in fp8."); - TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); - TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); - - auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)}; - auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; - - constexpr bool transpose_result = true; - bool transpose_a; - bool transpose_b; - if ((b_strides[0] == 1) && (b_strides[1] >= std::max(1, b_sizes[0]))) { - transpose_b = false; - } else if ((b_strides[1] == 1) && (b_strides[0] >= std::max(1, b_sizes[1]))) { - transpose_b = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); +torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b, + torch::Tensor& scaleA, torch::Tensor& scaleB, + int algo_idx) { + auto a_strides{a.strides()}; + auto b_strides{b.strides()}; + auto a_sizes{a.sizes()}; + auto b_sizes{b.sizes()}; + + // CHECK_INPUT(a); + // CHECK_INPUT(b); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fnuz && + b.dtype() == torch::kFloat8_e4m3fnuz, + "The input tensors should be in fp8."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Input tensors must be 2-D."); + TORCH_CHECK(a_sizes[1] == b_sizes[0], "a dim 1 must match b dim 0."); + + auto options{at::TensorOptions().dtype(torch::kFloat16).device(at::kCUDA)}; + auto result{torch::empty({a_sizes[0], b_sizes[1]}, options)}; + + constexpr bool transpose_result = true; + bool transpose_a; + bool transpose_b; + if ((b_strides[0] == 1) && + (b_strides[1] >= std::max(1, b_sizes[0]))) { + transpose_b = false; + } else if ((b_strides[1] == 1) && + (b_strides[0] >= std::max(1, b_sizes[1]))) { + transpose_b = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((a_strides[0] == 1) && + (a_strides[1] >= std::max(1, a_sizes[0]))) { + transpose_a = false; + } else if ((a_strides[1] == 1) && + (a_strides[0] >= std::max(1, a_sizes[1]))) { + transpose_a = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_a; + transpose_a = !transpose_b; + transpose_b = !tmp; + a_strides = b.strides(); + b_strides = a.strides(); + a_sizes = b.sizes(); + b_sizes = a.sizes(); + } + + float alpha = 1.0f; + float beta = 0.0f; + int64_t m = a_sizes[transpose_result ? 1 : 0]; + int64_t k = a_sizes[transpose_result ? 0 : 1]; + int64_t n = b_sizes[transpose_result ? 0 : 1]; + + void* d_a = static_cast((transpose_result ? b : a).data_ptr()); + void* d_b = static_cast((transpose_result ? a : b).data_ptr()); + void* d_d = static_cast(result.data_ptr()); + + // void *d_scaleA, *d_scaleB, *d_workspace; + // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); + // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); + // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), + // sizeof(float), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(d_scaleB, + // &(transpose_result ? scaleA : scaleB), sizeof(float), + // hipMemcpyHostToDevice)); + auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); + auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); + + auto handle = at::cuda::getCurrentCUDABlasLtHandle(); + auto stream = at::cuda::getCurrentCUDAStream(); + + hipblaslt_ext::GemmPreference gemmPref; + gemmPref.setMaxWorkspaceBytes(0); + hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, + HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, + HIP_R_16F, HIPBLAS_COMPUTE_32F); + + hipblaslt_ext::GemmEpilogue + epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. + // (Gemm only) + hipblaslt_ext::GemmInputs inputs; + inputs.a = d_a; + inputs.b = d_b; + inputs.c = d_d; + inputs.d = d_d; + inputs.alpha = α + inputs.beta = β + inputs.scaleA = d_scaleA; + inputs.scaleB = d_scaleB; + gemm.setProblem(m, n, k, 1, epilogue, inputs); + if (algo_idx == 0) { + constexpr int request_solutions = 1024; + std::vector heuristicResult; + heuristicResult.reserve(request_solutions); + CHECK_HIPBLASLT_ERROR( + gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); + static size_t solSize = 0; + if (heuristicResult.size() != solSize) { + std::cout << "fp16 sols: " << heuristicResult.size() << "\n"; + solSize = heuristicResult.size(); + for (auto& res : heuristicResult) { + auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); + std::cout << idx << "\n"; + } } - if ((a_strides[0] == 1) && (a_strides[1] >= std::max(1, a_sizes[0]))) { - transpose_a = false; - } else if ((a_strides[1] == 1) && (a_strides[0] >= std::max(1, a_sizes[1]))) { - transpose_a = true; - } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); - } - - if (transpose_result) { - bool tmp = transpose_a; - transpose_a = !transpose_b; - transpose_b = !tmp; - a_strides = b.strides(); - b_strides = a.strides(); - a_sizes = b.sizes(); - b_sizes = a.sizes(); - } - - float alpha = 1.0f; - float beta = 0.0f; - int64_t m = a_sizes[transpose_result ? 1 : 0]; - int64_t k = a_sizes[transpose_result ? 0 : 1]; - int64_t n = b_sizes[transpose_result ? 0 : 1]; - - void* d_a = static_cast((transpose_result ? b : a).data_ptr()); - void* d_b = static_cast((transpose_result ? a : b).data_ptr()); - void* d_d = static_cast(result.data_ptr()); - - // void *d_scaleA, *d_scaleB, *d_workspace; - // CHECK_HIP_ERROR(hipMalloc(&d_scaleA, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_scaleB, sizeof(float))); - // CHECK_HIP_ERROR(hipMalloc(&d_workspace, max_workspace_size)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleA, &(transpose_result ? scaleB : scaleA), sizeof(float), hipMemcpyHostToDevice)); - // CHECK_HIP_ERROR(hipMemcpy(d_scaleB, &(transpose_result ? scaleA : scaleB), sizeof(float), hipMemcpyHostToDevice)); - auto d_scaleA = transpose_result ? scaleB.data_ptr() : scaleA.data_ptr(); - auto d_scaleB = transpose_result ? scaleA.data_ptr() : scaleB.data_ptr(); - - auto handle = at::cuda::getCurrentCUDABlasLtHandle(); - auto stream = at::cuda::getCurrentCUDAStream(); - - hipblaslt_ext::GemmPreference gemmPref; - gemmPref.setMaxWorkspaceBytes(0); - hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F, HIP_R_16F, - HIPBLAS_COMPUTE_32F); - - hipblaslt_ext::GemmEpilogue epilogue{}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) - hipblaslt_ext::GemmInputs inputs; - inputs.a = d_a; - inputs.b = d_b; - inputs.c = d_d; - inputs.d = d_d; - inputs.alpha = α - inputs.beta = β - inputs.scaleA = d_scaleA; - inputs.scaleB = d_scaleB; - gemm.setProblem(m, n, k, 1, epilogue, inputs); - if (algo_idx == 0) { - constexpr int request_solutions = 1024; - std::vector heuristicResult; - heuristicResult.reserve(request_solutions); - CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); - static size_t solSize = 0; - if (heuristicResult.size() != solSize) { - std::cout << "fp16 sols: " << heuristicResult.size() << "\n"; - solSize = heuristicResult.size(); - for (auto& res : heuristicResult) { - auto idx = hipblaslt_ext::getIndexFromAlgo(res.algo); - std::cout << idx << "\n"; - } - } - algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); - TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); - } - std::vector algoIndex(1); - algoIndex[0] = algo_idx; - std::vector tmpAlgo; - TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); - - CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); - CHECK_HIPBLASLT_ERROR(gemm.run(stream)); - - // hipFree(d_scaleA); - // hipFree(d_scaleB); - - return result; + algo_idx = hipblaslt_ext::getIndexFromAlgo(heuristicResult[0].algo); + TORCH_CHECK(!heuristicResult.empty(), "No valid solution found!"); + } + std::vector algoIndex(1); + algoIndex[0] = algo_idx; + std::vector tmpAlgo; + TORCH_CUDABLAS_CHECK( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo)); + + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); + CHECK_HIPBLASLT_ERROR(gemm.run(stream)); + + // hipFree(d_scaleA); + // hipFree(d_scaleB); + + return result; } \ No newline at end of file diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index 23d975fe0f37e..8a35467edbc21 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -307,344 +307,351 @@ vec_conversion(const Float8_& a) { // fp8 -> half template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint8_t& a, float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - __half_raw res; - res.data = static_cast(f8) * scale; - return res.x; +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; } // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint16_t& a, float scale) -{ -#if defined(__HIP__MI300__) - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - union { - __half2_raw h2r; - uint32_t ui32; - } tmp; - tmp.h2r.x.data = f2[0] * scale; - tmp.h2r.y.data = f2[1] * scale; - return tmp.ui32; -#else - union { - uint16_t u16[2]; - uint32_t u32; - } tmp; - - tmp.u16[0] = scaled_vec_conversion(static_cast(a), scale); - tmp.u16[1] = scaled_vec_conversion(static_cast(a >> 8U), scale); - return tmp.u32; -#endif +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif } // fp8x4 -> half2x2 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint32_t& a, float scale) -{ - union { - uint2 u32x2; - uint32_t u32[2]; - } tmp; - tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); - tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return tmp.u32x2; +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; } // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, float scale) -{ - union { - uint4 u64x2; - uint2 u64[2]; - } tmp; - tmp.u64[0] = scaled_vec_conversion(a.x, scale); - tmp.u64[1] = scaled_vec_conversion(a.y, scale); - return tmp.u64x2; +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; } using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> -__inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) -{ - hip_fp8 f8{a, hip_fp8::from_bits()}; - float f{f8}; - return __float2bfloat16(f * scale); +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); } using __nv_bfloat162 = __hip_bfloat162; // fp8x2 -> __nv_bfloat162 template <> -__inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, float scale) -{ - __nv_bfloat162 res; - res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); - return res; +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; } // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion(const uint32_t& a, float scale) -{ - bf16_4_t res; - res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); - res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; } // fp8x8 -> bf16_8_t template <> -__inline__ __device__ bf16_8_t scaled_vec_conversion(const uint2& a, float scale) -{ - bf16_4_t tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - bf16_8_t res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // fp8 -> float template <> -__inline__ __device__ float scaled_vec_conversion(const uint8_t& a, float scale) -{ - hip_fp8 fp8{a, hip_fp8::from_bits()}; - return static_cast(fp8) * scale; +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; } // fp8x2 -> float2 template <> -__inline__ __device__ float2 scaled_vec_conversion(const uint16_t& a, float scale) -{ -#if defined(__HIP__MI300__) - float2 res; - const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); - res.x = f2[0] * scale; - res.y = f2[1] * scale; - return res; -#else - float2 res; - res.x = scaled_vec_conversion(static_cast(a), scale); - res.y = scaled_vec_conversion(static_cast(a >> 8U), scale); - return res; -#endif +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif } // fp8x4 -> float4 template <> -__inline__ __device__ Float4_ scaled_vec_conversion(const uint32_t& a, const float scale) -{ - Float4_ res; - res.x = scaled_vec_conversion((uint16_t)a, scale); - res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); - return res; +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; } // fp8x4 -> float4 template <> -__inline__ __device__ float4 scaled_vec_conversion(const uint32_t& a, float scale) -{ - Float4_ res = scaled_vec_conversion(a, scale); - return {res.x.x, res.x.y, res.y.x, res.y.y}; +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; } // fp8x8 -> float8 template <> -__inline__ __device__ Float8_ scaled_vec_conversion(const uint2& a, float scale) -{ - Float4_ tmp1, tmp2; - tmp1 = scaled_vec_conversion(a.x, scale); - tmp2 = scaled_vec_conversion(a.y, scale); - Float8_ res; - res.x = tmp1.x; - res.y = tmp1.y; - res.z = tmp2.x; - res.w = tmp2.y; - return res; +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; } // half -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const uint16_t& a, float scale) -{ - __half_raw tmp; - tmp.x = a; +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half_raw tmp; + tmp.x = a; - hip_fp8 f8{static_cast(tmp.data / scale)}; - return f8.data; + hip_fp8 f8{static_cast(tmp.data / scale)}; + return f8.data; } // halfx2 -> fp8x2 -template<> -__inline__ __device__ uint16_t scaled_vec_conversion(const uint32_t& a, float scale) -{ -#ifdef __HIP__MI300__ - union { - uint32_t ui32; - __half2_raw h2r; - } tmp; - tmp.ui32 = a; - - union { - uint32_t ui32; - float f; - } f1, f2; - f1.f = tmp.h2r.x.data / scale; - f2.f = tmp.h2r.y.data / scale; - if ((f1.ui32 & 0x7F800000) != 0x7F800000) { - f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); - } - if ((f2.ui32 & 0x7F800000) != 0x7F800000) { - f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); - } - return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); -#else - union { - uint32_t ui32; - __half2_raw h2r; - } tmp; - tmp.ui32 = a; - - union { - uint8_t ui8[2]; - uint16_t ui16; - } res; - res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); - res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); - return res.ui16; -#endif +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; + #endif } // half2x2 -> fp8x4 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const uint2& a, float scale) -{ - union { - uint16_t ui16[2]; - uint32_t ui32; - } tmp; - tmp.ui16[0] = scaled_vec_conversion(a.x, scale); - tmp.ui16[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui32; +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; } // half2x4 -> fp8x8 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, float scale) -{ - union { - uint2 ui2[2]; - uint4 ui4; - } tmp; - tmp.ui4 = a; - uint2 res; - res.x = scaled_vec_conversion(tmp.ui2[0], scale); - res.y = scaled_vec_conversion(tmp.ui2[1], scale); - return res; +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; } // bf16 -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const __nv_bfloat16& a, float scale) -{ - hip_fp8 res{__bfloat162float(a) / scale}; - return res.data; +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; } // bf16x2 -> fp8x2 template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const __nv_bfloat162& a, float scale) -{ - union { - uint8_t ui8[2]; - uint16_t ui16; - } tmp; - tmp.ui8[0] = scaled_vec_conversion(a.x, scale); - tmp.ui8[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui16; +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; } // bf16x4 -> fp8x4 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const bf16_4_t& a, float scale) -{ - union { - uint16_t ui16[2]; - uint32_t ui32; - } tmp; - tmp.ui16[0] = scaled_vec_conversion(a.x, scale); - tmp.ui16[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui32; +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; } // bf16x8 -> fp8x8 template <> -__inline__ __device__ uint2 scaled_vec_conversion(const bf16_8_t& a, float scale) -{ - uint2 res; - res.x = scaled_vec_conversion({a.x, a.y}, scale); - res.y = scaled_vec_conversion({a.z, a.w}, scale); - return res; +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; } // float -> fp8 template <> -__inline__ __device__ uint8_t scaled_vec_conversion(const float& a, float scale) -{ - hip_fp8 f8(a); - return f8.data; +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, float scale) { + hip_fp8 f8(a); + return f8.data; } // floatx2 -> fp8x2 template <> -__inline__ __device__ uint16_t scaled_vec_conversion(const float2& a, float scale) -{ -#ifdef __HIP__MI300__ - union { - uint32_t ui32; - float f; - } f1, f2; - f1.f = a.x / scale; - f2.f = a.y / scale; - if ((f1.ui32 & 0x7F800000) != 0x7F800000) { - f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); - } - if ((f2.ui32 & 0x7F800000) != 0x7F800000) { - f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); - } - return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f,f2.f, 0, 0); -#else - union { - uint8_t ui8[2]; - uint16_t ui16; - } tmp; - tmp.ui8[0] = scaled_vec_conversion(a.x, scale); - tmp.ui8[1] = scaled_vec_conversion(a.y, scale); - return tmp.ui16; -#endif +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; + #endif } // floatx4 -> fp8x4 template <> -__inline__ __device__ uint32_t scaled_vec_conversion(const float4& a, float scale) -{ - union { - uint16_t ui16[2]; - uint32_t ui32; - } tmp; - tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); - tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); - return tmp.ui32; +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; } #endif // ENABLE_FP8 diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 937df5a0bec13..bcb8fa514444d 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -14,33 +14,36 @@ #endif namespace vllm { - -template -__global__ void convert_fp8_kernel( - const Tin* __restrict__ src_data, Tout* __restrict__ dst_data, const float* scale, size_t N) -{ - const int64_t block_idx = blockIdx.x; - - using V_in_vec = typename Vec::Type; - using V_out_vec = typename Vec::Type; - auto dst_data_vec = reinterpret_cast(dst_data); - auto src_data_vec = reinterpret_cast(src_data); - int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); - auto idx = startIdx; - if (idx >= N) { - return; - } - dst_data_vec[idx] = fp8::scaled_vec_conversion(src_data_vec[idx], *scale); - //dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); - - //for (int64_t i = 0; i < loopSize; ++i) { - // auto idx = startIdx + i; - // if (idx >= N) { - // return; - // } - // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); - //} +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_data, + Tout* __restrict__ dst_data, + const float* scale, size_t N) { + const int64_t block_idx = blockIdx.x; + + using V_in_vec = typename Vec::Type; + using V_out_vec = typename Vec::Type; + auto dst_data_vec = reinterpret_cast(dst_data); + auto src_data_vec = reinterpret_cast(src_data); + + int64_t startIdx = (threadIdx.x + blockDim.x * blockIdx.x); + auto idx = startIdx; + if (idx >= N) { + return; + } + dst_data_vec[idx] = fp8::scaled_vec_conversion( + src_data_vec[idx], *scale); + // dst_data_vec[idx+1] = fp8_e4m3::vec_conversion(src_data_vec[idx+1], *scale); + + // for (int64_t i = 0; i < loopSize; ++i) { + // auto idx = startIdx + i; + // if (idx >= N) { + // return; + // } + // dst_data_vec[idx] = fp8_e4m3::vec_conversion(src_data_vec[idx], *scale); + // } } __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -117,7 +120,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, } } -} // namespace vllm +} // namespace vllm void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] @@ -158,54 +161,57 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] } template -struct call_convert_fp8 -{ - void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, torch::Tensor& scale) - { - const auto N = src_data.numel() / 2; - //std::cout << N << "\n"; - constexpr uint32_t loopSize = 1;//std::max(N / 50000000LL, 1); - constexpr dim3 numThreads{1024, 1, 1}; - auto neededBlocks = (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); - uint32_t actualBlocks = neededBlocks; - - //static uint32_t maxBlocks = 0; - //if (actualBlocks != maxBlocks) { - // maxBlocks = actualBlocks; - // std::cout << actualBlocks << "\n"; - //} - - const dim3 grid{actualBlocks, 1, 1}; - - const auto stream = at::cuda::getCurrentCUDAStream(); - - vllm::convert_fp8_kernel - <<>>(reinterpret_cast(src_data.data_ptr()), - reinterpret_cast(dst_data.data_ptr()), (float*)scale.data_ptr(), N); - } +struct call_convert_fp8 { + void operator()(torch::Tensor& src_data, torch::Tensor& dst_data, + torch::Tensor& scale) { + const auto N = src_data.numel() / 2; + // std::cout << N << "\n"; + constexpr uint32_t loopSize = 1; // std::max(N / 50000000LL, 1); + constexpr dim3 numThreads{1024, 1, 1}; + auto neededBlocks = + (N + (numThreads.x * loopSize) - 1) / (numThreads.x * loopSize); + uint32_t actualBlocks = neededBlocks; + + // static uint32_t maxBlocks = 0; + // if (actualBlocks != maxBlocks) { + // maxBlocks = actualBlocks; + // std::cout << actualBlocks << "\n"; + // } + + const dim3 grid{actualBlocks, 1, 1}; + + const auto stream = at::cuda::getCurrentCUDAStream(); + + vllm::convert_fp8_kernel + <<>>( + reinterpret_cast(src_data.data_ptr()), + reinterpret_cast(dst_data.data_ptr()), + (float*)scale.data_ptr(), N); + } }; -void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, torch::Tensor& scale) -{ - torch::Device src_device = src_data.device(); - torch::Device dst_device = dst_data.device(); - TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") - TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") - TORCH_CHECK(src_device.index() == dst_device.index(), "src and dst must be on the same GPU"); - at::cuda::OptionalCUDAGuard device_guard(src_device); - auto t1 = src_data.dtype(); - auto t2 = dst_data.dtype(); - if (src_data.dtype() == at::ScalarType::Float) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (src_data.dtype() == at::ScalarType::Half) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (src_data.dtype() == at::ScalarType::BFloat16) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::Float) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::Half) { - call_convert_fp8{}(src_data, dst_data, scale); - } else if (dst_data.dtype() == at::ScalarType::BFloat16) { - call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); - } +void convert_fp8(torch::Tensor& dst_data, torch::Tensor& src_data, + torch::Tensor& scale) { + torch::Device src_device = src_data.device(); + torch::Device dst_device = dst_data.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + auto t1 = src_data.dtype(); + auto t2 = dst_data.dtype(); + if (src_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (src_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Float) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::Half) { + call_convert_fp8{}(src_data, dst_data, scale); + } else if (dst_data.dtype() == at::ScalarType::BFloat16) { + call_convert_fp8<__nv_bfloat16, uint8_t, 2>{}(src_data, dst_data, scale); + } } \ No newline at end of file diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index ad6ec10892b6e..894c9e9dc6554 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -165,6 +165,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata + def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, @@ -185,43 +186,44 @@ def _make_alibi_bias( bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases -def _make_alibi_bias_v2( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], - make_attn_mask: bool = True -) -> List[torch.Tensor]: +def _make_alibi_bias_v2(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: attn_biases = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) return attn_biases - class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -384,8 +386,10 @@ def forward( if self.use_triton_flash_attn: if self.alibi_slopes is not None: att_masks = _make_alibi_bias_v2( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, make_attn_mask=False) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, @@ -402,20 +406,17 @@ def forward( elif self.use_naive_attn: if self.alibi_slopes is not None: att_masks = _make_alibi_bias_v2( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, make_attn_mask=True) # type: ignore + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=True) # type: ignore if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - out = self.attn_func( - query, - key, - value, - prefill_meta.seq_lens, - self.scale, - att_masks - ) + out = self.attn_func(query, key, value, + prefill_meta.seq_lens, self.scale, + att_masks) else: out = self.attn_func( q=query, @@ -486,7 +487,7 @@ def _naive_attention( key[start:end], value[start:end], scale, - attn_masks[i], + attn_masks[i] if attn_masks else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) @@ -505,13 +506,13 @@ def _naive_masked_attention( seq_len, head_size, head_dim = query.shape if attn_mask is None: attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) attn_mask = attn_mask * torch.finfo(query.dtype).min attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = attn_weights + attn_mask.float() attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out \ No newline at end of file + return out diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index c99029175b5a2..05134872ba39c 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -61,7 +61,8 @@ def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): @triton.jit -def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): +def load_fn(ptrs, offset_first, offset_second, boundary_first, + boundary_second): if offset_first is not None and offset_second is not None: mask = (offset_first[:, None] < boundary_first) & \ (offset_second[None, :] < boundary_second) @@ -78,38 +79,43 @@ def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): @triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, - actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, - block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr): +def _attn_fwd_inner( + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, + stride_bn, start_m, actual_seqlen_k, actual_seqlen_q, dropout_p, + philox_seed, batch_philox_offset, encoded_sm_ptrs, block_min, + block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - if MASK_STEPS: - k_offs_n = start_n + tl.arange(0, BLOCK_N) - else: - k_offs_n = None + k_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) - k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, + actual_seqlen_k) if PRE_LOAD_V: # We can use the same offsets as k, just with dims transposed. - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: - # If this is the last block / iteration, we want to - # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. - # last step might get wasted but that is okay. check if this masking works For - # that case. - if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): - boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + if MASK_STEPS: # NOQA: SIM102 + if start_n + BLOCK_N == block_max and n_extra_tokens != 0: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if + # not is_modulo_mn. Last step might get wasted but that is okay. + # Check if this masking works for that case. + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) size_n = start_n + OFFS_N[None, :] mask = size_n < boundary_m[:, None] qk = tl.where(mask, qk, float("-inf")) @@ -120,11 +126,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # -- compute qk ---- qk += tl.dot(q, k) if bias_ptrs is not None: - bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None - bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + bias_offs_n = start_n + tl.arange(0, + BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, + actual_seqlen_k) # While bias is added after multiplying qk with sm_scale, - # our optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. + # our optimization to use 2^x instead of e^x results in an + # additional scale factor of log2(e) which we must also multiply + # the bias with. qk += (bias * 1.44269504089) # softmax @@ -135,10 +144,15 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) if ENABLE_DROPOUT: - philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N - keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, + BLOCK_N, actual_seqlen_k) if RETURN_ENCODED_SOFTMAX: - tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) + tl.store( + encoded_sm_ptrs, + tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) p = tl.where(keep, p, 0.0) elif RETURN_ENCODED_SOFTMAX: tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) @@ -146,7 +160,8 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i @@ -260,14 +275,20 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit -def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, - stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, - stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, - MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): +def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, + stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, + stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, + stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, + stride_ah, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, + philox_offset_base, encoded_softmax, alibi_slopes, + HQ: tl.constexpr, HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr): start_m = tl.program_id(0) off_h_q = tl.program_id(1) off_z = tl.program_id(2) @@ -298,42 +319,46 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s # This block of code determines what N is, and if this WG is operating # on those M rows. n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if (IS_CAUSAL): + if IS_CAUSAL: # If seqlen_q == seqlen_k, the attn scores are a square matrix. # If seqlen_q != seqlen_k, attn scores are rectangular which means # the causal mask boundary is bottom right aligned, and ends at either # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn matrix - n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This captures the decrease in n_blocks if we have a rectangular + # attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) # This is what adjusts the block_max for the current WG, only # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is part of - # the blocks that are all 0. We exit early. + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. if n_blocks <= 0: - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_on acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) o_ptrs_mask = offs_m[:, None] < seqlen_q # We still need to write 0s to the result tl.store(o_ptrs, acc, mask=o_ptrs_mask) # The tensor allocated for L is based on MAX_SEQLENS_Q as that is # statically known. - l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # We store inf to LSE, not -inf because in the bwd pass, we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. - l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # We store inf to LSE, not -inf because in the bwd pass, we subtract + # this from qk which makes it -inf, such that exp(qk - inf) = 0 for + # these masked blocks. + l = tl.full( # NOQA: E741 + [BLOCK_M], value=float("inf"), dtype=tl.float32) l_ptrs_mask = offs_m < MAX_SEQLENS_Q tl.store(l_ptrs, l, mask=l_ptrs_mask) - # TODO: Should dropout and return encoded softmax be handled here too? + # TODO: Should dropout & return encoded softmax be handled here too? return # If MQA / GQA, set the K and V head offsets appropriately. GROUP_SIZE: tl.constexpr = HQ // HK - if GROUP_SIZE != 1: - off_h_k = off_h_q // GROUP_SIZE - else: - off_h_k = off_h_q + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: @@ -343,16 +368,23 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) # Compute pointers for all the tensors used in this kernel. - q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn - v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + q_ptrs = (q_offset + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk) + k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + k_ptrs = (k_offset + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn) + v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + v_ptrs = (v_offset + offs_n[:, None] * stride_vk + + offs_d[None, :] * stride_vn) if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh - bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[ + None, :] * stride_bn else: bias_ptrs = None @@ -367,11 +399,13 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k else: batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. In + # We can ask to return the dropout mask without actually doing dropout. In # this case, we return an invalid pointer so indicate the mask is not valid. if RETURN_ENCODED_SOFTMAX: encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k - encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] + encoded_sm_ptrs = encoded_sm_base + offs_m[:, + None] * seqlen_k + offs_n[ + None, :] else: encoded_sm_ptrs = None # initialize pointer to m and l @@ -398,50 +432,105 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s else: # Padding on Q does not need to be masked in the FA loop. masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. - # In this case we might exceed n_blocks so pick the min. + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. masked_blocks = min(masked_blocks, n_blocks) n_full_blocks = n_blocks - masked_blocks block_min = 0 block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its actual - # value because there is no masking. Similarly we do not need padding. + # Compute for full blocks. Here we set causal to false unconditionally + # because there is no masking. Similarly we do not need padding. if n_full_blocks > 0: block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, block_max, 0, 0, 0, alibi_slope, - # IS_CAUSAL, .... - False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + alibi_slope, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) block_min = block_max block_max = n_blocks * BLOCK_N tl.debug_barrier() # Remaining blocks, if any, are full / not masked. if (masked_blocks > 0): - if IS_CAUSAL: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) - else: - offs_n_causal = 0 + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 k_ptrs += n_full_blocks * BLOCK_N * stride_kn v_ptrs += n_full_blocks * BLOCK_N * stride_vk if USE_BIAS: bias_ptrs += n_full_blocks * BLOCK_N * stride_bn if RETURN_ENCODED_SOFTMAX: encoded_sm_ptrs += n_full_blocks * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, - start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, batch_philox_offset, - encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, - n_extra_tokens, alibi_slope, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, + ACTUAL_BLOCK_DMODEL) # epilogue acc = acc / l_i[:, None] if ENABLE_DROPOUT: @@ -454,28 +543,36 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: + if IS_CAUSAL: # NOQA: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + out_ptrs_mask = mask_m_offsets[:, + None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. - # This is only true for the last M block. For others, overflow_size will be -ve + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few + # rows. This is only true for the last M block. For others, overflow_size + # will be -ve overflow_size = end_m_idx - seqlen_q if overflow_size > 0: - boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + boundary = tl.full((BLOCK_M, ), + BLOCK_M - overflow_size, + dtype=tl.int32) l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) else: tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O - o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) @@ -595,7 +692,9 @@ def forward( else: bias_strides = (0, 0, 0, 0) alibi_strides = (0, 0) - M = torch.empty((batch, nheads_q, max_seqlens_q), device=q.device, dtype=torch.float32) + M = torch.empty((batch, nheads_q, max_seqlens_q), + device=q.device, + dtype=torch.float32) attn_fwd[grid]( q, diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index b04adb532dd38..d5acc965ad200 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -6,13 +6,14 @@ import torch from torch.distributed import ProcessGroup +from vllm.utils import is_hip + from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_ca_communicator, get_tp_pynccl_communicator) -from vllm.utils import is_hip @dataclass diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6d1b3b47c3dd7..6e9b017ea93b3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -188,7 +188,8 @@ def initialize_model_parallel( _TP_CPU_GROUP = cpu_group if tensor_model_parallel_size > 1 and not is_hip(): - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( group=_TP_CPU_GROUP, device=_LOCAL_RANK, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 476301a216c48..00890a49b9be3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm import _custom_C from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -16,7 +15,6 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index b963576aa4471..228a646d471f6 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -23,7 +23,7 @@ "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, - "fp8": Fp8Config if not is_hip() else Fp8RocmConfig, + "fp8": Fp8Config if not is_hip() else Fp8RocmConfig, # type: ignore # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index caa53fb6ceee8..ddccc5825c8a4 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -1,22 +1,19 @@ -from typing import Any, Dict, List, Optional, Tuple, Union, Iterator +import os +from typing import List, Optional, Tuple, Union +import pandas as pd import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -import torch.nn.functional as F -from safetensors import safe_open -from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs -import pandas as pd -import os - -try: +try: # NOQA: SIM105 from vllm._C import ops as vllm_ops except ImportError: pass @@ -25,10 +22,10 @@ class Fp8RocmConfig(QuantizationConfig): + def __init__(self) -> None: # self.quantized_weights_path = config["quantized_weights"] self._tuned = {} - self._stats = {} gemm_type = os.getenv("FP8_GEMM", "fp8_16") #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": @@ -38,10 +35,14 @@ def __init__(self) -> None: self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 tuned_filename = "/projects/tuned_fp8_16.csv" else: - raise Exception(f"Unknown fp8 gemm type: {gemm_type}") + raise ValueError(f"Unknown fp8 gemm type: {gemm_type}") try: df = pd.read_csv(tuned_filename) - except: + except pd.errors.ParserError as e: + logger.warning( + "An error occurred while parsing `%s`: %s" + "FP8 tuning results will not be used!", tuned_filename, e) + except (IOError, pd.errors.EmptyDataError): return for i in range(len(df)): @@ -58,7 +59,7 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config) -> "Fp8RocmConfig": - return cls(config) + return cls() @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -73,8 +74,8 @@ def get_min_capability(cls) -> int: def get_name(cls) -> str: return "Fp8Rocm" - def get_quant_method(self, - layer: torch.nn.Module) -> Optional["Fp8RocmLinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["Fp8RocmLinearMethod"]: if isinstance(layer, LinearBase): return Fp8RocmLinearMethod(self) return None @@ -84,10 +85,10 @@ def get_scaled_act_names(self) -> List[str]: class Fp8RocmLinearMethod(LinearMethodBase): + def __init__(self, config: Fp8RocmConfig): self._config = config - def _create_scale_param( self, scale_name: str, @@ -106,7 +107,6 @@ def _create_scale_param( self.scales_shard_indexer, }) - def create_weights( self, layer: torch.nn.Module, @@ -132,59 +132,56 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, + **extra_weight_attrs, "input_dim": 1, "output_dim": 0 }) - - self._create_scale_param( - scale_name="weights_scaling_factor", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - self._create_scale_param( - scale_name="activation_scaling_factor", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - self._create_scale_param( - scale_name="output_scaling_factor", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs) - - + + self._create_scale_param(scale_name="weights_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param(scale_name="activation_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + self._create_scale_param(scale_name="output_scaling_factor", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + def process_weights_after_loading(self, layer: Module) -> None: if (not hasattr(layer, "process_after_load") or not layer.process_after_load): return - layer.activation_scaling_factor = Parameter(layer.activation_scaling_factor.max(), - requires_grad=False) - layer.output_scaling_factor = Parameter(layer.output_scaling_factor.reciprocal().max(), - requires_grad=False) + layer.activation_scaling_factor = Parameter( + layer.activation_scaling_factor.max(), requires_grad=False) + layer.output_scaling_factor = Parameter( + layer.output_scaling_factor.reciprocal().max(), + requires_grad=False) max_w_scale = layer.weights_scaling_factor.max() if len(layer.logical_widths) > 1: start = 0 for idx, logical_width in enumerate(layer.logical_widths): end = start + logical_width - weight_dq = _per_tensor_dequantize(layer.weight[start:end, :], - layer.weights_scaling_factor[idx]) + weight_dq = _per_tensor_dequantize( + layer.weight[start:end, :], + layer.weights_scaling_factor[idx]) layer.weight[start:end, :] = _per_tensor_quantize( weight_dq, max_w_scale) start = end - layer.weights_scaling_factor = Parameter(max_w_scale, requires_grad=False) + layer.weights_scaling_factor = Parameter(max_w_scale, + requires_grad=False) # WEIGHT # Transpose weight for passing to torch._scaled_mm weight = layer.weight layer.weight = Parameter(weight, requires_grad=False) - def scales_shard_indexer( self, param: torch.Tensor, loaded_weight: torch.Tensor, shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: @@ -198,7 +195,7 @@ def scales_shard_indexer( shard_id = qkv_idxs[shard_id] else: ValueError(f"Shard id must be int or str but got {type(shard_id)}") - + # To handle the scalar loaded tensor if loaded_weight.numel() == 1 and len(loaded_weight.shape) != 0: loaded_weight = torch.scalar_tensor(loaded_weight[0]) @@ -212,7 +209,9 @@ def apply_fp8_16( asf: torch.Tensor, wsf: torch.Tensor, osf: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert not bias x8 = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) vllm_ops.convert_fp8(x8, x, asf) m = weight.shape[0] @@ -226,11 +225,15 @@ def apply_fp8_16( if os.getenv("TUNE_FP8") == "1": try: df = pd.read_csv("/projects/fp8_shapes.csv") - except: + except (IOError, pd.errors.EmptyDataError, + pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( - [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] - ).drop_duplicates() + [df, pd.DataFrame({ + "M": [m], + "N": [n], + "K": [k] + })]).drop_duplicates() df.to_csv("/projects/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) @@ -259,17 +262,21 @@ def apply_fp8_8( if os.getenv("TUNE_FP8") == "1": try: df = pd.read_csv("/projects/fp8_shapes.csv") - except: + except (IOError, pd.errors.EmptyDataError, + pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) df = pd.concat( - [df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})] - ).drop_duplicates() - df.to_csv("/projects/fp8_shapese.csv", index=False) + [df, pd.DataFrame({ + "M": [m], + "N": [n], + "K": [k] + })]).drop_duplicates() + df.to_csv("/projects/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) res16 = torch.empty_like(res, dtype=torch.float16) - vllm_ops.convert_fp8(res16, res, 1/osf) + vllm_ops.convert_fp8(res16, res, 1 / osf) return res16 def apply( @@ -287,17 +294,17 @@ def apply( return self._config.gemm_method(self, x, weight, asf, wsf, osf) return F.linear(x, weight, bias) - + def _per_tensor_quantize(tensor: torch.Tensor, - inv_scale: float) -> torch.Tensor: + inv_scale: float) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fnuz) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) return qweight.to(torch.float8_e4m3fnuz) def _per_tensor_dequantize(tensor: torch.Tensor, - inv_scale: float) -> torch.Tensor: + inv_scale: float) -> torch.Tensor: fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index aa143c65d82b9..a7b8d1ad35620 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -247,10 +247,13 @@ def load_model(self, *, model_config: ModelConfig, model, "fall_back_to_pt_during_load", True)), ) - if model_config.quantization == 'fp8' and model_config.quantization_param_path is not None: + if (model_config.quantization == 'fp8' + and model_config.quantization_param_path is not None): model.load_quantized_weights( - safetensors_weights_iterator([model_config.model + model_config.quantization_param_path]) - ) + safetensors_weights_iterator([ + model_config.model + + model_config.quantization_param_path + ])) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2b8d3573f45cf..c7d63df353ca5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -438,8 +438,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + def load_quantized_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]): params_dict = dict(self.named_parameters()) #with open("/projects/a.txt", "r") as f: # j = json.load(f) @@ -457,7 +458,8 @@ def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: #print(name) name = name.replace('transformer', 'model') - name = name.replace('kv_cache_scaling_factor', 'qkv.output_scaling_factor') + name = name.replace('kv_cache_scaling_factor', + 'qkv.output_scaling_factor') loaded_weight = loaded_weight.to("cuda") if loaded_weight.dtype == torch.int8: loaded_weight[loaded_weight == -128] = 0 @@ -481,9 +483,9 @@ def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) param = params_dict[name] - if "activation_scaling_factor" in name or "weights_scaling_factor" in name: - param.data.copy_(loaded_weight) - elif "output_scaling_factor" in name: + if ("activation_scaling_factor" in name + or "weights_scaling_factor" in name + or "output_scaling_factor" in name): param.data.copy_(loaded_weight) else: weight_loader = getattr(param, "weight_loader", From 95b3accb1805824e13426e2db7246c4881568afa Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Mon, 10 Jun 2024 04:28:21 -0500 Subject: [PATCH 411/413] Include benchmark scripts in container (#45) --- Dockerfile.rocm | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 83a483075f8a4..ac0d2d8a6aa5e 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -153,10 +153,12 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl / COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/gradlib/dist/*.whl / COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/rocm_patch /rocm_patch COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt / +COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks # ----------------------- # Final vLLM image FROM base AS final +ARG COMMON_WORKDIR ARG BUILD_FA RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/* @@ -215,6 +217,9 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ *) ;; esac \ && pip install *.whl +# Copy over the benchmark scripts as well +COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks + ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 # Performance environment variable. From d254de7f0b27b2c6513ac6f8e472a01de6c49a46 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Mon, 10 Jun 2024 11:40:08 -0500 Subject: [PATCH 412/413] Adding fp8 to gradlib (#44) * adding fp8 gemm tunner to gradlib * formatting * add instructions * Linting * adding fp8 gemm tunner to gradlib formatting add instructions * Linting fp8 gradlib * fix merging issue of ROCm_performance.md * delete fp8_gemm_tuner.py * Fix linting for triton: unmeld if with constexpr * update tutorial * Fix linting again * fix typo --------- Co-authored-by: Matthew Wong --- ROCm_performance.md | 21 + gradlib/csrc/hipbsolgemm.cu | 813 +++++++++--------- gradlib/gradlib/fp8_gemm_tuner.py | 289 +++++++ .../layers/quantization/fp8_rocm.py | 12 +- 4 files changed, 731 insertions(+), 404 deletions(-) create mode 100644 gradlib/gradlib/fp8_gemm_tuner.py diff --git a/ROCm_performance.md b/ROCm_performance.md index 180c848a21950..bea77d1a27fc4 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -18,3 +18,24 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. + +## Fp8 Quantization + +To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. + +Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`. + +## Gemm Tuning for Fp8 + +To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model. + +To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and then run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`. + +Next, run gradlib to obtain the best solutions of these shapes: + +``` +python3 gradlib/gradlib/fp8_gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv +``` +where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer. + +Now, when running inference with fp8, we are using the tuned gemm for best performance. \ No newline at end of file diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index bf15fb1297667..7888abb6e923c 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -1,9 +1,9 @@ // #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others -// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h -// #undef __HIP_NO_HALF_OPERATORS__ -// #undef __HIP_NO_HALF_CONVERSIONS__ -// #endif +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif #include #include @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -31,168 +32,138 @@ #include #include "nvToolsExt.h" -//#include - +// #include // #ifdef USE_ROCM -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #endif +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif // #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL // #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif +// flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if(error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if(error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIPBLAS_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif namespace { - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; +/*thread_local*/ cudaStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ cudaEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +size_t workspace_size = 2 * 128 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipDataType dtype; + + friend auto operator<(const MatMulConfig& left, + const MatMulConfig& right) -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, + right.dtype); + } +}; - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - size_t workspace_size = 2*128*1024*1024; - //uint64_t workspace_size = 0; - void* d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig { - hipblasOperation_t op_A; - hipblasOperation_t op_B; - int M; - int N; - int K; - hipDataType dtype; - - friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); - } - }; +// std::map, +// std::vector> heuristic_map; +std::map heuristic_map; - // std::map, std::vector> heuristic_map; - std::map heuristic_map; +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; - hipEvent_t start, stop; - int bench_iters { 1 }; - int warmup_iters { 1 }; +bool cout_print = false; - bool cout_print = false; - - //std::vector heuristicResult; -} +torch::Tensor dTensor; -//find all hipblaslt solutions for given gemm problem +// std::vector heuristicResult; +} // namespace + +// find all hipblaslt solutions for given gemm problem std::vector hipblasLtMatmul_findallsols_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipDataType dtype, - hipStream_t &stream) -{ - int flag { 0 }; + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* b, int ldb, const void* beta, void* c, int ldc, + hipDataType intype, hipDataType outtype, hipStream_t& stream) { + int flag{0}; hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - - //std::vector heuristicResult(10); - //CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - // handle, matmul, matA, matB, matC, matC, - // preference, 10, heuristicResult.data(), &returnedAlgoCount)); + + // std::vector heuristicResult(10); + // CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + // handle, matmul, matA, matB, matC, matC, + // preference, 10, heuristicResult.data(), &returnedAlgoCount)); std::vector heuristicResult; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, - op_A, - op_B, - dtype, - dtype, - dtype, - dtype, - HIPBLAS_COMPUTE_32F, - heuristicResult)); + CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos( + handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, intype, + intype, outtype, outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); std::vector algoIndex; int returned_algo_count = heuristicResult.size(); - //for (int i = 0; i < returnedAlgoCount; i++) { + // for (int i = 0; i < returnedAlgoCount; i++) { for (int i = 0; i < returned_algo_count; i++) { - auto algo = heuristicResult[i].algo; - size_t ret_workspace_size = 0; - auto status = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - alpha, - matA, - matB, - beta, - matC, - matC, - algo, - ret_workspace_size - ); - if (status == HIPBLAS_STATUS_SUCCESS) { - if (ret_workspace_size hipblasLtMatmul_findallsols_wrapper( ///////////////////////////////////////////////////////////////////////////////////////////////////////// /** * hipBLASLt GEMM call -*/ + */ hipblasStatus_t hipblasLtMatmul_sol_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipDataType dtype, - hipStream_t &stream, - int solution_index=-1) -{ + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* scaleA, const void* b, int ldb, const void* scaleB, + const void* beta, void* c, int ldc, const void* scaleC, hipDataType intype, + hipDataType outtype, hipStream_t& stream, int solution_index = -1) { // TODO: flag is not supported for hipblasLt yet - int flag { 0 }; - //if (dtype == HIPBLAS_R_16F) { - // use fp16 alt impl for MI200 - // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - //flag = rocblas_gemm_flags_fp16_alt_impl; + int flag{0}; + // if (dtype == HIPBLAS_R_16F) { + // use fp16 alt impl for MI200 + // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + // flag = rocblas_gemm_flags_fp16_alt_impl; //} - //nvtxRangePushA("hipBLASLt variables creation"); + // nvtxRangePushA("hipBLASLt variables creation"); hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - //nvtxRangePop(); - // if heuristic does not exist in the map, do search and push into the map - //auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; - //if (heuristic_map.count(gemm_key) <= 0) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); + // nvtxRangePop(); + // if heuristic does not exist in the map, do search and push into the map + // auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; + // if (heuristic_map.count(gemm_key) <= 0) { std::vector heuristicResult(1); - if (solution_index<0) { - //nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); - std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; + if (solution_index < 0) { + // nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + std::cout + << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" + << std::endl; if (cout_print) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") + << (op_B == HIPBLAS_OP_N ? "N" : "T") << " (" << m << ", " << n + << ", " << k << "), dtype: " << intype << ", (lda, ldb, ldc): (" + << lda << ", " << ldb << ", " << ldc << "), " << std::endl; } - //std::vector heuristicResult(request_solutions); + // std::vector + // heuristicResult(request_solutions); CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if((returnedAlgoCount != request_solutions) && cout_print) { + handle, matmul, matA, matB, matC, matC, preference, request_solutions, + heuristicResult.data(), &returnedAlgoCount)); + if ((returnedAlgoCount != request_solutions) && cout_print) { std::cout << "less solution found! request: " << request_solutions << ", found: " << returnedAlgoCount << std::endl; } - //heuristic_map[gemm_key] = heuristicResult[0]; -/* - if (returnedAlgoCount == 1) { - heuristic_map[gemm_key] = heuristicResult[0]; - } else { - // benchmark requested solutions and pick best one - int bestIndex { -1 }; - double bestMs { std::numeric_limits::max() }; - for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { - // warm up - for (int iter { 0 }; iter < warmup_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - // performance measuring - double eventMs; - CHECK_HIP_ERROR(hipEventRecord(start, stream)); - for (int iter { 0 }; iter < bench_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - CHECK_HIP_ERROR(hipEventRecord(stop, stream)); - CHECK_HIP_ERROR(hipEventSynchronize(stop)); - float temp; - CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); - eventMs = double(temp); - eventMs /= bench_iters; - - if (cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; - } - if (bestMs > eventMs) { - bestMs = eventMs; - bestIndex = sol; - if (cout_print) { - std::cout << " *" << std::endl; - } + // heuristic_map[gemm_key] = heuristicResult[0]; + /* + if (returnedAlgoCount == 1) { + heuristic_map[gemm_key] = heuristicResult[0]; } else { - if (cout_print) { - std::cout << std::endl; + // benchmark requested solutions and pick best one + int bestIndex { -1 }; + double bestMs { std::numeric_limits::max() }; + for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { + // warm up + for (int iter { 0 }; iter < warmup_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the + values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + // performance measuring + double eventMs; + CHECK_HIP_ERROR(hipEventRecord(start, stream)); + for (int iter { 0 }; iter < bench_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the + values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + CHECK_HIP_ERROR(hipEventRecord(stop, stream)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + float temp; + CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); + eventMs = double(temp); + eventMs /= bench_iters; + + if (cout_print) { + std::cout << " Sol " << sol << ": average time per iter " << + std::to_string(eventMs) << " ms"; + } + if (bestMs > eventMs) { + bestMs = eventMs; + bestIndex = sol; + if (cout_print) { + std::cout << " *" << std::endl; + } + } else { + if (cout_print) { + std::cout << std::endl; + } + } } + heuristic_map[gemm_key] = heuristicResult[bestIndex]; } - } - heuristic_map[gemm_key] = heuristicResult[bestIndex]; - } -*/ - //nvtxRangePop(); + */ + // nvtxRangePop(); } else { - std::vector algoIndex(1); - algoIndex[0]=solution_index; - //std::vector tmpAlgo; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); + std::vector algoIndex(1); + algoIndex[0] = solution_index; + // std::vector tmpAlgo; + CHECK_HIPBLAS_ERROR( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); } - - //size_t ret_workspace_size = 0; - - //auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - // alpha, - // matA, - // matB, - // beta, - // matC, - // matC, - // heuristicResult[0].algo, - // ret_workspace_size + + // size_t ret_workspace_size = 0; + + // auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, + // alpha, + // matA, + // matB, + // beta, + // matC, + // matC, + // heuristicResult[0].algo, + // ret_workspace_size //); - //if (status1 == HIPBLAS_STATUS_SUCCESS) { - // std::cout << "Workspace size" << ret_workspace_size << std::endl; + // if (status1 == HIPBLAS_STATUS_SUCCESS) { + // std::cout << "Workspace size" << ret_workspace_size << std::endl; //} else { - // std::cout << "Algo not supported!!!" << std::endl; + // std::cout << "Algo not supported!!!" << std::endl; //} - hipblasStatus_t status = hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, - &heuristicResult[0].algo, - d_workspace, workspace_size, - stream); - - //nvtxRangePushA("hipBLASLt variables deletion"); + hipblasStatus_t status = hipblasLtMatmul( + handle, matmul, alpha, a, matA, b, matB, beta, c, matC, c, matC, + &heuristicResult[0].algo, d_workspace, workspace_size, stream); + + // nvtxRangePushA("hipBLASLt variables deletion"); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); - //nvtxRangePop(); + // nvtxRangePop(); return status; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// torch::Tensor HipbSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - const int solution_index - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; + const torch::Tensor& mat1, const torch::Tensor& mat2, + const int solution_index, at::optional Type = at::nullopt, + at::optional scale1 = at::nullopt, + at::optional scale2 = at::nullopt, + at::optional scaleOut = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto inType{mat1.options().dtype()}; + auto outType = inType.toScalarType(); + if (Type.has_value()) + outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << + // result.strides() << std::endl; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { @@ -434,90 +414,120 @@ torch::Tensor HipbSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one { 1.0f }; - float zero { 0.0f }; + // std::cout << " | transpose_result: " << (transpose_result ? "true" : + // "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << + // std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << + // std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl - // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl; - - hipDataType hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIP_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIP_R_16BF; - } else if (abcType == at::kFloat) { - hipblasType = HIP_R_32F; + + void *d_scale1 = nullptr, *d_scale2 = nullptr, *d_scaleOut = nullptr; + if (scale1.has_value()) { + d_scale1 = static_cast(scale1.value().data_ptr()); + } + if (scale2.has_value()) { + d_scale2 = static_cast(scale2.value().data_ptr()); + } + if (scaleOut.has_value()) { + d_scaleOut = static_cast(scaleOut.value().data_ptr()); + } + + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; } else { assert(false && "Wrong datatype!"); } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + if (transpose_result) std::swap(d_scale1, d_scale2); + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasType, - current_stream,solution_index)); + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, d_scale1, ptrB, mat2_ld, d_scale2, &zero, ptrC, result_ld, + d_scaleOut, hipblasInType, hipblasOutType, current_stream, + solution_index)); return result; } -//find all hipblas solutions and return them to python land +// find all hipblas solutions and return them to python land std::vector HipbFindAllSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2 - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; + const torch::Tensor& mat1, const torch::Tensor& mat2, + at::optional Type = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto inType{mat1.options().dtype()}; + auto outType = inType.toScalarType(); + if (Type.has_value()) + outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { bool tmp = transpose_mat1; @@ -528,83 +538,90 @@ std::vector HipbFindAllSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - float one { 1.0f }; - float zero { 0.0f }; + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - hipDataType hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIP_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIP_R_16BF; - } else if (abcType == at::kFloat) { - hipblasType = HIP_R_32F; + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; } else { assert(false && "Wrong datatype!"); } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; return hipblasLtMatmul_findallsols_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasType, - current_stream); - + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, ptrB, mat2_ld, &zero, ptrC, result_ld, hipblasInType, + hipblasOutType, current_stream); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void hipb_create_extension() -{ - //CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); - //CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); +void hipb_create_extension() { + // CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + // CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); // hipBLASLt CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - - //CHECK_HIP_ERROR(hipEventCreate(&start)); - //CHECK_HIP_ERROR(hipEventCreate(&stop)); + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); + + // CHECK_HIP_ERROR(hipEventCreate(&start)); + // CHECK_HIP_ERROR(hipEventCreate(&stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void hipb_destroy_extension() -{ - //CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - //CHECK_HIP_ERROR(hipEventDestroy(event)); +void hipb_destroy_extension() { + // CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + // CHECK_HIP_ERROR(hipEventDestroy(event)); - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); - //CHECK_HIP_ERROR(hipEventDestroy(start)); - //CHECK_HIP_ERROR(hipEventDestroy(stop)); + // CHECK_HIP_ERROR(hipEventDestroy(start)); + // CHECK_HIP_ERROR(hipEventDestroy(stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); - m.def("hipb_mm", &HipbSolIdxBlas, "mm"); - m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols"); -} + m.def("hipb_mm", &HipbSolIdxBlas, "mm", py::arg("mat1"), py::arg("mat2"), + py::arg("solution_index"), py::arg("outType") = at::nullopt, + py::arg("scale1") = at::nullopt, py::arg("scale2") = at::nullopt, + py::arg("scaleOut") = at::nullopt); + m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", + py::arg("mat1"), py::arg("mat2"), py::arg("outType") = at::nullopt); +} \ No newline at end of file diff --git a/gradlib/gradlib/fp8_gemm_tuner.py b/gradlib/gradlib/fp8_gemm_tuner.py new file mode 100644 index 0000000000000..61df1933f8658 --- /dev/null +++ b/gradlib/gradlib/fp8_gemm_tuner.py @@ -0,0 +1,289 @@ +import argparse +import json +import os +import random +from pathlib import Path + +import hipbsolidxgemm +import pandas as pd +import torch +import torch.nn.functional as F + +hipbsolidxgemm.hipb_create_extension() + +rtol = 1e-5 +atol = 1 + + +class Fp8Gemm: + + def __init__(self, m, n, k, indtype, outdtype): + self.m = m + self.k = k + self.n = n + self.indtype = indtype + self.outdtype = outdtype + self.nb = 37 + self.inp = torch.randn((self.n, self.k), + device='cuda').to(self.indtype) + self.weights = torch.randn((self.m, self.k), + device='cuda').to(self.indtype) + # weights2 is used in measurement/warm iters to ensure HBM + # fetch for weight tensors + self.weights2 = torch.randn((self.nb, self.m, self.k), + device='cuda').to(self.indtype) + self.blob = torch.ones(128 * 1024 * 1024, + dtype=torch.float32, + device='cuda') + self.topn = 20 #number of top solutions from each source + self.hipb_sols = [] + self.rtol = 1e-5 + self.atol = 1 + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + def find_hipblas_sols(self): + sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t(), + self.outdtype) + print('M N K', + self.m, + self.n, + self.k, + '>>> Total hipb solutions', + len(sols), + flush=True) + #print(sols) + self.hipb_sols = sols + + def check_gemm_ref(self, libtype, solidx): + ref = F.linear(self.inp.to(torch.float32), + self.weights.to(torch.float32)).to(self.outdtype) + c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) + if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): + #print('>>>',libtype,'Solidx',solidx,'passed reference test') + return True + else: + print('>>>', 'Solidx', solidx, 'FAILED reference test', flush=True) + print(ref, flush=True) + print(c, flush=True) + return False + + def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): + #print('>>>hipbtime',solidx) + for i in range(cold_iters): + hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) + self.start.record() + for i in range(warm_iters): + hipbsolidxgemm.hipb_mm( + self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), + solidx, self.outdtype) + self.end.record() + torch.cuda.synchronize() + gtime = self.start.elapsed_time(self.end) / warm_iters + #print('>>> Solidx GTime',solidx,gtime,'ms') + return gtime + + def hipb_time_all_sols(self, fast_mode=0, top_sols=0): + coldi = 20 + warmi = 20 + if fast_mode: + coldi = 2 + warmi = 2 + solutions = self.hipb_sols + if top_sols: + solutions = self.hipb_top_sols + gtimes = {} + for solidx in solutions: + gtimes[solidx] = self.hipb_time_sol(solidx, + cold_iters=coldi, + warm_iters=warmi) + self.hipb_gtimedf = pd.DataFrame.from_dict( + gtimes, orient='index', + columns=['gtimems']).sort_values(by='gtimems') + self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') + print('>>> HipBlasLt top solutions, Fast Mode', fast_mode) + print(self.hipb_gtimedf.head(self.topn)) + + def warmup(self, warmi=500): + for i in range(warmi): + self.blob = self.blob + 0.00001 + + def functional_check_topn_fastest(self): + hipb_topn = [] + for solidx in self.hipb_gtimedf.index[:self.topn]: + if self.check_gemm_ref(libtype='hipblaslt', solidx=solidx): + hipb_topn.append(solidx) + self.hipb_top_sols = hipb_topn + + def find_fastest_solution(self): + self.find_hipblas_sols() + self.warmup() + self.hipb_time_all_sols(fast_mode=1) + self.functional_check_topn_fastest() + self.warmup() + self.hipb_time_all_sols(fast_mode=0, top_sols=1) + if len(self.hipb_gtimedf) > 0: + best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] + self.best_solidx = self.hipb_gtimedf.index[0] + self.best_soltime = best_hipb_time + else: + print('>>> No hipblas solutions found!', flush=True) + self.best_solidx = 0 + self.best_soltime = 0 + print('>>> Fastest Solution is', + self.best_solidx, + self.best_soltime, + flush=True) + + +class Fp8GemmTuner: + + def __init__(self, indtype, outdtype, tuned_file=None): + self.gemm_problems = pd.DataFrame(columns=['M', 'N', 'K']) + self.indtype = indtype + self.outdtype = outdtype + self.tuned_file = tuned_file + if Path(tuned_file).is_file(): + self.gdf = pd.read_csv(tuned_file) + else: + self.gdf = None + + def add_gemm(self, m, n, k): + if (self.gdf is None + or (self.gdf[(self.gdf['M'] == m) & (self.gdf['N'] == n) & + (self.gdf['K'] == k)].empty)): + entry = {'M': [m], 'N': [n], 'K': [k]} + df = pd.DataFrame(entry) + self.gemm_problems = pd.concat([self.gemm_problems, df], + ignore_index=True) + else: + print( + f">>>Info: Found Duplicate shape(M:{m}, N:{n}, K:{k}), skipping" + ) + + def find_best_sols(self): + df = self.gemm_problems + soldf = pd.DataFrame() + for i in range(len(df)): + ds = df.iloc[i] + gemmobj = Fp8Gemm(ds['M'], + ds['N'], + ds['K'], + indtype=self.indtype, + outdtype=self.outdtype) + gemmobj.find_fastest_solution() + soldf.loc[i, 'solidx'] = gemmobj.best_solidx + soldf.loc[i, 'soltimems'] = gemmobj.best_soltime + soldf['indtype'] = self.indtype + soldf['outdtype'] = self.outdtype + finaldf = pd.concat([self.gemm_problems, soldf], axis=1) + finaldf = pd.concat([finaldf, self.gdf]) + finaldf.to_csv(self.tuned_file, index=False) + print(finaldf) + + +def generate_mk_sets(model_dir, tp=1): + with open(f'{model_dir}/config.json') as f: + data = json.load(f) + hidden_size = data['hidden_size'] + intermediate_size = data['intermediate_size'] + total_num_heads = data['num_attention_heads'] + total_num_kv_heads = data['num_key_value_heads'] + head_dim = hidden_size // total_num_heads + return [((total_num_heads + (2 * total_num_kv_heads)) * head_dim // tp, + hidden_size), (hidden_size, hidden_size // tp), + (intermediate_size * 2 // tp, hidden_size), + (hidden_size, intermediate_size // tp)], hidden_size + + +def get_dtype(dtype_str): + dtype = torch.float8_e4m3fnuz + if dtype_str == 'f32': + dtype = torch.float32 + elif dtype_str == 'bf16': + dtype = torch.bfloat16 + elif dtype_str == 'f16': + dtype = torch.float16 + elif dtype_str == 'f8': + dtype = torch.float8_e4m3fnuz + else: + print('>>> Warning! Invalid dtype', dtype_str, + 'using default dtype f8') + return dtype + + +def list_of_ints(arg): + return list(map(int, arg.split(','))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", + type=str, + default=os.getenv('GTUNE_MODEL', ""), + help="Enter the location of your model directory") + parser.add_argument("--tuned_file", + type=str, + default=os.getenv('GTUNE_TUNED', "tuned.csv"), + help="output file for tuned gemm solutions") + parser.add_argument( + "--input_file", + type=str, + default=os.getenv('GTUNE_INPUT', None), + help="list of gemms to tune for, mutually exclusive with model_dir") + parser.add_argument("--tp", + type=int, + default=os.getenv('GTUNE_TP', 1), + help="Tensor parallelism to be used.") + parser.add_argument("--indtype", + type=str, + default='f8', + help="dtype f32 f16 bf16 fp8") + parser.add_argument("--outdtype", + type=str, + default='f16', + help="dtype f32 f16 bf16 fp8") + parser.add_argument("--batch_size", + type=int, + default=os.getenv('GTUNE_BATCH_SIZE', 1), + help="Batch size to tune for") + parser.add_argument("--nsets", + type=list_of_ints, + default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], + help="N sizes to tune for: 1,128,2048") + args = parser.parse_args() + + indtype = get_dtype(args.indtype) + outdtype = get_dtype(args.outdtype) + + gtuner = Fp8GemmTuner(indtype, outdtype, args.tuned_file) + nsets = [i * args.batch_size for i in args.nsets] + if args.input_file: + print(f">>> Loading {args.input_file}") + if not Path(args.input_file).is_file(): + print(f">>> ERROR: {args.input_file} does not exist. Exiting") + exit(1) + shapes = pd.read_csv(args.input_file) + for i in range(len(shapes)): + ds = shapes.iloc[i] + gtuner.add_gemm(ds['M'], ds['N'], ds['K']) + else: + if not args.model_dir: + print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") + #LL2 13B sizes + mksets = [(15360, 5120), (5120, 5120), (27648, 5120), + (5120, 13824)] + gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm + else: + mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp) + gtuner.add_gemm( + m=32000 // args.tp, n=1 * args.batch_size, k=hidden_size + ) #TODO: Handle cases where vocab_size is not divisible by tp + + for n in sorted(nsets): + for m, k in mksets: + gtuner.add_gemm(m, n, k) + + gtuner.find_best_sols() diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index ddccc5825c8a4..ddb83a6ed452e 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -30,10 +30,10 @@ def __init__(self) -> None: #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8 - tuned_filename = "/projects/tuned_fp8_8.csv" + tuned_filename = "/tmp/tuned_fp8_8.csv" elif gemm_type == "fp8_16": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 - tuned_filename = "/projects/tuned_fp8_16.csv" + tuned_filename = "/tmp/tuned_fp8_16.csv" else: raise ValueError(f"Unknown fp8 gemm type: {gemm_type}") try: @@ -50,7 +50,7 @@ def __init__(self) -> None: m = shape["M"] n = shape["N"] k = shape["K"] - algo = shape["algo"] + algo = shape["solidx"] self._tuned[(m, n, k)] = algo @classmethod @@ -224,7 +224,7 @@ def apply_fp8_16( if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/projects/fp8_shapes.csv") + df = pd.read_csv("/tmp/fp8_shapes.csv") except (IOError, pd.errors.EmptyDataError, pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) @@ -234,7 +234,7 @@ def apply_fp8_16( "N": [n], "K": [k] })]).drop_duplicates() - df.to_csv("/projects/fp8_shapes.csv", index=False) + df.to_csv("/tmp/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res @@ -271,7 +271,7 @@ def apply_fp8_8( "N": [n], "K": [k] })]).drop_duplicates() - df.to_csv("/projects/fp8_shapes.csv", index=False) + df.to_csv("/tmp/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo)) From ff241027abbfdef8a1e9d0c431cf9bc10520884d Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 12 Jun 2024 00:07:30 +0800 Subject: [PATCH 413/413] Update fp8_gemm_tuner.py exchange import torch and hipbsolidxgemm (#46) * Update fp8_gemm_tuner.py exchange import torch and hipbsolidxgemm ImportError: libc10.so: cannot open shared object file: No such file or directory https://stackoverflow.com/a/65710714 * run isort on fp9_gemm_tuner.py * add # isort: split * fix yapf --------- Co-authored-by: charlifu --- gradlib/gradlib/fp8_gemm_tuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradlib/gradlib/fp8_gemm_tuner.py b/gradlib/gradlib/fp8_gemm_tuner.py index 61df1933f8658..babf304927605 100644 --- a/gradlib/gradlib/fp8_gemm_tuner.py +++ b/gradlib/gradlib/fp8_gemm_tuner.py @@ -4,9 +4,9 @@ import random from pathlib import Path +import torch # isort: split import hipbsolidxgemm import pandas as pd -import torch import torch.nn.functional as F hipbsolidxgemm.hipb_create_extension()