From b3104b2a10ab7cb7532442177ae0d0c40acf9d03 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E8=83=A1=E8=AF=91=E6=96=87?= <1020030101@qq.com>
Date: Wed, 10 Apr 2024 15:09:36 +0800
Subject: [PATCH 001/413] [Bugfix] Fix logits processor when prompt_logprobs is
not None (#3899)
---
tests/samplers/test_logits_processor.py | 62 +++++++++++++++++++
.../model_executor/layers/logits_processor.py | 11 +++-
2 files changed, 72 insertions(+), 1 deletion(-)
create mode 100644 tests/samplers/test_logits_processor.py
diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py
new file mode 100644
index 0000000000000..3788e9e9752ff
--- /dev/null
+++ b/tests/samplers/test_logits_processor.py
@@ -0,0 +1,62 @@
+import pytest
+import torch
+
+from vllm import SamplingParams
+
+MODELS = ["facebook/opt-125m"]
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+def test_logits_processor_force_generate(
+ vllm_runner,
+ example_prompts,
+ model: str,
+ dtype: str,
+) -> None:
+ vllm_model = vllm_runner(model, dtype=dtype)
+ tokenizer = vllm_model.model.get_tokenizer()
+ repeat_times = 2
+ enforced_answers = " vLLM"
+ vllm_token_ids = tokenizer.encode(enforced_answers,
+ add_special_tokens=False)
+ max_tokens = len(vllm_token_ids) * repeat_times
+
+ def pick_vllm(token_ids, logits):
+ token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
+ logits[token_id] = torch.finfo(logits.dtype).max
+ return logits
+
+ params_with_logprobs = SamplingParams(
+ logits_processors=[pick_vllm],
+ prompt_logprobs=3,
+ max_tokens=max_tokens,
+ )
+
+ # test logits_processors when prompt_logprobs is not None
+ vllm_model.model._add_request(
+ prompt=example_prompts[0],
+ sampling_params=params_with_logprobs,
+ prompt_token_ids=None,
+ )
+
+ # test prompt_logprobs is not None
+ vllm_model.model._add_request(
+ prompt=example_prompts[1],
+ sampling_params=SamplingParams(
+ prompt_logprobs=3,
+ max_tokens=max_tokens,
+ ),
+ prompt_token_ids=None,
+ )
+
+ # test grouped requests
+ vllm_model.model._add_request(
+ prompt=example_prompts[2],
+ sampling_params=SamplingParams(max_tokens=max_tokens),
+ prompt_token_ids=None,
+ )
+
+ outputs = vllm_model.model._run_engine(False)
+
+ assert outputs[0].outputs[0].text == enforced_answers * repeat_times
diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py
index 28e8f6bb7e638..ec531f79ced52 100644
--- a/vllm/model_executor/layers/logits_processor.py
+++ b/vllm/model_executor/layers/logits_processor.py
@@ -86,8 +86,16 @@ def _apply_logits_processors(
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
- for seq_ids, sampling_params in sampling_metadata.seq_groups:
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
+ seq_ids, sampling_params = seq_group
logits_processors = sampling_params.logits_processors
+ # handle prompt_logprobs by skipping rows in logits added for
+ # the prompt tokens (prompt logprobs are not processed)
+ if (i < sampling_metadata.num_prompts
+ and sampling_params.prompt_logprobs is not None):
+ assert len(seq_ids) == 1
+ logits_row_idx += sampling_metadata.prompt_lens[i] - 1
+
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
@@ -100,5 +108,6 @@ def _apply_logits_processors(
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
+ # verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0]
return logits
From 0258b7a94b08321ca01cf170f867b67c1920af87 Mon Sep 17 00:00:00 2001
From: Travis Johnson
Date: Wed, 10 Apr 2024 02:39:56 -0600
Subject: [PATCH 002/413] [Bugfix] handle prompt_logprobs in
_apply_min_tokens_penalty (#3876)
Signed-off-by: Travis Johnson
---
tests/samplers/test_sampler.py | 116 +++++++++++++++++++++-----
vllm/model_executor/layers/sampler.py | 19 ++++-
2 files changed, 112 insertions(+), 23 deletions(-)
diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py
index 1626b72282072..26e2d29ffd04c 100644
--- a/tests/samplers/test_sampler.py
+++ b/tests/samplers/test_sampler.py
@@ -1,3 +1,4 @@
+import itertools
import random
from typing import List, Optional, Tuple
from unittest.mock import patch
@@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
- stop_token_ids=None):
+ *,
+ stop_token_ids: Optional[List[str]] = None,
+ prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids,
+ # requesting prompt_logprobs changes the structure of `logits`
+ prompt_logprobs=prompt_logprobs,
)
sampling_params.eos_token_id = eos_token_id
return sampling_params
@@ -217,9 +222,9 @@ def generate_test_case():
expected_penalization = []
sequence_metadata_list = []
+ # 20% chance to generate seq group metadata list with all prompts
+ is_prompt = random.random() < 0.2
while batch_size > 0:
- # 20% chance to generate prompt seq group with single sequence
- is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
@@ -240,7 +245,7 @@ def generate_test_case():
seq_group_penalization = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
- num_generated = random.randint(1, 100) if not is_prompt else 0
+ num_generated = 0 if is_prompt else random.randint(1, 100)
seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens)
@@ -292,6 +297,21 @@ def generate_test_case():
]
}
+ prompt_with_penalization_and_prompt_logprobs = {
+ "expected_penalization": [False, False, True],
+ "seq_group_metadata_list": [
+ SequenceGroupMetadata(
+ request_id="test_1",
+ is_prompt=True,
+ seq_data={
+ next(seq_id_counter): create_sequence_data(num_input=3),
+ },
+ sampling_params=create_sampling_params(1, prompt_logprobs=3),
+ block_tables={},
+ ),
+ ]
+ }
+
stop_penalizing_after_min_tokens = {
"expected_penalization": [False],
"seq_group_metadata_list": [
@@ -309,8 +329,34 @@ def generate_test_case():
}
stop_token_ids = [42, 99, 42, 0] # intentional duplication
- simple_combination = {
- "expected_penalization": [True, False, False],
+ prompt_combination = {
+ "expected_penalization": [False, True, False],
+ "seq_group_metadata_list": [
+ SequenceGroupMetadata(
+ request_id="test_2",
+ is_prompt=True,
+ seq_data={
+ next(seq_id_counter): create_sequence_data(num_input=2),
+ },
+ sampling_params=create_sampling_params(1, prompt_logprobs=3),
+ block_tables={},
+ ),
+ SequenceGroupMetadata(
+ request_id="test_3",
+ is_prompt=True,
+ seq_data={
+ next(seq_id_counter): create_sequence_data(),
+ },
+ sampling_params=create_sampling_params(
+ 0, stop_token_ids=stop_token_ids),
+ block_tables={},
+ )
+ ]
+ }
+
+ stop_token_ids = [1, 999, 37, 37] # intentional duplication
+ decode_combination = {
+ "expected_penalization": [True, False, False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
@@ -327,14 +373,19 @@ def generate_test_case():
),
SequenceGroupMetadata(
request_id="test_2",
- is_prompt=True,
+ is_prompt=False,
seq_data={
- next(seq_id_counter): create_sequence_data(),
+ next(seq_id_counter):
+ create_sequence_data(num_generated=20),
+ next(seq_id_counter):
+ create_sequence_data(num_generated=1),
+ next(seq_id_counter):
+ create_sequence_data(num_generated=10),
},
sampling_params=create_sampling_params(
- 0, stop_token_ids=stop_token_ids),
+ 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
block_tables={},
- )
+ ),
]
}
@@ -342,8 +393,10 @@ def generate_test_case():
test_cases = [
prompt_without_penalization,
prompt_with_penalization,
+ prompt_with_penalization_and_prompt_logprobs,
stop_penalizing_after_min_tokens,
- simple_combination,
+ prompt_combination,
+ decode_combination,
]
else:
test_cases = [generate_test_case()]
@@ -351,30 +404,49 @@ def generate_test_case():
def run_test_case(*,
expected_penalization=None,
seq_group_metadata_list=None):
- assert expected_penalization, "Invalid test case"
- assert seq_group_metadata_list, "Invalid test case"
+ assert expected_penalization, \
+ "Invalid test case, need expected_penalization"
+ assert seq_group_metadata_list, \
+ "Invalid test case, need seq_group_metadata_list"
batch_size = 0
prompt_lens = []
- sampling_params_per_seq = []
+ sampling_params_per_row = []
for sgm in seq_group_metadata_list:
- num_seqs = len(sgm.seq_data)
- batch_size += num_seqs
sampling_params = sgm.sampling_params
- for seq_id in sgm.seq_data:
- prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
- sampling_params_per_seq.append(sampling_params)
+
+ num_rows = len(sgm.seq_data)
+ if sgm.is_prompt:
+ # a prompt seq_group has only one sequence
+ seq_data = next(iter(sgm.seq_data.values()))
+ prompt_len = seq_data.get_prompt_len()
+ prompt_lens.append(prompt_len)
+
+ if sgm.sampling_params.prompt_logprobs:
+ # with prompt_logprobs each token in the prompt has a row in
+ # logits
+ num_rows = prompt_len
+
+ batch_size += num_rows
+ sampling_params_per_row.extend(
+ itertools.repeat(sampling_params, num_rows))
+
+ assert len(
+ expected_penalization
+ ) == batch_size, \
+ ("Invalid test case, expected_penalization does not match computed"
+ "batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list,
- prompt_lens=prompt_lens,
- subquery_lens=prompt_lens)
+ prompt_lens=prompt_lens if prompt_lens else None,
+ subquery_lens=prompt_lens if prompt_lens else None)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate(
- zip(expected_penalization, sampling_params_per_seq)):
+ zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index cb1480de03e3a..03bf38caebe0e 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -27,6 +27,12 @@ class Sampler(nn.Module):
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
+
+ The structure of the logits tensor is coupled with the seq_groups in
+ sampling_metadata. Typically, each sequence in each seq_group has one row in
+ logits for the next token to be sampled; however, for a seq_group with a
+ prompt request with the prompt_logprobs sampling parameter, there are rows
+ in logits for each token in the input prompt.
"""
def forward(
@@ -106,7 +112,16 @@ def _apply_min_tokens_penalty(
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
- for seq_ids, sampling_params in sampling_metadata.seq_groups:
+ for i, seq_group in enumerate(sampling_metadata.seq_groups):
+ seq_ids, sampling_params = seq_group
+
+ # handle prompt_logprobs by skipping rows in logits added for the prompt
+ # tokens (prompt logprobs are not penalized)
+ if (i < sampling_metadata.num_prompts
+ and sampling_params.prompt_logprobs is not None):
+ assert len(seq_ids) == 1
+ start_idx += sampling_metadata.prompt_lens[i] - 1
+
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
@@ -132,6 +147,8 @@ def _apply_min_tokens_penalty(
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
+ # verifies that no rows in logits were missed unexpectedly
+ assert start_idx == logits.shape[0]
return logits
From bd3c144e0b8e82c9b3c5c40c6d557fe8665de5a3 Mon Sep 17 00:00:00 2001
From: Woosuk Kwon
Date: Wed, 10 Apr 2024 07:37:17 -0700
Subject: [PATCH 003/413] [Bugfix][ROCm] Add numba to Dockerfile.rocm (#3962)
---
Dockerfile.rocm | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
index 10b8bf1e7fabd..b1c5fac9d78ef 100644
--- a/Dockerfile.rocm
+++ b/Dockerfile.rocm
@@ -91,7 +91,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
COPY ./ /app/vllm
-RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --upgrade pip numba
RUN python3 -m pip install xformers==0.0.23 --no-deps
RUN cd /app \
From 8b317c6dd09ce566f4b4abeb446585ac75262cce Mon Sep 17 00:00:00 2001
From: James Whedbee
Date: Wed, 10 Apr 2024 10:12:00 -0500
Subject: [PATCH 004/413] [Model][AMD] ROCm support for 256 head dims for Gemma
(#3972)
---
vllm/attention/ops/triton_flash_attention.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index b86e845020b07..87cf30cbef79a 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -677,8 +677,7 @@ def check_args(
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
- # TODO: Fix assert to check head size <=256 once supported
- assert head_size <= 128
+ assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
@@ -729,7 +728,7 @@ def forward(
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
- unpadded_head_dims = {32, 64, 128}
+ unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims:
padded_d_model = None
for i in unpadded_head_dims:
From e35397468f36a857b8d2b7d92a472265e1c500cc Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 10 Apr 2024 10:03:02 -0700
Subject: [PATCH 005/413] [Doc] Add doc to state our model support policy
(#3948)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
---
docs/source/models/supported_models.rst | 26 +++++++++++++++++++++++++
1 file changed, 26 insertions(+)
diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index e7bfdcb65316e..c09b0ff250437 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -168,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub `_ and `test_big_models.py `_ for the models that have passed this test.
+2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
+3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests `_ and `examples `_ for the models that have passed this test.
+4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
From e4c4072c94b346053768691451566c56664e26a7 Mon Sep 17 00:00:00 2001
From: Daniel E Marasco
Date: Wed, 10 Apr 2024 13:15:51 -0400
Subject: [PATCH 006/413] [Bugfix] Remove key sorting for `guided_json`
parameter in OpenAi compatible Server (#3945)
---
vllm/model_executor/guided_decoding.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py
index e56f74c7794fb..8e710f1ac2b53 100644
--- a/vllm/model_executor/guided_decoding.py
+++ b/vllm/model_executor/guided_decoding.py
@@ -91,7 +91,7 @@ def _get_guide_and_mode(
json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
- json = json_dumps(json, sort_keys=True)
+ json = json_dumps(json)
elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
From 92cd2e2f21e8ec65b2cb635a9f15de38157a1359 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fr=CE=B1n=C3=A7ois?=
Date: Wed, 10 Apr 2024 20:05:52 +0200
Subject: [PATCH 007/413] [Doc] Fix getting stared to use publicly available
model (#3963)
---
docs/source/serving/openai_compatible_server.md | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md
index 032fe5d03bd52..388b5daa79a92 100644
--- a/docs/source/serving/openai_compatible_server.md
+++ b/docs/source/serving/openai_compatible_server.md
@@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
```bash
-python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123
+python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123
```
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
@@ -16,9 +16,8 @@ client = OpenAI(
)
completion = client.chat.completions.create(
- model="meta-llama/Llama-2-7b-hf",
+ model="mistralai/Mistral-7B-Instruct-v0.2",
messages=[
- {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
)
@@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly
```python
completion = client.chat.completions.create(
- model="meta-llama/Llama-2-7b-hf",
+ model="mistralai/Mistral-7B-Instruct-v0.2",
messages=[
- {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={
@@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
-An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12)
+An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
From 934d3662f716d60abfb04cf9fdd6d20f6e75f140 Mon Sep 17 00:00:00 2001
From: Travis Johnson
Date: Wed, 10 Apr 2024 16:28:25 -0600
Subject: [PATCH 008/413] [Bugfix] handle hf_config with architectures == None
(#3982)
Signed-off-by: Travis Johnson
Co-authored-by: Simon Mo
---
vllm/config.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/vllm/config.py b/vllm/config.py
index 753fc33e9b717..bca250e922288 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -158,7 +158,9 @@ def _verify_load_format(self) -> None:
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
- if "MixtralForCausalLM" in architectures and load_format == "pt":
+ # architectures can be None instead of []
+ if architectures and "MixtralForCausalLM" in architectures \
+ and load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
From 63e7176f265be43dcc425f5ab4ab45c90234f5c3 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 10 Apr 2024 15:33:30 -0700
Subject: [PATCH 009/413] [Core][Refactor] move parallel_utils into
vllm/distributed (#3950)
[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
---
tests/conftest.py | 3 +--
tests/distributed/test_comm_ops.py | 6 +++---
tests/distributed/test_custom_all_reduce.py | 13 ++++++-------
tests/distributed/test_pynccl.py | 4 ++--
tests/lora/conftest.py | 3 +--
vllm/distributed/__init__.py | 3 +++
.../communication_op.py | 14 ++++++++------
.../device_communicators}/__init__.py | 0
.../device_communicators}/custom_all_reduce.py | 5 +++--
.../device_communicators}/pynccl.py | 0
.../device_communicators}/pynccl_utils.py | 4 ++--
.../parallel_state.py | 4 ++--
.../parallel_utils => distributed}/utils.py | 0
vllm/lora/layers.py | 13 ++++++-------
vllm/model_executor/layers/activation.py | 5 ++---
vllm/model_executor/layers/linear.py | 11 +++++------
vllm/model_executor/layers/logits_processor.py | 3 +--
.../layers/vocab_parallel_embedding.py | 8 +++-----
vllm/model_executor/models/baichuan.py | 4 ++--
vllm/model_executor/models/bloom.py | 4 ++--
vllm/model_executor/models/chatglm.py | 3 +--
vllm/model_executor/models/commandr.py | 4 ++--
vllm/model_executor/models/dbrx.py | 7 +++----
vllm/model_executor/models/deepseek.py | 7 +++----
vllm/model_executor/models/falcon.py | 7 +++----
vllm/model_executor/models/gemma.py | 3 +--
vllm/model_executor/models/gpt2.py | 3 +--
vllm/model_executor/models/gpt_bigcode.py | 3 +--
vllm/model_executor/models/gpt_j.py | 3 +--
vllm/model_executor/models/gpt_neox.py | 3 +--
vllm/model_executor/models/internlm2.py | 3 +--
vllm/model_executor/models/jais.py | 4 ++--
vllm/model_executor/models/llama.py | 4 ++--
vllm/model_executor/models/minicpm.py | 7 +++----
vllm/model_executor/models/mixtral.py | 7 +++----
vllm/model_executor/models/mixtral_quant.py | 7 +++----
vllm/model_executor/models/mpt.py | 4 ++--
vllm/model_executor/models/olmo.py | 3 +--
vllm/model_executor/models/opt.py | 3 +--
vllm/model_executor/models/orion.py | 3 +--
vllm/model_executor/models/phi.py | 3 +--
vllm/model_executor/models/qwen.py | 3 +--
vllm/model_executor/models/qwen2.py | 3 +--
vllm/model_executor/models/qwen2_moe.py | 7 +++----
vllm/model_executor/models/stablelm.py | 3 +--
vllm/model_executor/models/starcoder2.py | 3 +--
vllm/model_executor/models/xverse.py | 3 +--
vllm/model_executor/parallel_utils/README.md | 1 -
vllm/test_utils.py | 4 ++--
vllm/worker/cpu_worker.py | 7 +++----
vllm/worker/model_runner.py | 8 +++-----
vllm/worker/worker.py | 12 ++++++------
52 files changed, 111 insertions(+), 141 deletions(-)
create mode 100644 vllm/distributed/__init__.py
rename vllm/{model_executor/parallel_utils => distributed}/communication_op.py (94%)
rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/__init__.py (100%)
rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/custom_all_reduce.py (98%)
rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl.py (100%)
rename vllm/{model_executor/parallel_utils => distributed/device_communicators}/pynccl_utils.py (91%)
rename vllm/{model_executor/parallel_utils => distributed}/parallel_state.py (98%)
rename vllm/{model_executor/parallel_utils => distributed}/utils.py (100%)
delete mode 100644 vllm/model_executor/parallel_utils/README.md
diff --git a/tests/conftest.py b/tests/conftest.py
index e00f3eb871e37..a7e8963af0eda 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,8 +11,7 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
-from vllm.model_executor.parallel_utils.parallel_state import (
- destroy_model_parallel)
+from vllm.distributed import destroy_model_parallel
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer
diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py
index d1811cb694db6..aa9e0537c6910 100644
--- a/tests/distributed/test_comm_ops.py
+++ b/tests/distributed/test_comm_ops.py
@@ -8,9 +8,9 @@
import ray
import torch
-from vllm.model_executor.parallel_utils.communication_op import (
- broadcast_tensor_dict, tensor_model_parallel_all_gather,
- tensor_model_parallel_all_reduce)
+from vllm.distributed import (broadcast_tensor_dict,
+ tensor_model_parallel_all_gather,
+ tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py
index 1e6e7f89a528c..3b1cd1773af19 100644
--- a/tests/distributed/test_custom_all_reduce.py
+++ b/tests/distributed/test_custom_all_reduce.py
@@ -6,9 +6,8 @@
import torch
import torch.distributed as dist
-from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
+from vllm.distributed import tensor_model_parallel_all_reduce
+from vllm.distributed.device_communicators import custom_all_reduce
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)
- custom_ar.init_custom_ar()
+ custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
- with custom_ar.capture():
+ with custom_all_reduce.capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port)
sz = 1024
- custom_ar.init_custom_ar()
- fa = custom_ar.get_handle()
+ custom_all_reduce.init_custom_all_reduce()
+ fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py
index 29782045130a6..b50eed1c8c722 100644
--- a/tests/distributed/test_pynccl.py
+++ b/tests/distributed/test_pynccl.py
@@ -4,8 +4,8 @@
import pytest
import torch
-from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
- ncclGetUniqueId)
+from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
+ ncclGetUniqueId)
def distributed_run(fn, world_size):
diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py
index acb5fa91e2012..207c635e2dc86 100644
--- a/tests/lora/conftest.py
+++ b/tests/lora/conftest.py
@@ -12,6 +12,7 @@
import vllm
from vllm.config import LoRAConfig
+from vllm.distributed import destroy_model_parallel, initialize_model_parallel
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
@@ -19,8 +20,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
-from vllm.model_executor.parallel_utils.parallel_state import (
- destroy_model_parallel, initialize_model_parallel)
def cleanup():
diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py
new file mode 100644
index 0000000000000..db325cfabf55e
--- /dev/null
+++ b/vllm/distributed/__init__.py
@@ -0,0 +1,3 @@
+from .communication_op import *
+from .parallel_state import *
+from .utils import *
diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/distributed/communication_op.py
similarity index 94%
rename from vllm/model_executor/parallel_utils/communication_op.py
rename to vllm/distributed/communication_op.py
index 9cbb40708dd5b..cf15db099b304 100644
--- a/vllm/model_executor/parallel_utils/communication_op.py
+++ b/vllm/distributed/communication_op.py
@@ -4,12 +4,10 @@
import torch
from torch.distributed import ProcessGroup
-from vllm.model_executor.parallel_utils import pynccl_utils
-from vllm.model_executor.parallel_utils.custom_all_reduce import (
- custom_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
+from .parallel_state import (get_tensor_model_parallel_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
+ from vllm.distributed.device_communicators import pynccl_utils
+ from vllm.distributed.device_communicators.custom_all_reduce import (
+ custom_all_reduce)
+
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
diff --git a/vllm/model_executor/parallel_utils/__init__.py b/vllm/distributed/device_communicators/__init__.py
similarity index 100%
rename from vllm/model_executor/parallel_utils/__init__.py
rename to vllm/distributed/device_communicators/__init__.py
diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
similarity index 98%
rename from vllm/model_executor/parallel_utils/custom_all_reduce.py
rename to vllm/distributed/device_communicators/custom_all_reduce.py
index bf8ee07070c8a..84238d2e46076 100644
--- a/vllm/model_executor/parallel_utils/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -5,8 +5,6 @@
import torch.distributed as dist
from vllm.logger import init_logger
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
try:
import pynvml
@@ -25,6 +23,9 @@
def init_custom_ar() -> None:
+ from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
+
global _CA_HANDLE
if _CA_HANDLE is not None:
return
diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
similarity index 100%
rename from vllm/model_executor/parallel_utils/pynccl.py
rename to vllm/distributed/device_communicators/pynccl.py
diff --git a/vllm/model_executor/parallel_utils/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py
similarity index 91%
rename from vllm/model_executor/parallel_utils/pynccl_utils.py
rename to vllm/distributed/device_communicators/pynccl_utils.py
index a099777aa0005..aeb73015733d1 100644
--- a/vllm/model_executor/parallel_utils/pynccl_utils.py
+++ b/vllm/distributed/device_communicators/pynccl_utils.py
@@ -9,8 +9,8 @@
logger = init_logger(__name__)
try:
- from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
- ncclGetVersion)
+ from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
+ ncclGetVersion)
except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/distributed/parallel_state.py
similarity index 98%
rename from vllm/model_executor/parallel_utils/parallel_state.py
rename to vllm/distributed/parallel_state.py
index 3bbfa1bd5443a..4bb77146295af 100644
--- a/vllm/model_executor/parallel_utils/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -8,8 +8,6 @@
import torch
-from vllm.model_executor.parallel_utils import pynccl_utils
-
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
@@ -266,6 +264,7 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
+ from vllm.distributed.device_communicators import pynccl_utils
# Destroy the pynccl states if any.
pynccl_utils.destroy_process_group()
@@ -279,6 +278,7 @@ def destroy_model_parallel():
@contextlib.contextmanager
def with_pynccl_for_all_reduce():
+ from vllm.distributed.device_communicators import pynccl_utils
"""use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/distributed/utils.py
similarity index 100%
rename from vllm/model_executor/parallel_utils/utils.py
rename to vllm/distributed/utils.py
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 0505014753951..dd33868f76302 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -10,6 +10,12 @@
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ split_tensor_along_last_dim,
+ tensor_model_parallel_all_gather,
+ tensor_model_parallel_all_reduce,
+ tensor_model_parallel_gather)
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -18,13 +24,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
- tensor_model_parallel_gather)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from vllm.model_executor.parallel_utils.utils import (
- split_tensor_along_last_dim)
if TYPE_CHECKING:
pass
diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py
index f569a5a49cbdf..6786c48e0caba 100644
--- a/vllm/model_executor/layers/activation.py
+++ b/vllm/model_executor/layers/activation.py
@@ -7,10 +7,9 @@
import torch.nn.functional as F
from vllm._C import ops
+from vllm.distributed import (divide, get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index f3d4d1789db2d..8f42b3e8a4abe 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -5,13 +5,12 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter
+from vllm.distributed import (divide, get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ split_tensor_along_last_dim,
+ tensor_model_parallel_all_gather,
+ tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from vllm.model_executor.parallel_utils.utils import (
- divide, split_tensor_along_last_dim)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py
index ec531f79ced52..e556e31f99378 100644
--- a/vllm/model_executor/layers/logits_processor.py
+++ b/vllm/model_executor/layers/logits_processor.py
@@ -4,8 +4,7 @@
import torch
import torch.nn as nn
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_gather)
+from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.sampling_metadata import SamplingMetadata
diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py
index 73bbfac33ed13..088c0849243c0 100644
--- a/vllm/model_executor/layers/vocab_parallel_embedding.py
+++ b/vllm/model_executor/layers/vocab_parallel_embedding.py
@@ -4,11 +4,9 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from vllm.model_executor.parallel_utils.utils import divide
+from vllm.distributed import (divide, get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index fa5a27b5a6974..30588aecdebe9 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -27,6 +27,8 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -38,8 +40,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index a9ff909090586..40966ab33631a 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -24,6 +24,8 @@
from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -33,8 +35,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 4008896e48dd1..7b46ba306619a 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -10,6 +10,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -21,8 +22,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index 29ba3844eb11d..aa27f0a96c745 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -29,6 +29,8 @@
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
@@ -39,8 +41,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 14c0fece69214..49eb7f1b2c185 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -5,6 +5,9 @@
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
@@ -15,10 +18,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py
index 2a2182ff4ebad..c7dd11d07e6da 100644
--- a/vllm/model_executor/models/deepseek.py
+++ b/vllm/model_executor/models/deepseek.py
@@ -28,6 +28,9 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -41,10 +44,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index 77c19b227d213..4f1ebcd5fb43c 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -27,6 +27,9 @@
from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -37,10 +40,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 08609532b8b3e..fc1fc35570368 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -23,6 +23,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -35,8 +36,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 3f816a9996be5..43f0d47fcb122 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -24,6 +24,7 @@
from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -33,8 +34,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index 07c647c2e1c41..cec2d771adfa8 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -25,6 +25,7 @@
from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -34,8 +35,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index 94048efe48420..5660097652748 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -23,6 +23,7 @@
from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -33,8 +34,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index a5b5d717d9846..2f9e2171cf114 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -23,6 +23,7 @@
from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -33,8 +34,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index bdb48bf21042e..6e9cbd3f9f43f 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -6,6 +6,7 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -17,8 +18,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index 12fc9dbd50732..a041b0c9a0452 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -26,6 +26,8 @@
from torch import nn
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
@@ -34,8 +36,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 72fe21df67d8a..c86e292e7df1a 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -29,6 +29,8 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -40,8 +42,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator,
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 99d1b4eb97bb8..49eda9c9a8112 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -29,6 +29,9 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -42,10 +45,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index 429bc8109b9f8..ff552a9d86536 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -29,6 +29,9 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -40,10 +43,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py
index 75f86bc134ee3..1f0c0e912beea 100644
--- a/vllm/model_executor/models/mixtral_quant.py
+++ b/vllm/model_executor/models/mixtral_quant.py
@@ -30,6 +30,9 @@
from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
@@ -40,10 +43,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index a39f94359a948..af4cdce29d085 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -7,6 +7,8 @@
import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -16,8 +18,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 611a48a9aad2b..3513c72879102 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -44,6 +44,7 @@
from torch import nn
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -55,8 +56,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index c1ae1b2ae0f03..3a640850662c0 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -24,6 +24,7 @@
from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -34,8 +35,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index ee910563b20df..c606ac027e9d9 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -11,6 +11,7 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
@@ -21,8 +22,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index 40e068acaba7d..e91624da90955 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -42,6 +42,7 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -52,8 +53,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index a63b9c8d63d13..6213a2ded65ab 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -11,6 +11,7 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -22,8 +23,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 8c92cd773f6b9..796e30e633e85 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -30,6 +30,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -41,8 +42,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 6b4a74198fd52..f920b4f5a40c7 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -30,6 +30,9 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -43,10 +46,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.communication_op import (
- tensor_model_parallel_all_reduce)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index b83637fd50dc7..651598b770f13 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -26,6 +26,7 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
@@ -36,8 +37,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 50d23e0a3b6ef..76e8e48673413 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -25,6 +25,7 @@
from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
@@ -35,8 +36,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py
index 83d2ddb2bcf35..7e9ce9e5c8e15 100644
--- a/vllm/model_executor/models/xverse.py
+++ b/vllm/model_executor/models/xverse.py
@@ -28,6 +28,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@@ -39,8 +40,6 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
-from vllm.model_executor.parallel_utils.parallel_state import (
- get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
diff --git a/vllm/model_executor/parallel_utils/README.md b/vllm/model_executor/parallel_utils/README.md
deleted file mode 100644
index b25e3afddad9c..0000000000000
--- a/vllm/model_executor/parallel_utils/README.md
+++ /dev/null
@@ -1 +0,0 @@
-The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference.
\ No newline at end of file
diff --git a/vllm/test_utils.py b/vllm/test_utils.py
index bc220d3b8a430..0cf23e4bb7e75 100644
--- a/vllm/test_utils.py
+++ b/vllm/test_utils.py
@@ -1,7 +1,7 @@
import ray
-from vllm.model_executor.parallel_utils.parallel_state import (
- ensure_model_parallel_initialized, init_distributed_environment)
+from vllm.distributed import (ensure_model_parallel_initialized,
+ init_distributed_environment)
from vllm.utils import get_open_port
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 42f0828b826e2..751384eb72af3 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -7,13 +7,12 @@
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
+from vllm.distributed import (broadcast_tensor_dict,
+ ensure_model_parallel_initialized,
+ init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader import get_model
-from vllm.model_executor.parallel_utils.communication_op import (
- broadcast_tensor_dict)
-from vllm.model_executor.parallel_utils.parallel_state import (
- ensure_model_parallel_initialized, init_distributed_environment)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.model_runner import ModelRunner
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index e7f20475ab1a7..1de4748b7bcc9 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -9,17 +9,15 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
+from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
+from vllm.distributed.device_communicators import (custom_all_reduce,
+ pynccl_utils)
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
-from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
-from vllm.model_executor.parallel_utils.communication_op import (
- broadcast_tensor_dict)
-from vllm.model_executor.parallel_utils.parallel_state import (
- with_pynccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 19de33089b2db..3f0b2fd83f3e5 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -8,14 +8,14 @@
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
+from vllm.distributed import (broadcast_tensor_dict,
+ ensure_model_parallel_initialized,
+ init_distributed_environment)
+from vllm.distributed.device_communicators import pynccl_utils
+from vllm.distributed.device_communicators.custom_all_reduce import (
+ init_custom_ar)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
-from vllm.model_executor.parallel_utils import pynccl_utils
-from vllm.model_executor.parallel_utils.communication_op import (
- broadcast_tensor_dict)
-from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
-from vllm.model_executor.parallel_utils.parallel_state import (
- ensure_model_parallel_initialized, init_distributed_environment)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
From 67b4221a61ace91a79aff507df0a95a01978300e Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Thu, 11 Apr 2024 09:56:48 +0900
Subject: [PATCH 010/413] [Core][5/N] Fully working chunked prefill e2e (#3884)
---
.buildkite/test-pipeline.yaml | 2 +
benchmarks/benchmark_latency.py | 3 +-
benchmarks/benchmark_throughput.py | 62 ++--
.../basic_correctness/test_chunked_prefill.py | 70 ++++
tests/core/test_chunked_prefill_scheduler.py | 16 +-
.../test_basic_distributed_correctness.py | 7 +-
.../test_chunked_prefill_distributed.py | 66 ++++
tests/entrypoints/test_openai_server.py | 2 +-
tests/models/test_models.py | 2 +-
tests/worker/test_model_runner.py | 189 ++++++++--
vllm/attention/__init__.py | 4 +-
vllm/attention/backends/abstract.py | 42 ++-
vllm/attention/backends/flash_attn.py | 85 +++--
vllm/attention/backends/rocm_flash_attn.py | 97 ++++--
vllm/attention/backends/torch_sdpa.py | 67 ++--
vllm/attention/backends/xformers.py | 138 ++++----
vllm/attention/layer.py | 5 +-
vllm/attention/ops/paged_attn.py | 6 -
vllm/config.py | 13 +-
vllm/core/scheduler.py | 15 +-
vllm/distributed/communication_op.py | 10 +-
vllm/engine/arg_utils.py | 5 +-
vllm/engine/llm_engine.py | 5 +-
vllm/lora/layers.py | 5 +-
vllm/sequence.py | 3 +-
vllm/worker/model_runner.py | 323 +++++++++++++-----
26 files changed, 927 insertions(+), 315 deletions(-)
create mode 100644 tests/basic_correctness/test_chunked_prefill.py
create mode 100644 tests/distributed/test_chunked_prefill_distributed.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 27e44463a30a6..695290ed74ab5 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -29,6 +29,8 @@ steps:
- pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
+ - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
+ - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index 91510dafc57a5..aadbc441713fc 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -177,8 +177,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
- type=bool,
- default=False,
+ action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index e71338273d1e5..6df1e1d628e6c 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
+ enable_chunked_prefill: bool,
+ max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
- llm = LLM(model=model,
- tokenizer=tokenizer,
- quantization=quantization,
- tensor_parallel_size=tensor_parallel_size,
- seed=seed,
- trust_remote_code=trust_remote_code,
- dtype=dtype,
- max_model_len=max_model_len,
- gpu_memory_utilization=gpu_memory_utilization,
- enforce_eager=enforce_eager,
- kv_cache_dtype=kv_cache_dtype,
- quantization_param_path=quantization_param_path,
- device=device,
- enable_prefix_caching=enable_prefix_caching,
- download_dir=download_dir)
+ llm = LLM(
+ model=model,
+ tokenizer=tokenizer,
+ quantization=quantization,
+ tensor_parallel_size=tensor_parallel_size,
+ seed=seed,
+ trust_remote_code=trust_remote_code,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ gpu_memory_utilization=gpu_memory_utilization,
+ enforce_eager=enforce_eager,
+ kv_cache_dtype=kv_cache_dtype,
+ quantization_param_path=quantization_param_path,
+ device=device,
+ enable_prefix_caching=enable_prefix_caching,
+ download_dir=download_dir,
+ enable_chunked_prefill=enable_chunked_prefill,
+ max_num_batched_tokens=max_num_batched_tokens,
+ )
# Add the requests to the engine.
for prompt, _, output_len in requests:
@@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args.output_len)
if args.backend == "vllm":
- elapsed_time = run_vllm(requests, args.model, args.tokenizer,
- args.quantization, args.tensor_parallel_size,
- args.seed, args.n, args.use_beam_search,
- args.trust_remote_code, args.dtype,
- args.max_model_len, args.enforce_eager,
- args.kv_cache_dtype,
- args.quantization_param_path, args.device,
- args.enable_prefix_caching,
- args.gpu_memory_utilization, args.download_dir)
+ elapsed_time = run_vllm(
+ requests, args.model, args.tokenizer, args.quantization,
+ args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
+ args.trust_remote_code, args.dtype, args.max_model_len,
+ args.enforce_eager, args.kv_cache_dtype,
+ args.quantization_param_path, args.device,
+ args.enable_prefix_caching, args.enable_chunked_prefill,
+ args.max_num_batched_tokens, args.gpu_memory_utilization,
+ args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -335,6 +341,14 @@ def main(args: argparse.Namespace):
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
+ parser.add_argument("--enable-chunked-prefill",
+ action='store_true',
+ help="enable chunked prefill for vLLM backend.")
+ parser.add_argument('--max-num-batched-tokens',
+ type=int,
+ default=None,
+ help='maximum number of batched tokens per '
+ 'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py
new file mode 100644
index 0000000000000..9ff07b3c09020
--- /dev/null
+++ b/tests/basic_correctness/test_chunked_prefill.py
@@ -0,0 +1,70 @@
+"""Compare the outputs of HF and vLLM when using greedy sampling.
+
+It tests chunked prefill. Chunked prefill can be enabled by
+enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
+prefill requests are chunked.
+
+Run `pytest tests/models/test_chunked_prefill.py`.
+"""
+import pytest
+
+MODELS = [
+ "facebook/opt-125m",
+ "meta-llama/Llama-2-7b-hf",
+]
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [32])
+@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
+@pytest.mark.parametrize("enforce_eager", [False, True])
+# NOTE: Increasing this in this suite will fail CI because we currently cannot
+# reset distributed env properly. Use a value > 1 just when you test.
+@pytest.mark.parametrize("tensor_parallel_size", [1])
+def test_models(
+ hf_runner,
+ vllm_runner,
+ example_prompts,
+ model: str,
+ dtype: str,
+ max_tokens: int,
+ chunked_prefill_token_size: int,
+ enforce_eager: bool,
+ tensor_parallel_size: int,
+) -> None:
+ if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
+ and not enforce_eager):
+ pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
+ "for high TP to save testing time.")
+ max_num_seqs = min(chunked_prefill_token_size, 256)
+ enable_chunked_prefill = False
+ max_num_batched_tokens = None
+ if chunked_prefill_token_size != -1:
+ enable_chunked_prefill = True
+ max_num_batched_tokens = chunked_prefill_token_size
+
+ hf_model = hf_runner(model, dtype=dtype)
+ hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+ del hf_model
+
+ vllm_model = vllm_runner(
+ model,
+ dtype=dtype,
+ max_num_batched_tokens=max_num_batched_tokens,
+ enable_chunked_prefill=enable_chunked_prefill,
+ tensor_parallel_size=tensor_parallel_size,
+ enforce_eager=enforce_eager,
+ max_num_seqs=max_num_seqs,
+ )
+ vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
+ del vllm_model
+ print(vllm_outputs[0])
+
+ for i in range(len(example_prompts)):
+ hf_output_ids, hf_output_str = hf_outputs[i]
+ vllm_output_ids, vllm_output_str = vllm_outputs[i]
+ assert hf_output_str == vllm_output_str, (
+ f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
+ assert hf_output_ids == vllm_output_ids, (
+ f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py
index 05e62ced5898f..cce396bf4953c 100644
--- a/tests/core/test_chunked_prefill_scheduler.py
+++ b/tests/core/test_chunked_prefill_scheduler.py
@@ -104,10 +104,10 @@ def test_chunk():
# One chunked prefill, and one decoding.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
- # The first one is decoding.
- assert seq_group_meta[0].token_chunk_size == 1
+ # The first one is prefill. Scheduler guarantees ordering.
+ assert seq_group_meta[0].token_chunk_size == 56
# The second one is a chunked prefill.
- assert seq_group_meta[1].token_chunk_size == 56
+ assert seq_group_meta[1].token_chunk_size == 1
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 57
@@ -157,12 +157,12 @@ def test_complex():
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 3
- # The first one is decoding.
- assert seq_group_meta[0].token_chunk_size == 1
- # The second one is a chunked prefill.
+ # The first one is the first chunked prefill.
+ assert seq_group_meta[0].token_chunk_size == 7
+ # The second one is the second new chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56
- # The third one is also chunked.
- assert seq_group_meta[2].token_chunk_size == 7
+ # The last one is decode.
+ assert seq_group_meta[2].token_chunk_size == 1
# Two of them are in chunked prefill.
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py
index 1eba14d7a6422..77aa90b12bf8f 100644
--- a/tests/distributed/test_basic_distributed_correctness.py
+++ b/tests/distributed/test_basic_distributed_correctness.py
@@ -33,11 +33,16 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
+
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
- vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
+ vllm_model = vllm_runner(
+ model,
+ dtype=dtype,
+ tensor_parallel_size=2,
+ )
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py
new file mode 100644
index 0000000000000..737b1f3169519
--- /dev/null
+++ b/tests/distributed/test_chunked_prefill_distributed.py
@@ -0,0 +1,66 @@
+"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
+vLLM will allocate all the available memory, so we need to run the tests one
+by one. The solution is to pass arguments (model name) by environment
+variables.
+
+Run:
+```sh
+TEST_DIST_MODEL=facebook/opt-125m pytest \
+ test_chunked_prefill_distributed.py
+TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
+ test_chunked_prefill_distributed.py
+```
+"""
+import os
+
+import pytest
+import torch
+
+MODELS = [
+ os.environ["TEST_DIST_MODEL"],
+]
+
+
+@pytest.mark.skipif(torch.cuda.device_count() < 2,
+ reason="Need at least 2 GPUs to run the test.")
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [5])
+@pytest.mark.parametrize("chunked_prefill_token_size", [16])
+def test_models(
+ hf_runner,
+ vllm_runner,
+ example_prompts,
+ model: str,
+ dtype: str,
+ max_tokens: int,
+ chunked_prefill_token_size: int,
+) -> None:
+ # Add a chunked prefill config.
+ max_num_seqs = min(chunked_prefill_token_size, 256)
+ assert chunked_prefill_token_size != -1
+ enable_chunked_prefill = True
+ max_num_batched_tokens = chunked_prefill_token_size
+
+ hf_model = hf_runner(model, dtype=dtype)
+ hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
+ del hf_model
+
+ vllm_model = vllm_runner(
+ model,
+ dtype=dtype,
+ tensor_parallel_size=2,
+ max_num_seqs=max_num_seqs,
+ enable_chunked_prefill=enable_chunked_prefill,
+ max_num_batched_tokens=max_num_batched_tokens,
+ )
+ vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
+ del vllm_model
+
+ for i in range(len(example_prompts)):
+ hf_output_ids, hf_output_str = hf_outputs[i]
+ vllm_output_ids, vllm_output_str = vllm_outputs[i]
+ assert hf_output_str == vllm_output_str, (
+ f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
+ assert hf_output_ids == vllm_output_ids, (
+ f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index 442f8bdf3b4ba..6f2086c4dd269 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras",
"2",
"--max-num-seqs",
- "128"
+ "128",
])
ray.get(server_runner.ready.remote())
yield server_runner
diff --git a/tests/models/test_models.py b/tests/models/test_models.py
index 53a80d4619646..cfe2539e3a052 100644
--- a/tests/models/test_models.py
+++ b/tests/models/test_models.py
@@ -12,7 +12,7 @@
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/pythia-70m",
- "bigscience/bloom-560m",
+ "bigscience/bloom-560m", # Testing alibi slopes.
"microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken
diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py
index 5b6f001f62fa7..dcaae4af4a6f8 100644
--- a/tests/worker/test_model_runner.py
+++ b/tests/worker/test_model_runner.py
@@ -1,14 +1,18 @@
import pytest
import torch
-from vllm.config import ModelConfig
+from vllm.config import ModelConfig, SchedulerConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
- model_runner = ModelRunner(None, None, None, None, None)
+ scheduler_config = SchedulerConfig(100000,
+ 100000,
+ 100000,
+ enable_chunked_prefill=False)
+ model_runner = ModelRunner(None, None, scheduler_config, None, None)
model_runner.set_block_size(16)
prompt_lens = []
@@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
- _, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
+ _, _,
+ slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
+ assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts.
device = model_runner.device
@@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert attn_metadata.prompt_lens == prompt_lens
- assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
- assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs.
@@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False
- assert attn_metadata.kv_cache_dtype == "auto"
- assert input_tokens.shape == (sum(prompt_lens), )
- assert input_positions.shape == (sum(prompt_lens), )
+ assert len(input_tokens) == sum(prompt_lens)
+ assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
- assert input_tokens.shape == (sum(prompt_lens), )
- assert input_positions.shape == (sum(prompt_lens), )
+ assert len(input_tokens) == sum(prompt_lens)
+ assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
- torch.testing.assert_close(input_tokens, input_positions)
+ assert input_tokens == input_positions
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
@@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
revision=None,
enforce_eager=False,
)
- model_runner = ModelRunner(model_config, None, None, None, None)
+ scheduler_config = SchedulerConfig(100000,
+ 100000,
+ 100000,
+ enable_chunked_prefill=False)
+ model_runner = ModelRunner(model_config, None, scheduler_config, None,
+ None)
model_runner.set_block_size(16)
prompt_lens = []
@@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
- input_tokens, input_positions, attn_metadata, _, _, _ = (
+ input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
+ assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None
- assert attn_metadata.num_prompt_tokens == 0
- assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
@@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True
- assert attn_metadata.kv_cache_dtype == "auto"
- assert input_tokens.shape == (expected_bs, )
- assert input_positions.shape == (expected_bs, )
- torch.testing.assert_close(input_tokens, input_positions)
+ assert len(input_tokens) == expected_bs
+ assert len(input_positions) == expected_bs
+ assert input_tokens == input_positions
# Verify Sampling
expected_selected_token_indices = []
@@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
+
+
+def test_empty_seq_group():
+ """Verify prepare prompt and decode returns empty output."""
+ model_config = ModelConfig(
+ "facebook/opt-125m",
+ "facebook/opt-125m",
+ tokenizer_mode="auto",
+ trust_remote_code=False,
+ download_dir=None,
+ load_format="dummy",
+ seed=0,
+ dtype="float16",
+ revision=None,
+ enforce_eager=False,
+ )
+ model_runner = ModelRunner(model_config, None, None, None, None)
+ model_runner.set_block_size(16)
+ seq_group_metadata_list = []
+ input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
+ model_runner._prepare_decode(seq_group_metadata_list))
+ assert len(input_tokens) == 0
+ assert len(input_positions) == 0
+ assert attn_metadata is None
+ assert len(slot_mapping) == 0
+
+ (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
+ _, _,
+ slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
+ assert len(input_tokens) == 0
+ assert len(input_positions) == 0
+ assert attn_metadata is None
+ assert len(slot_mapping) == 0
+ assert len(return_prompt_lens) == 0
+
+
+@pytest.mark.parametrize("batch_size", list(range(2, 128)))
+@pytest.mark.parametrize("enforce_eager", [True, False])
+def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
+
+ def get_world_size(group=None):
+ return 1
+
+ def mock_get_process_group_ranks(group=None):
+ return [0]
+
+ monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
+ monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
+ mock_get_process_group_ranks)
+
+ model_config = ModelConfig(
+ "facebook/opt-125m",
+ "facebook/opt-125m",
+ tokenizer_mode="auto",
+ trust_remote_code=False,
+ download_dir=None,
+ load_format="dummy",
+ seed=0,
+ dtype="float16",
+ revision=None,
+ enforce_eager=enforce_eager,
+ )
+ scheduler_config = SchedulerConfig(100000,
+ 100000,
+ 100000,
+ enable_chunked_prefill=True)
+ model_runner = ModelRunner(model_config,
+ None,
+ scheduler_config,
+ None,
+ None,
+ is_driver_worker=True)
+ model_runner.set_block_size(16)
+
+ # Add prefill requests.
+ prompt_lens = []
+ seq_group_metadata_list = []
+ prefill_metadata_list = []
+ decode_metadata_list = []
+ block_tables = {0: [1]}
+ prefill_batch_size = batch_size // 2
+ decode_batch_size = batch_size - prefill_batch_size
+ for i in range(prefill_batch_size):
+ # make sure all tokens fit into one block
+ prompt_len = i % (model_runner.block_size - 1) + 1
+ prompt_lens.append(prompt_len)
+ seq_data = SequenceData(list(range(prompt_len)))
+ seq_group_metadata = SequenceGroupMetadata(
+ request_id=f"test_{i}",
+ is_prompt=True,
+ seq_data={0: seq_data},
+ sampling_params=SamplingParams(temperature=0),
+ block_tables=block_tables,
+ )
+ assert seq_group_metadata.token_chunk_size == seq_data.get_len()
+ seq_group_metadata_list.append(seq_group_metadata)
+ prefill_metadata_list.append(seq_group_metadata)
+
+ # Add decode requests
+ for i in range(prefill_batch_size, batch_size):
+ # make sure all tokens fit into one block
+ prompt_len = i % (model_runner.block_size - 1) + 1
+ prompt_toks = list(range(prompt_len))
+ seq_data = SequenceData(prompt_toks)
+ seq_group_metadata = SequenceGroupMetadata(
+ request_id=f"test_{i}",
+ is_prompt=False,
+ seq_data={0: seq_data},
+ sampling_params=SamplingParams(temperature=0),
+ block_tables={0: [1]},
+ )
+ assert seq_group_metadata.token_chunk_size == 1
+ seq_group_metadata_list.append(seq_group_metadata)
+ decode_metadata_list.append(seq_group_metadata)
+
+ (input_tokens, input_positions, attn_metadata, _, _, _,
+ _) = model_runner.prepare_input_tensors(seq_group_metadata_list)
+
+ prefill_meta_actual = attn_metadata.prefill_metadata
+ decode_meta_actual = attn_metadata.decode_metadata
+
+ assert len(attn_metadata.slot_mapping) == len(input_tokens)
+ assert len(input_positions) == len(input_tokens)
+ assert attn_metadata.kv_cache_dtype == "auto"
+ assert attn_metadata.num_prefills == prefill_batch_size
+ if enforce_eager:
+ assert attn_metadata.num_decode_tokens == decode_batch_size
+ else:
+ assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
+ decode_batch_size)
+ assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
+
+ # Verify attn metadata is consistent. We don't need to test individual
+ # values here because they are tested above.
+ prefill_meta = model_runner._prepare_prompt(
+ prefill_metadata_list).attn_metadata
+ decode_meta = model_runner._prepare_decode(
+ decode_metadata_list).attn_metadata
+
+ for attr_expected, attr_actual in zip(vars(prefill_meta),
+ vars(prefill_meta_actual)):
+ assert attr_expected[1] == attr_actual[1]
+ for attr_expected, attr_actual in zip(vars(decode_meta),
+ vars(decode_meta_actual)):
+ assert attr_expected[1] == attr_actual[1]
diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py
index 9acb82c0df2c2..7636b34a16fed 100644
--- a/vllm/attention/__init__.py
+++ b/vllm/attention/__init__.py
@@ -1,5 +1,6 @@
from vllm.attention.backends.abstract import (AttentionBackend,
- AttentionMetadata)
+ AttentionMetadata,
+ AttentionMetadataPerStage)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
@@ -8,4 +9,5 @@
"AttentionMetadata",
"Attention",
"get_attn_backend",
+ "AttentionMetadataPerStage",
]
diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py
index a03cf2dd7a6fa..7a4ccecf702f4 100644
--- a/vllm/attention/backends/abstract.py
+++ b/vllm/attention/backends/abstract.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
import torch
@@ -47,7 +47,8 @@ def copy_blocks(
@dataclass
-class AttentionMetadata:
+class AttentionMetadataPerStage:
+ """Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
@@ -59,6 +60,41 @@ def asdict_zerocopy(self) -> Dict[str, Any]:
}
+T = TypeVar("T", bound=AttentionMetadataPerStage)
+
+
+@dataclass
+class AttentionMetadata(Generic[T]):
+ """Attention metadata for prefill and decode batched together."""
+ # Total number of prefill requests.
+ num_prefills: int
+ # Number of prefill tokens.
+ num_prefill_tokens: int
+ # Number of decode tokens. Note that it is equivalent to the number of
+ # decode requests.
+ num_decode_tokens: int
+ # The attention metadata for prefill requests in a batch.
+ # None if there's no prefill requests in a batch.
+ prefill_metadata: Optional[T]
+ # The attention metadata for decode requests in a batch.
+ # None if there's no decode requests in a batch.
+ decode_metadata: Optional[T]
+ # (num_tokens,). The indices of the token slots that input tokens will be
+ # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
+ # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
+ # in block 0, and 1st slot in block 1, respectively.
+ slot_mapping: torch.Tensor
+ # The kv cache's data type.
+ kv_cache_dtype: str
+
+ def __post_init__(self):
+ if self.num_prefill_tokens > 0:
+ assert self.num_prefills > 0
+ assert self.prefill_metadata is not None
+ if self.num_decode_tokens > 0:
+ assert self.decode_metadata is not None
+
+
class AttentionImpl(ABC):
@abstractmethod
@@ -80,7 +116,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
+ attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError
diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py
index 4e0d9d1418b32..12e8c4404b94e 100644
--- a/vllm/attention/backends/flash_attn.py
+++ b/vllm/attention/backends/flash_attn.py
@@ -11,7 +11,8 @@
from flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
- AttentionMetadata)
+ AttentionMetadata,
+ AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@@ -53,7 +54,8 @@ def copy_blocks(
@dataclass
-class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
+class FlashAttentionMetadata(AttentionMetadataPerStage,
+ PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
@@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
- # The number of prompt tokens. Doesn't include padding.
- num_prompt_tokens: int
- # The number of generation tokens. Doesn't include padding.
- num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
@@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
- |<--------------- num_prompt_tokens -------------->|
- |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+ |<--------------- num_prefill_tokens ----------------->|
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
- |<------------------ num_generation_tokens (M) ----------------->|
- |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+ |<----------------- num_decode_tokens ------------------>|
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
+
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
+ batched together in a flattened 1D query.
+
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
+ |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
+
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
+ padding between prefill and decode tokens.
"""
def __init__(
@@ -155,7 +162,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
- attn_metadata: FlashAttentionMetadata,
+ attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@@ -188,52 +195,70 @@ def forward(
attn_metadata.kv_cache_dtype,
kv_scale)
- if attn_metadata.is_prompt:
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+ output = torch.empty_like(query)
+ # Query for decode. KV is not needed because it is already cached.
+ decode_query = query[num_prefill_tokens:]
+ # QKV for prefill.
+ query = query[:num_prefill_tokens]
+ key = key[:num_prefill_tokens]
+ value = value[:num_prefill_tokens]
+
+ assert query.shape[0] == num_prefill_tokens
+ assert decode_query.shape[0] == num_decode_tokens
+
+ if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
- if kv_cache is None or attn_metadata.block_tables.numel() == 0:
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
- output = flash_attn_varlen_func(
+ out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
- cu_seqlens_q=attn_metadata.seq_start_loc,
- cu_seqlens_k=attn_metadata.seq_start_loc,
- max_seqlen_q=attn_metadata.max_prompt_len,
- max_seqlen_k=attn_metadata.max_prompt_len,
+ cu_seqlens_q=prefill_meta.seq_start_loc,
+ cu_seqlens_k=prefill_meta.seq_start_loc,
+ max_seqlen_q=prefill_meta.max_prompt_len,
+ max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
- output = PagedAttention.forward_prefix(
+ output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.subquery_start_loc,
- attn_metadata.prompt_lens_tensor,
- attn_metadata.context_lens,
- attn_metadata.max_subquery_len,
+ prefill_meta.block_tables,
+ prefill_meta.subquery_start_loc,
+ prefill_meta.prompt_lens_tensor,
+ prefill_meta.context_lens,
+ prefill_meta.max_subquery_len,
self.alibi_slopes,
)
- else:
+ if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
- output = PagedAttention.forward_decode(
- query,
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
+ decode_query,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.context_lens,
- attn_metadata.max_context_len,
+ decode_meta.block_tables,
+ decode_meta.context_lens,
+ decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py
index 6019d917b4494..e55435cd2c947 100644
--- a/vllm/attention/backends/rocm_flash_attn.py
+++ b/vllm/attention/backends/rocm_flash_attn.py
@@ -6,7 +6,8 @@
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
- AttentionMetadata)
+ AttentionMetadata,
+ AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@@ -51,7 +52,8 @@ def copy_blocks(
@dataclass
-class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
+class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
+ PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
@@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
- # The number of prompt tokens. Doesn't include padding.
- num_prompt_tokens: int
- # The number of generation tokens. Doesn't include padding.
- num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
@@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
The prompts might have different lengths, while the generation tokens
always have length 1.
+
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
+ batched together in a flattened 1D query.
+
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
+ |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
+
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
+ padding between prefill and decode tokens.
"""
def __init__(
@@ -181,7 +188,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
- attn_metadata: ROCmFlashAttentionMetadata,
+ attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@@ -218,9 +225,25 @@ def forward(
kv_scale,
)
- if attn_metadata.is_prompt:
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+ output = torch.empty_like(query)
+ # Query for decode. KV is not needed because it is already cached.
+ decode_query = query[num_prefill_tokens:]
+ # QKV for prefill.
+ query = query[:num_prefill_tokens]
+ key = key[:num_prefill_tokens]
+ value = value[:num_prefill_tokens]
+
+ assert query.shape[0] == num_prefill_tokens
+ assert decode_query.shape[0] == num_decode_tokens
+
+ if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
- if kv_cache is None or attn_metadata.block_tables.numel() == 0:
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
@@ -230,63 +253,69 @@ def forward(
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
- output = self.attn_fuc(
+ out = self.attn_fuc(
query,
key,
value,
- attn_metadata.prompt_lens,
+ prefill_meta.prompt_lens,
self.scale,
)
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
else:
- output, _ = self.attn_func(
+ out, _ = self.attn_func(
query,
key,
value,
None,
- attn_metadata.seq_start_loc,
- attn_metadata.seq_start_loc,
- attn_metadata.max_prompt_len,
- attn_metadata.max_prompt_len,
+ prefill_meta.seq_start_loc,
+ prefill_meta.seq_start_loc,
+ prefill_meta.max_prompt_len,
+ prefill_meta.max_prompt_len,
True,
self.scale,
)
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
else:
- output = self.attn_func(
+ out = self.attn_func(
q=query,
k=key,
v=value,
- cu_seqlens_q=attn_metadata.seq_start_loc,
- cu_seqlens_k=attn_metadata.seq_start_loc,
- max_seqlen_q=attn_metadata.max_prompt_len,
- max_seqlen_k=attn_metadata.max_prompt_len,
+ cu_seqlens_q=prefill_meta.seq_start_loc,
+ cu_seqlens_k=prefill_meta.seq_start_loc,
+ max_seqlen_q=prefill_meta.max_prompt_len,
+ max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale,
causal=True,
)
-
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
- output = PagedAttention.forward_prefix(
+ output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.subquery_start_loc,
- attn_metadata.prompt_lens_tensor,
- attn_metadata.context_lens,
- attn_metadata.max_subquery_len,
+ prefill_meta.block_tables,
+ prefill_meta.subquery_start_loc,
+ prefill_meta.prompt_lens_tensor,
+ prefill_meta.context_lens,
+ prefill_meta.max_subquery_len,
self.alibi_slopes,
)
- else:
+
+ if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
- output = PagedAttention.forward_decode(
- query,
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
+ decode_query,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.context_lens,
- attn_metadata.max_context_len,
+ decode_meta.block_tables,
+ decode_meta.context_lens,
+ decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py
index 9706e1910cb79..63904ea929870 100644
--- a/vllm/attention/backends/torch_sdpa.py
+++ b/vllm/attention/backends/torch_sdpa.py
@@ -7,7 +7,8 @@
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
- AttentionMetadata)
+ AttentionMetadata,
+ AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@@ -49,17 +50,14 @@ def copy_blocks(
@dataclass
-class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
+class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
- slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor]
- num_prompt_tokens: int
- num_generation_tokens: int
max_subquery_len: Optional[int] = None
max_prompt_len: Optional[int] = None
@@ -113,7 +111,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
- attn_metadata: TorchSDPAMetadata,
+ attn_metadata: AttentionMetadata[TorchSDPAMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@@ -142,36 +140,51 @@ def forward(
attn_metadata.kv_cache_dtype,
kv_scale)
- if attn_metadata.is_prompt:
- if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+ output = torch.empty_like(query)
+ # Query for decode. KV is not needed because it is already cached.
+ decode_query = query[num_prefill_tokens:]
+ # QKV for prefill.
+ query = query[:num_prefill_tokens]
+ key = key[:num_prefill_tokens]
+ value = value[:num_prefill_tokens]
+
+ assert query.shape[0] == num_prefill_tokens
+ assert decode_query.shape[0] == num_decode_tokens
+
+ if prefill_meta := attn_metadata.prefill_metadata:
+ if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
- if attn_metadata.attn_bias is None:
+ if prefill_meta.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
- attn_metadata.prompt_lens) # type: ignore
+ prefill_meta.prompt_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
- attn_metadata.prompt_lens, self.sliding_window,
+ prefill_meta.prompt_lens, self.sliding_window,
query.dtype) # type: ignore
else:
- att_masks = [None] * len(attn_metadata.prompt_lens)
- attn_metadata.attn_bias = att_masks
+ att_masks = [None] * len(prefill_meta.prompt_lens)
+ prefill_meta.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
- output = torch.empty(
- (num_tokens, self.num_heads, self.head_size),
- dtype=query.dtype)
- for prompt_len, mask in zip(attn_metadata.prompt_lens,
- attn_metadata.attn_bias):
+ out = torch.empty((num_tokens, self.num_heads, self.head_size),
+ dtype=query.dtype)
+ for prompt_len, mask in zip(prefill_meta.prompt_lens,
+ prefill_meta.attn_bias):
end = start + prompt_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
@@ -181,28 +194,32 @@ def forward(
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
- output[start:end, :, :] = sub_out
+ out[start:end, :, :] = sub_out
start = end
+ assert out.shape == output[:num_prefill_tokens].shape
+ output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
- else:
+ if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
- output = PagedAttention.forward_decode(
- query,
+ out = PagedAttention.forward_decode(
+ decode_query,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.context_lens,
- attn_metadata.max_context_len,
+ decode_meta.block_tables,
+ decode_meta.context_lens,
+ decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
+ assert out.shape == output[num_prefill_tokens:].shape
+ output[num_prefill_tokens:]
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py
index 05b68bba5e6eb..b745a04a143b4 100644
--- a/vllm/attention/backends/xformers.py
+++ b/vllm/attention/backends/xformers.py
@@ -9,7 +9,8 @@
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
- AttentionMetadata)
+ AttentionMetadata,
+ AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@@ -54,7 +55,7 @@ def copy_blocks(
@dataclass
-class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
+class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
@@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
- # (num_tokens,). The indices of the token slots that input tokens will be
- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
- # in block 0, and 1st slot in block 1, respectively.
- slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
- # The number of prompt tokens. Doesn't include padding.
- num_prompt_tokens: int
- # The number of generation tokens. Doesn't include padding.
- num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
@@ -123,18 +115,27 @@ def __post_init__(self):
class XFormersImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
- |<--------------- num_prompt_tokens --------------->|
- |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
+ |<--------------- num_prefill_tokens ----------------->|
+ |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
- |<------------------ num_generation_tokens (M) ----------------->|
- |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+ |<----------------- num_decode_tokens ------------------>|
+ |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
+
+ If chunked prefill is enabled, prefill tokens and decode tokens can be
+ batched together in a flattened 1D query.
+
+ |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
+ |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
+
+ Currently, cuda graph is disabled for chunked prefill, meaning there's no
+ padding between prefill and decode tokens.
"""
def __init__(
@@ -170,7 +171,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
- attn_metadata: XFormersMetadata,
+ attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@@ -202,59 +203,61 @@ def forward(
attn_metadata.kv_cache_dtype,
kv_scale)
- if attn_metadata.is_prompt:
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_decode_tokens = attn_metadata.num_decode_tokens
+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens
+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens
+
+ output = torch.empty_like(query)
+ # Query for decode. KV is not needed because it is already cached.
+ decode_query = query[num_prefill_tokens:]
+ # QKV for prefill.
+ query = query[:num_prefill_tokens]
+ key = key[:num_prefill_tokens]
+ value = value[:num_prefill_tokens]
+
+ assert query.shape[0] == num_prefill_tokens
+ assert decode_query.shape[0] == num_decode_tokens
+
+ if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
- if kv_cache is None or attn_metadata.block_tables.numel() == 0:
+ if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
- if self.num_kv_heads != self.num_heads:
- # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
- # project the key and value tensors to the desired number of
- # heads.
- # TODO(woosuk): Use MQA/GQA kernels for higher performance.
- query = query.view(query.shape[0], self.num_kv_heads,
- self.num_queries_per_kv,
- query.shape[-1])
- key = key[:, :,
- None, :].expand(key.shape[0], self.num_kv_heads,
- self.num_queries_per_kv,
- key.shape[-1])
- value = value[:, :,
- None, :].expand(value.shape[0],
- self.num_kv_heads,
- self.num_queries_per_kv,
- value.shape[-1])
-
- output = self._run_memory_efficient_xformers_forward(
- query, key, value, attn_metadata)
+ out = self._run_memory_efficient_xformers_forward(
+ query, key, value, prefill_meta)
+ assert out.shape == output[:num_prefill_tokens].shape
+ output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
- output = PagedAttention.forward_prefix(
+ out = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.subquery_start_loc,
- attn_metadata.prompt_lens_tensor,
- attn_metadata.context_lens,
- attn_metadata.max_subquery_len,
+ prefill_meta.block_tables,
+ prefill_meta.subquery_start_loc,
+ prefill_meta.prompt_lens_tensor,
+ prefill_meta.context_lens,
+ prefill_meta.max_subquery_len,
self.alibi_slopes,
)
- else:
- # Decoding run.
- output = PagedAttention.forward_decode(
- query,
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
+
+ if decode_meta := attn_metadata.decode_metadata:
+ output[num_prefill_tokens:] = PagedAttention.forward_decode(
+ decode_query,
key_cache,
value_cache,
- attn_metadata.block_tables,
- attn_metadata.context_lens,
- attn_metadata.max_context_len,
+ decode_meta.block_tables,
+ decode_meta.context_lens,
+ decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@@ -275,13 +278,30 @@ def _run_memory_efficient_xformers_forward(
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
+ See https://facebookresearch.github.io/xformers/components/ops.html
+ for API spec.
+
Args:
- output: shape = [num_prompt_tokens, num_heads, head_size]
- query: shape = [num_prompt_tokens, num_heads, head_size]
- key: shape = [num_prompt_tokens, num_kv_heads, head_size]
- value: shape = [num_prompt_tokens, num_kv_heads, head_size]
+ output: shape = [num_prefill_tokens, num_heads, head_size]
+ query: shape = [num_prefill_tokens, num_heads, head_size]
+ key: shape = [num_prefill_tokens, num_kv_heads, head_size]
+ value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
+ original_query = query
+ if self.num_kv_heads != self.num_heads:
+ # GQA/MQA requires the shape [B, M, G, H, K].
+ # Note that the output also has the same shape (which is different
+ # from a spec from the doc).
+ query = query.view(query.shape[0], self.num_kv_heads,
+ self.num_queries_per_kv, query.shape[-1])
+ key = key[:, :,
+ None, :].expand(key.shape[0], self.num_kv_heads,
+ self.num_queries_per_kv, key.shape[-1])
+ value = value[:, :,
+ None, :].expand(value.shape[0], self.num_kv_heads,
+ self.num_queries_per_kv,
+ value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
@@ -302,6 +322,7 @@ def _run_memory_efficient_xformers_forward(
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
+ # Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
@@ -312,14 +333,13 @@ def _run_memory_efficient_xformers_forward(
attn_bias=attn_metadata.attn_bias[0],
p=0.0,
scale=self.scale)
-
- return out.view_as(query)
+ return out.view_as(original_query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
- output = torch.empty_like(query)
+ output = torch.empty_like(original_query)
start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len
@@ -331,7 +351,7 @@ def _run_memory_efficient_xformers_forward(
p=0.0,
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
- output[start:end].copy_(out.squeeze(0))
+ output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len
return output
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index 9856654fc5f94..fc65ae108dbb1 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -4,7 +4,8 @@
import torch
import torch.nn as nn
-from vllm.attention.backends.abstract import AttentionMetadata
+from vllm.attention.backends.abstract import (AttentionMetadata,
+ AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend
@@ -41,7 +42,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
- attn_metadata: AttentionMetadata,
+ attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py
index 256bffdf032eb..2d918491d6576 100644
--- a/vllm/attention/ops/paged_attn.py
+++ b/vllm/attention/ops/paged_attn.py
@@ -13,11 +13,6 @@
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
- # (num_tokens,). The indices of the token slots that input tokens will be
- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
- # in block 0, and 1st slot in block 1, respectively.
- slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
@@ -31,7 +26,6 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
- kv_cache_dtype: str
class PagedAttention:
diff --git a/vllm/config.py b/vllm/config.py
index bca250e922288..4102edbe01d35 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -565,9 +565,16 @@ def __init__(
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
- # If max_model_len is too short, use 2048 as the default value for
- # higher throughput.
- self.max_num_batched_tokens = max(max_model_len, 2048)
+ if enable_chunked_prefill:
+ # For chunked prefill, choose the well-tuned batch size.
+ self.max_num_batched_tokens = 768
+ else:
+ # If max_model_len is too short, use 2048 as the default value
+ # for higher throughput.
+ self.max_num_batched_tokens = max(max_model_len, 2048)
+ if enable_chunked_prefill:
+ logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
+
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.use_v2_block_manager = use_v2_block_manager
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 0ae53f9374960..2942eab735a92 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -140,7 +140,11 @@ def _sort_by_lora_ids(self) -> bool:
@property
def lora_requests(self) -> Set[LoRARequest]:
- return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
+ return {
+ g.seq_group.lora_request
+ for g in self.scheduled_seq_groups
+ if g.seq_group.lora_request is not None
+ }
@dataclass
@@ -826,13 +830,12 @@ def _schedule_chunked_prefill(self):
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
-
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
- running_scheduled.decode_seq_groups +
running_scheduled.prefill_seq_groups +
- swapped_in.decode_seq_groups +
- swapped_in.prefill_seq_groups),
+ swapped_in.prefill_seq_groups +
+ running_scheduled.decode_seq_groups +
+ swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)),
@@ -907,7 +910,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
- is_prompt = i < scheduler_outputs.num_prefill_groups
+ is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py
index cf15db099b304..1004d626b6a4b 100644
--- a/vllm/distributed/communication_op.py
+++ b/vllm/distributed/communication_op.py
@@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
+ async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
- torch.distributed.broadcast(tensor, src=src, group=group)
+ async_handles.append(
+ torch.distributed.broadcast(tensor,
+ src=src,
+ group=group,
+ async_op=True))
+ for async_handle in async_handles:
+ async_handle.wait()
+
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index d4b573992c06c..daefddc01b431 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -386,9 +386,8 @@ def add_cli_args(
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
- type=bool,
- default=False,
- help='If True, the prefill requests can be chunked based on the '
+ action='store_true',
+ help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 1c639af696544..ddfdda898a5c6 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -633,7 +633,10 @@ def _process_model_outputs(
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
- self._process_sequence_group_outputs(seq_group, outputs)
+ # If uncomputed tokens > 0, it means prefill is chunked.
+ # We don't need to process outputs in that case.
+ if seq_group.get_num_uncomputed_tokens() == 0:
+ self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index dd33868f76302..84a94091486d7 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -267,12 +267,13 @@ def set_mapping(
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
- indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
+ embedding_len = self.indices_len[3]
+ indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
- indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
+ indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 576bbe8c4f6c4..77029908c2218 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -500,7 +500,8 @@ def update_num_computed_tokens(self, num_new_computed_tokens: int):
def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
for seq in self.get_seqs():
- num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
+ if not seq.is_finished():
+ num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 1de4748b7bcc9..47ad8f0c9b78b 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -1,12 +1,14 @@
import contextlib
import time
-from typing import Dict, List, Optional, Set, Tuple
+from enum import IntEnum
+from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np
import torch
import torch.nn as nn
-from vllm.attention import AttentionMetadata, get_attn_backend
+from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
+ get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
@@ -37,6 +39,66 @@
]
+class PreparePromptMetadata(NamedTuple):
+ input_tokens: List[int]
+ input_positions: List[int]
+ attn_metadata: Optional[AttentionMetadataPerStage]
+ prompt_lens: List[int]
+ subquery_lens: List[int]
+ lora_index_mapping: List[int]
+ lora_prompt_mapping: List[int]
+ lora_requests: Set[LoRARequest]
+ multi_modal_input: Optional[torch.Tensor]
+ slot_mapping: List[int]
+
+ @classmethod
+ def empty(cls):
+ return PreparePromptMetadata(
+ input_tokens=[],
+ input_positions=[],
+ attn_metadata=None,
+ prompt_lens=[],
+ subquery_lens=[],
+ lora_index_mapping=[],
+ lora_prompt_mapping=[],
+ lora_requests=set(),
+ multi_modal_input=None,
+ slot_mapping=[],
+ )
+
+
+class PrepareDecodeMetadata(NamedTuple):
+ input_tokens: List[int]
+ input_positions: List[int]
+ attn_metadata: Optional[AttentionMetadata]
+ lora_index_mapping: List[int]
+ lora_prompt_mapping: List[int]
+ lora_requests: Set[LoRARequest]
+ slot_mapping: List[int]
+
+ @classmethod
+ def empty(cls):
+ return PrepareDecodeMetadata(
+ input_tokens=[],
+ input_positions=[],
+ attn_metadata=None,
+ lora_index_mapping=[],
+ lora_prompt_mapping=[],
+ lora_requests=set(),
+ slot_mapping=[],
+ )
+
+
+# How batches are constructed.
+class BatchType(IntEnum):
+ # Every batch is prefill.
+ PREFILL = 0
+ # Every batch is decode.
+ DECODE = 1
+ # Batch is a mixture of prefill and decode.
+ MIXED = 2
+
+
class ModelRunner:
def __init__(
@@ -152,10 +214,7 @@ def get_max_block_per_batch(self) -> int:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
- List[int], List[int], List[int], Set[LoRARequest],
- torch.Tensor]:
- assert len(seq_group_metadata_list) > 0
+ ) -> PreparePromptMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
@@ -169,6 +228,9 @@ def _prepare_prompt(
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
+ if len(seq_group_metadata_list) == 0:
+ return PreparePromptMetadata.empty()
+
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
@@ -178,7 +240,8 @@ def _prepare_prompt(
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
- and computed_block_nums is not None):
+ and not (computed_block_nums is None
+ or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
@@ -190,13 +253,8 @@ def _prepare_prompt(
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
- # TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
- prompt_len = len(prompt_tokens)
- # Right now, the prefill_end is always same as the length of
- # sequence. However, once chunked prefill is introduced, this
- # assumption can be changed.
- assert prefill_end == seq_data.get_len()
+ prompt_len = prefill_end
prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention.
@@ -206,6 +264,14 @@ def _prepare_prompt(
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums)
+ elif self.scheduler_config.chunked_prefill_enabled:
+ if seq_group_metadata.block_tables is not None:
+ # Prefill has chunked before.
+ block_table = seq_group_metadata.block_tables[seq_id]
+ prefix_block_tables.append(block_table)
+ else:
+ # The first prefill.
+ prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
@@ -267,20 +333,8 @@ def _prepare_prompt(
max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens)
- num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.long,
- device=self.device)
- input_positions = torch.tensor(input_positions,
- dtype=torch.long,
- device=self.device)
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.long,
- device=self.device)
- lora_index_mapping = lora_index_mapping
-
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
@@ -332,11 +386,8 @@ def _prepare_prompt(
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
- slot_mapping=slot_mapping,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
- num_prompt_tokens=num_prompt_tokens,
- num_generation_tokens=0,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_prompt_len=max_prompt_len,
@@ -345,18 +396,25 @@ def _prepare_prompt(
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
- kv_cache_dtype=self.kv_cache_dtype,
)
- return (input_tokens, input_positions, attn_metadata, prompt_lens,
- subquery_lens, lora_index_mapping, lora_prompt_mapping,
- lora_requests, multi_modal_input)
+
+ return PreparePromptMetadata(
+ input_tokens=input_tokens,
+ input_positions=input_positions,
+ attn_metadata=attn_metadata,
+ prompt_lens=prompt_lens,
+ subquery_lens=subquery_lens,
+ lora_index_mapping=lora_index_mapping,
+ lora_prompt_mapping=lora_prompt_mapping,
+ lora_requests=lora_requests,
+ multi_modal_input=multi_modal_input,
+ slot_mapping=slot_mapping,
+ )
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
- List[int], Set[LoRARequest]]:
- assert len(seq_group_metadata_list) > 0
+ ) -> PrepareDecodeMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
@@ -366,6 +424,9 @@ def _prepare_decode(
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
+ if len(seq_group_metadata_list) == 0:
+ return PrepareDecodeMetadata.empty()
+
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
@@ -424,15 +485,6 @@ def _prepare_decode(
lora_index_mapping.append(0)
batch_size = graph_batch_size
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.long,
- device=self.device)
- input_positions = torch.tensor(input_positions,
- dtype=torch.long,
- device=self.device)
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.long,
- device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
@@ -440,9 +492,9 @@ def _prepare_decode(
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
- assert context_lens.shape[0] == input_tokens.shape[0]
- assert context_lens.shape[0] == input_positions.shape[0]
- assert context_lens.shape[0] == slot_mapping.shape[0]
+ assert context_lens.shape[0] == len(input_tokens)
+ assert context_lens.shape[0] == len(input_positions)
+ assert context_lens.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
@@ -464,11 +516,8 @@ def _prepare_decode(
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
- slot_mapping=slot_mapping,
prompt_lens=None,
prompt_lens_tensor=None,
- num_prompt_tokens=0,
- num_generation_tokens=len(input_tokens),
max_subquery_len=None,
max_context_len=max_context_len,
max_prompt_len=None,
@@ -477,10 +526,16 @@ def _prepare_decode(
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
- kv_cache_dtype=self.kv_cache_dtype,
)
- return (input_tokens, input_positions, attn_metadata,
- lora_index_mapping, lora_prompt_mapping, lora_requests)
+ return PrepareDecodeMetadata(
+ input_tokens=input_tokens,
+ input_positions=input_positions,
+ attn_metadata=attn_metadata,
+ lora_index_mapping=lora_index_mapping,
+ lora_prompt_mapping=lora_prompt_mapping,
+ lora_requests=lora_requests,
+ slot_mapping=slot_mapping,
+ )
def _prepare_sample(
self,
@@ -586,26 +641,66 @@ def prepare_input_tensors(
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
- # NOTE: We assume that all sequences in the group are all prompts or
- # all decodes.
- is_prompt = seq_group_metadata_list[0].is_prompt
+ prefill_reqs = []
+ decode_reqs = []
+ for seq_group_meta in seq_group_metadata_list:
+ if seq_group_meta.is_prompt:
+ prefill_reqs.append(seq_group_meta)
+ else:
+ decode_reqs.append(seq_group_meta)
+
# Prepare input tensors.
- if is_prompt:
- (input_tokens, input_positions, attn_metadata, prompt_lens,
- subquery_lens, lora_index_mapping, lora_prompt_mapping,
- lora_requests, multi_modal_input
- ) = self._prepare_prompt(seq_group_metadata_list)
- else:
- (input_tokens, input_positions, attn_metadata,
- lora_index_mapping, lora_prompt_mapping,
- lora_requests) = self._prepare_decode(seq_group_metadata_list)
- prompt_lens = []
- subquery_lens = None
- multi_modal_input = None
+ (
+ input_tokens,
+ input_positions,
+ prefill_attn_metadata,
+ prompt_lens,
+ subquery_lens,
+ lora_index_mapping,
+ lora_prompt_mapping,
+ lora_requests,
+ multi_modal_input,
+ slot_mapping,
+ ) = self._prepare_prompt(prefill_reqs)
+ (
+ decode_input_tokens,
+ decode_input_positions,
+ decode_attn_metadata,
+ decode_lora_index_mapping,
+ decode_lora_prompt_mapping,
+ decode_lora_requests,
+ decode_slot_mapping,
+ ) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens)
+ if not self.scheduler_config.chunked_prefill_enabled:
+ assert (len(prefill_reqs) and len(decode_reqs)) == 0
+
+ num_prefills = len(prompt_lens)
+ num_prefill_tokens = len(input_tokens)
+ num_decode_tokens = len(decode_input_tokens)
+
+ # Coalesce tensors. Note that attn_metadata is currently not
+ # coalesced for simplicity.
+ input_tokens.extend(decode_input_tokens)
+ input_positions.extend(decode_input_positions)
+ slot_mapping.extend(decode_slot_mapping)
+ lora_index_mapping.extend(decode_lora_index_mapping)
+ lora_prompt_mapping.extend(decode_lora_prompt_mapping)
+ lora_requests.update(decode_lora_requests)
+
+ input_tokens = torch.tensor(input_tokens,
+ dtype=torch.long,
+ device=self.device)
+ input_positions = torch.tensor(input_positions,
+ dtype=torch.long,
+ device=self.device)
+ slot_mapping = torch.tensor(slot_mapping,
+ dtype=torch.long,
+ device=self.device)
+
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
@@ -615,6 +710,16 @@ def prepare_input_tensors(
lora_mapping = None
# Broadcast the metadata.
+ # If batch contains both prefill and decode, it sends 2 broadcasts.
+ # If it only contains 1 type, it triggers a single broadcast.
+ if (prefill_attn_metadata is not None
+ and decode_attn_metadata is not None):
+ batch_type = BatchType.MIXED
+ elif prefill_attn_metadata is not None:
+ batch_type = BatchType.PREFILL
+ else:
+ batch_type = BatchType.DECODE
+
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
@@ -623,19 +728,49 @@ def prepare_input_tensors(
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
+ "num_prefill_tokens": num_prefill_tokens,
+ "num_decode_tokens": num_decode_tokens,
+ "slot_mapping": slot_mapping,
+ "num_prefills": num_prefills,
+ "batch_type": batch_type,
}
- metadata_dict.update(attn_metadata.asdict_zerocopy())
+ if prefill_attn_metadata is not None:
+ metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
+ else:
+ metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
+
+ # Broadcast decode attn metadata for mixed batch type.
+ # The additional broadcast costs 300us overhead on 4 A10 GPUs.
+ # We can potentially reduce the overhead by coelescing tensors.
+ if batch_type == BatchType.MIXED:
+ assert decode_attn_metadata is not None
+ metadata_dict = decode_attn_metadata.asdict_zerocopy()
+ broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
+ slot_mapping = metadata_dict.pop("slot_mapping")
+ num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
- attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
+ num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
+ num_decode_tokens = metadata_dict.pop("num_decode_tokens")
+ batch_type = metadata_dict.pop("batch_type")
+
+ # Create an attention metadata.
+ prefill_attn_metadata = None
+ decode_attn_metadata = None
+ if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
+ prefill_attn_metadata = self.attn_backend.make_metadata(
+ **metadata_dict)
+ else:
+ decode_attn_metadata = self.attn_backend.make_metadata(
+ **metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
@@ -646,6 +781,23 @@ def prepare_input_tensors(
perform_sampling=False,
)
+ # if it is a mixed batch, decode attn_metadata is broadcasted
+ # separately.
+ if batch_type == BatchType.MIXED:
+ metadata_dict = broadcast_tensor_dict(src=0)
+ decode_attn_metadata = self.attn_backend.make_metadata(
+ **metadata_dict)
+
+ attn_metadata = AttentionMetadata(
+ num_prefills=num_prefills,
+ slot_mapping=slot_mapping,
+ num_prefill_tokens=num_prefill_tokens,
+ num_decode_tokens=num_decode_tokens,
+ prefill_metadata=prefill_attn_metadata,
+ decode_metadata=decode_attn_metadata,
+ kv_cache_dtype=self.kv_cache_dtype,
+ )
+
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
@@ -663,8 +815,10 @@ def execute_model(
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
- # Execute the model.
- if attn_metadata.use_cuda_graph:
+ # Currently cuda graph is only supported by the decode phase.
+ prefill_meta = attn_metadata.prefill_metadata
+ decode_meta = attn_metadata.decode_metadata
+ if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
@@ -842,13 +996,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata.
- attn_metadata = self.attn_backend.make_metadata(
+ decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
- slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
prompt_lens_tensor=None,
- num_prompt_tokens=0,
- num_generation_tokens=batch_size,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_prompt_len=None,
@@ -857,6 +1008,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
+ )
+ attn_metadata = AttentionMetadata(
+ num_prefills=0,
+ num_prefill_tokens=0,
+ num_decode_tokens=batch_size,
+ slot_mapping=slot_mapping[:batch_size],
+ prefill_metadata=None,
+ decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
@@ -950,8 +1109,8 @@ def capture(
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
- "context_lens": attn_metadata.context_lens,
- "block_tables": attn_metadata.block_tables,
+ "context_lens": attn_metadata.decode_metadata.context_lens,
+ "block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
@@ -972,10 +1131,10 @@ def forward(
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
- self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
- non_blocking=True)
- self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
- non_blocking=True)
+ self.input_buffers["context_lens"].copy_(
+ attn_metadata.decode_metadata.context_lens, non_blocking=True)
+ self.input_buffers["block_tables"].copy_(
+ attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.
self.graph.replay()
From caada5e50aa16cd5f59bd7889128a83588ca1f99 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 10 Apr 2024 18:48:26 -0700
Subject: [PATCH 011/413] [Core][Model] torch.compile for layernorm in commandr
(#3985)
[Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985)
---
vllm/model_executor/models/commandr.py | 23 +++++++++++++++--------
1 file changed, 15 insertions(+), 8 deletions(-)
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index aa27f0a96c745..aa9b28b676e0b 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -48,6 +48,18 @@
from vllm.sequence import SamplerOutput
+@torch.compile
+def layer_norm_func(hidden_states, weight, variance_epsilon):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ mean = hidden_states.mean(-1, keepdim=True)
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
+ hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
+ variance_epsilon)
+ hidden_states = weight.to(torch.float32) * hidden_states
+ return hidden_states.to(input_dtype)
+
+
class LayerNorm(nn.Module):
def __init__(self, param_shape=None, eps=1e-5):
@@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5):
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- mean = hidden_states.mean(-1, keepdim=True)
- variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
- hidden_states = (hidden_states -
- mean) * torch.rsqrt(variance + self.variance_epsilon)
- hidden_states = self.weight.to(torch.float32) * hidden_states
- return hidden_states.to(input_dtype), residuals
+ hidden_states = layer_norm_func(hidden_states, self.weight,
+ self.variance_epsilon)
+ return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
From e42df7227d18e2b96785f8ee52053663ade05b63 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Thu, 11 Apr 2024 12:09:50 +0900
Subject: [PATCH 012/413] [Test] Add xformer and flash attn tests (#3961)
Co-authored-by: Simon Mo
---
tests/basic_correctness/test_basic_correctness.py | 6 ++++++
vllm/attention/selector.py | 9 +++++++++
2 files changed, 15 insertions(+)
diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py
index 97cff623c5e1d..bd4c7ea3301be 100644
--- a/tests/basic_correctness/test_basic_correctness.py
+++ b/tests/basic_correctness/test_basic_correctness.py
@@ -4,6 +4,8 @@
"""
import pytest
+from vllm.attention.selector import VLLM_ATTENTION_BACKEND
+
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
@@ -14,6 +16,7 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
+@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
@@ -22,7 +25,10 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
+ attn_backend: str,
+ monkeypatch,
) -> None:
+ monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index 4c699aed48d49..554e802cd5513 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -1,4 +1,5 @@
import enum
+import os
from functools import lru_cache
from typing import Type
@@ -10,6 +11,8 @@
logger = init_logger(__name__)
+VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
+
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
@@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.")
return _Backend.XFORMERS
+
+ backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
+ if backend_by_env_var is not None:
+ return _Backend[backend_by_env_var]
+
+ # Default case.
return _Backend.FLASH_ATTN
From e9da5a40c63ce7f8a85438d3c7d919b46e7939f5 Mon Sep 17 00:00:00 2001
From: Kunshang Ji
Date: Thu, 11 Apr 2024 03:26:07 +0000
Subject: [PATCH 013/413] [Misc] Add indirection layer for custom ops (#3913)
---
.../kernels/benchmark_paged_attention.py | 2 +-
tests/kernels/test_attention.py | 6 +-
tests/kernels/test_cache.py | 25 ++-
vllm/_custom_ops.py | 193 ++++++++++++++++++
vllm/attention/ops/paged_attn.py | 10 +-
vllm/model_executor/layers/activation.py | 2 +-
.../layers/fused_moe/fused_moe.py | 2 +-
vllm/model_executor/layers/layernorm.py | 2 +-
.../model_executor/layers/quantization/awq.py | 2 +-
.../layers/quantization/gptq.py | 2 +-
.../layers/quantization/marlin.py | 2 +-
.../layers/quantization/squeezellm.py | 2 +-
.../model_executor/layers/rotary_embedding.py | 2 +-
vllm/utils.py | 4 +-
14 files changed, 224 insertions(+), 32 deletions(-)
create mode 100644 vllm/_custom_ops.py
diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
index f71d1fcaaef50..5c3650fa72d17 100644
--- a/benchmarks/kernels/benchmark_paged_attention.py
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -5,7 +5,7 @@
import torch
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
NUM_BLOCKS = 1024
diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py
index 03ea72924921e..9b1f3e30b6dca 100644
--- a/tests/kernels/test_attention.py
+++ b/tests/kernels/test_attention.py
@@ -7,7 +7,7 @@
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
-from vllm._C import cache_ops, ops
+from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
@@ -237,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
- cache_ops.convert_fp8(key_cache, dequantized_key_cache)
+ ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
- cache_ops.convert_fp8(value_cache, dequantized_value_cache)
+ ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
ref_output = torch.empty_like(query)
diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py
index 4141aacafd0b2..d1051fd7e2f4d 100644
--- a/tests/kernels/test_cache.py
+++ b/tests/kernels/test_cache.py
@@ -4,7 +4,7 @@
import pytest
import torch
-from vllm._C import cache_ops
+from vllm import _custom_ops as ops
from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@@ -80,7 +80,7 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel.
- cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
# Run the reference implementation.
for src, dsts in block_mapping.items():
@@ -145,9 +145,9 @@ def test_reshape_and_cache(
# Clone the KV caches.
if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
- cache_ops.convert_fp8(key_cache, cloned_key_cache)
+ ops.convert_fp8(key_cache, cloned_key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
- cache_ops.convert_fp8(value_cache, cloned_value_cache)
+ ops.convert_fp8(value_cache, cloned_value_cache)
else:
cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone()
@@ -156,14 +156,14 @@ def test_reshape_and_cache(
kv_scale = 1.0
# Call the reshape_and_cache kernel.
- cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
- slot_mapping, kv_cache_dtype, kv_scale)
+ ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
+ kv_cache_dtype, kv_scale)
if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
- cache_ops.convert_fp8(key_cache, result_key_cache)
+ ops.convert_fp8(key_cache, result_key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
- cache_ops.convert_fp8(value_cache, result_value_cache)
+ ops.convert_fp8(value_cache, result_value_cache)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@@ -251,9 +251,8 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
- cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
- cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
- block_mapping)
+ ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
+ ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
for src, dst in block_mapping.items():
assert torch.allclose(src_key_caches_clone[src].cpu(),
@@ -291,9 +290,9 @@ def test_fp8_conversion(
cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
- cache_ops.convert_fp8(cache, cache_fp8)
+ ops.convert_fp8(cache, cache_fp8)
converted_cache = torch.empty_like(cache)
- cache_ops.convert_fp8(cache_fp8, converted_cache)
+ ops.convert_fp8(cache_fp8, converted_cache)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
new file mode 100644
index 0000000000000..a0837a20875fe
--- /dev/null
+++ b/vllm/_custom_ops.py
@@ -0,0 +1,193 @@
+from typing import Dict, Optional
+
+import torch
+
+try:
+ from vllm._C import cache_ops as vllm_cache_ops
+ from vllm._C import ops as vllm_ops
+except ImportError:
+ pass
+
+
+# activation ops
+def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+ vllm_ops.silu_and_mul(out, x)
+
+
+def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+ vllm_ops.gelu_and_mul(out, x)
+
+
+def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
+ vllm_ops.gelu_tanh_and_mul(out, x)
+
+
+def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
+ vllm_ops.gelu_fast(out, x)
+
+
+def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
+ vllm_ops.gelu_new(out, x)
+
+
+# page attention ops
+def paged_attention_v1(
+ out: torch.Tensor,
+ query: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ num_kv_heads: int,
+ scale: float,
+ block_tables: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_size: int,
+ max_context_len: int,
+ alibi_slopes: Optional[torch.Tensor],
+ kv_cache_dtype: str,
+ kv_scale: float,
+) -> None:
+ vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
+ num_kv_heads, scale, block_tables,
+ context_lens, block_size, max_context_len,
+ alibi_slopes, kv_cache_dtype, kv_scale)
+
+
+def paged_attention_v2(
+ out: torch.Tensor,
+ exp_sum: torch.Tensor,
+ max_logits: torch.Tensor,
+ tmp_out: torch.Tensor,
+ query: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ num_kv_heads: int,
+ scale: float,
+ block_tables: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_size: int,
+ max_context_len: int,
+ alibi_slopes: Optional[torch.Tensor],
+ kv_cache_dtype: str,
+ kv_scale: float,
+) -> None:
+ vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
+ key_cache, value_cache, num_kv_heads, scale,
+ block_tables, context_lens, block_size,
+ max_context_len, alibi_slopes, kv_cache_dtype,
+ kv_scale)
+
+
+# pos encoding ops
+def rotary_embedding(
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ head_size: int,
+ cos_sin_cache: torch.Tensor,
+ is_neox: bool,
+) -> None:
+ vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
+ is_neox)
+
+
+def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
+ key: torch.Tensor, head_size: int,
+ cos_sin_cache: torch.Tensor, is_neox: bool,
+ rot_dim: int,
+ cos_sin_cache_offsets: torch.Tensor) -> None:
+ vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
+ cos_sin_cache, is_neox, rot_dim,
+ cos_sin_cache_offsets)
+
+
+# layer norm ops
+def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
+ epsilon: float) -> None:
+ vllm_ops.rms_norm(out, input, weight, epsilon)
+
+
+def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
+ weight: torch.Tensor, epsilon: float) -> None:
+ vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
+
+
+# quantization ops
+# awq
+def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
+ zeros: torch.Tensor, split_k_iters: int, thx: int,
+ thy: int) -> torch.Tensor:
+ return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
+ thy)
+
+
+def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
+ scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
+ return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
+
+
+# gptq
+def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
+ b_g_idx: torch.Tensor, use_exllama: bool,
+ bit: int) -> torch.Tensor:
+ return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
+ b_g_idx, use_exllama, bit)
+
+
+def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
+ bit: int) -> None:
+ vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
+
+
+# squeezellm
+def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
+ lookup_table: torch.Tensor) -> None:
+ vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
+
+
+# marlin
+def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
+ size_n: int, size_k: int) -> torch.Tensor:
+ return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
+ size_n, size_k)
+
+
+# moe
+def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
+ block_size: int, sorted_token_ids: torch.Tensor,
+ experts_ids: torch.Tensor,
+ num_tokens_post_pad: torch.Tensor) -> None:
+ vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
+ sorted_token_ids, experts_ids,
+ num_tokens_post_pad)
+
+
+def reshape_and_cache(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ kv_cache_dtype: str,
+ kv_scale: float,
+) -> None:
+ vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
+ slot_mapping, kv_cache_dtype, kv_scale)
+
+
+def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
+ block_mapping: torch.Tensor) -> None:
+ vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
+
+
+def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
+ block_mapping: Dict[int, int]) -> None:
+ vllm_cache_ops.swap_blocks(src, dst, block_mapping)
+
+
+def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
+ vllm_cache_ops.convert_fp8(output, input)
+
+
+#TODO: cuda_utils, custom_ar
diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py
index 2d918491d6576..cd0690a4ba957 100644
--- a/vllm/attention/ops/paged_attn.py
+++ b/vllm/attention/ops/paged_attn.py
@@ -3,7 +3,7 @@
import torch
-from vllm._C import cache_ops, ops
+from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -69,7 +69,7 @@ def write_to_paged_cache(
kv_cache_dtype: str,
kv_scale: float,
) -> None:
- cache_ops.reshape_and_cache(
+ ops.reshape_and_cache(
key,
value,
key_cache,
@@ -199,11 +199,11 @@ def swap_blocks(
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
- cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+ ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
- cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
+ ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
@@ -212,4 +212,4 @@ def copy_blocks(
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
- cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
+ ops.copy_blocks(key_caches, value_caches, src_to_dists)
diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py
index 6786c48e0caba..baf1d4f266181 100644
--- a/vllm/model_executor/layers/activation.py
+++ b/vllm/model_executor/layers/activation.py
@@ -6,7 +6,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 1ec09f0cd4c28..377b6588dbf47 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -8,7 +8,7 @@
import triton
import triton.language as tl
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import is_hip
diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py
index cb3cee2bad5ad..a6619714b8aab 100644
--- a/vllm/model_executor/layers/layernorm.py
+++ b/vllm/model_executor/layers/layernorm.py
@@ -4,7 +4,7 @@
import torch
import torch.nn as nn
-from vllm._C import ops
+from vllm import _custom_ops as ops
class RMSNorm(nn.Module):
diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py
index 2caef5f1ebf50..daea5ac73e429 100644
--- a/vllm/model_executor/layers/quantization/awq.py
+++ b/vllm/model_executor/layers/quantization/awq.py
@@ -3,7 +3,7 @@
import torch
from torch.nn.parameter import Parameter
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index 53baf710ed811..757ab1af8392e 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -6,7 +6,7 @@
import torch
from torch.nn.parameter import Parameter
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py
index 784229878edf4..a6482c059cc41 100644
--- a/vllm/model_executor/layers/quantization/marlin.py
+++ b/vllm/model_executor/layers/quantization/marlin.py
@@ -3,7 +3,7 @@
import torch
from torch.nn.parameter import Parameter
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py
index ed25455e6ec1f..bb295df2acc3f 100644
--- a/vllm/model_executor/layers/quantization/squeezellm.py
+++ b/vllm/model_executor/layers/quantization/squeezellm.py
@@ -3,7 +3,7 @@
import torch
from torch.nn.parameter import Parameter
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index d80e73bbe39e9..eb8d5f6dfb2a9 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -27,7 +27,7 @@
import torch
import torch.nn as nn
-from vllm._C import ops
+from vllm import _custom_ops as ops
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
diff --git a/vllm/utils.py b/vllm/utils.py
index 8ba03333d3b6c..8ab8927512cc9 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -279,10 +279,10 @@ def _generate_random_fp8(
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
- from vllm._C import cache_ops
+ from vllm import _custom_ops as ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
- cache_ops.convert_fp8(tensor_tmp, tensor)
+ ops.convert_fp8(tensor_tmp, tensor)
del tensor_tmp
From f3d0bf7589d6e63a691dcbb9d1db538c184fde29 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 10 Apr 2024 20:33:02 -0700
Subject: [PATCH 014/413] [Doc][Installation] delete python setup.py develop
(#3989)
---
docs/source/getting_started/installation.rst | 10 ----------
1 file changed, 10 deletions(-)
diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst
index 5dfb32080f97a..e7826114ffa9d 100644
--- a/docs/source/getting_started/installation.rst
+++ b/docs/source/getting_started/installation.rst
@@ -85,13 +85,3 @@ You can also build and install vLLM from source:
$ nvcc --version # verify that nvcc is in your PATH
$ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME
-
-.. note::
- If you are developing the C++ backend of vLLM, consider building vLLM with
-
- .. code-block:: console
-
- $ python setup.py develop
-
- since it will give you incremental builds. The downside is that this method
- is `deprecated by setuptools `_.
From c1dc547129f5faaa2ca5ba557145b8ec8838693c Mon Sep 17 00:00:00 2001
From: Roger Wang <136131678+ywang96@users.noreply.github.com>
Date: Thu, 11 Apr 2024 07:50:00 -0700
Subject: [PATCH 015/413] [Kernel] Fused MoE Config for Mixtral 8x22 (#4002)
---
...048,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++
...048,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++
...096,device_name=NVIDIA_A100-SXM4-80GB.json | 146 ++++++++++++++++++
...096,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++++++++++++++++++
4 files changed, 584 insertions(+)
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
new file mode 100644
index 0000000000000..0bb423b28f5ab
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000000..26bcbf26970c7
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
new file mode 100644
index 0000000000000..dbc624731f5cb
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000000..32c0c9da471cb
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "96": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "256": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
From 08ccee1e830d39ecdb3c6cf382c843dbf5ae830e Mon Sep 17 00:00:00 2001
From: "fuchen.ljl"
Date: Thu, 11 Apr 2024 23:59:26 +0800
Subject: [PATCH 016/413] punica fix-bgmv-kernel-640 (#4007)
---
csrc/punica/bgmv/bgmv_config.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h
index 2219d960ae62f..1084a0f20df6b 100644
--- a/csrc/punica/bgmv/bgmv_config.h
+++ b/csrc/punica/bgmv/bgmv_config.h
@@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
+ f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
From 8afca50889bad6ad987c523c48c31fc52fcb72e4 Mon Sep 17 00:00:00 2001
From: bigPYJ1151
Date: Fri, 12 Apr 2024 02:56:49 +0800
Subject: [PATCH 017/413] [Hardware][Intel] Isolate CPUModelRunner and
ModelRunner for better maintenance (#3824)
---
vllm/attention/backends/torch_sdpa.py | 72 ++---
vllm/executor/cpu_executor.py | 10 +
vllm/utils.py | 1 -
vllm/worker/cpu_model_runner.py | 408 ++++++++++++++++++++++++++
vllm/worker/cpu_worker.py | 13 +-
5 files changed, 443 insertions(+), 61 deletions(-)
create mode 100644 vllm/worker/cpu_model_runner.py
diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py
index 63904ea929870..d21b54b16db4b 100644
--- a/vllm/attention/backends/torch_sdpa.py
+++ b/vllm/attention/backends/torch_sdpa.py
@@ -50,20 +50,15 @@ def copy_blocks(
@dataclass
-class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
+class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
+ AttentionMetadataPerStage):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
+ slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
- prompt_lens_tensor: Optional[torch.Tensor]
-
- max_subquery_len: Optional[int] = None
- max_prompt_len: Optional[int] = None
- subquery_start_loc: Optional[torch.Tensor] = None
- seq_start_loc: Optional[torch.Tensor] = None
- use_cuda_graph: bool = False
def __post_init__(self):
# Set during the execution of the first attention op.
@@ -111,7 +106,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
- attn_metadata: AttentionMetadata[TorchSDPAMetadata],
+ attn_metadata: TorchSDPAMetadata,
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@@ -140,51 +135,36 @@ def forward(
attn_metadata.kv_cache_dtype,
kv_scale)
- num_prefill_tokens = attn_metadata.num_prefill_tokens
- num_decode_tokens = attn_metadata.num_decode_tokens
- assert key.shape[0] == num_prefill_tokens + num_decode_tokens
- assert value.shape[0] == num_prefill_tokens + num_decode_tokens
-
- output = torch.empty_like(query)
- # Query for decode. KV is not needed because it is already cached.
- decode_query = query[num_prefill_tokens:]
- # QKV for prefill.
- query = query[:num_prefill_tokens]
- key = key[:num_prefill_tokens]
- value = value[:num_prefill_tokens]
-
- assert query.shape[0] == num_prefill_tokens
- assert decode_query.shape[0] == num_decode_tokens
-
- if prefill_meta := attn_metadata.prefill_metadata:
- if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
+ if attn_metadata.is_prompt:
+ if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
- if prefill_meta.attn_bias is None:
+ if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
- prefill_meta.prompt_lens) # type: ignore
+ attn_metadata.prompt_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
- prefill_meta.prompt_lens, self.sliding_window,
+ attn_metadata.prompt_lens, self.sliding_window,
query.dtype) # type: ignore
else:
- att_masks = [None] * len(prefill_meta.prompt_lens)
- prefill_meta.attn_bias = att_masks
+ att_masks = [None] * len(attn_metadata.prompt_lens)
+ attn_metadata.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
- out = torch.empty((num_tokens, self.num_heads, self.head_size),
- dtype=query.dtype)
- for prompt_len, mask in zip(prefill_meta.prompt_lens,
- prefill_meta.attn_bias):
+ output = torch.empty(
+ (num_tokens, self.num_heads, self.head_size),
+ dtype=query.dtype)
+ for prompt_len, mask in zip(attn_metadata.prompt_lens,
+ attn_metadata.attn_bias):
end = start + prompt_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
@@ -194,32 +174,28 @@ def forward(
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
- out[start:end, :, :] = sub_out
+ output[start:end, :, :] = sub_out
start = end
- assert out.shape == output[:num_prefill_tokens].shape
- output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
- if decode_meta := attn_metadata.decode_metadata:
+ else:
# Decoding run.
- out = PagedAttention.forward_decode(
- decode_query,
+ output = PagedAttention.forward_decode(
+ query,
key_cache,
value_cache,
- decode_meta.block_tables,
- decode_meta.context_lens,
- decode_meta.max_context_len,
+ attn_metadata.block_tables,
+ attn_metadata.context_lens,
+ attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
- assert out.shape == output[num_prefill_tokens:].shape
- output[num_prefill_tokens:]
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
@@ -241,7 +217,7 @@ def _make_alibi_bias(
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
- bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
+ bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, prompt_len, prompt_len),
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 2bf97338da0ed..eda4e8989c163 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
assert lora_config is None, "cpu backend doesn't support LoRA"
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
+ scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
self.model_config = model_config
self.cache_config = cache_config
@@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config
+def _verify_and_get_scheduler_config(
+ config: SchedulerConfig) -> SchedulerConfig:
+ if config.chunked_prefill_enabled:
+ logger.warning("Chunked prefill is not supported on CPU, disable it.")
+ config.chunked_prefill_enabled = False
+
+ return config
+
+
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching:
diff --git a/vllm/utils.py b/vllm/utils.py
index 8ab8927512cc9..fdb0a3768ab0d 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool:
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif is_cpu():
- print_warning_once("Pin memory is not supported on CPU.")
return False
return True
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
new file mode 100644
index 0000000000000..49e1ad5709f5d
--- /dev/null
+++ b/vllm/worker/cpu_model_runner.py
@@ -0,0 +1,408 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from vllm.attention import AttentionMetadata, get_attn_backend
+from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
+ SchedulerConfig)
+from vllm.distributed import broadcast_tensor_dict
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.model_loader import get_model
+from vllm.sampling_params import SamplingParams, SamplingType
+from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
+from vllm.utils import make_tensor_with_pad, maybe_expand_dim
+
+logger = init_logger(__name__)
+
+_PAD_SLOT_ID = -1
+
+
+class CPUModelRunner:
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig,
+ device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ kv_cache_dtype: Optional[str] = "auto",
+ is_driver_worker: bool = False,
+ *args,
+ **kwargs,
+ ):
+ self.model_config = model_config
+ self.parallel_config = parallel_config
+ self.scheduler_config = scheduler_config
+ self.lora_config = lora_config
+ self.is_driver_worker = is_driver_worker
+
+ # model_config can be None in tests/samplers/test_sampler.py.
+ # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
+ self.sliding_window = (model_config.get_sliding_window()
+ if model_config is not None else None)
+ self.device_config = (device_config
+ if device_config is not None else DeviceConfig())
+ self.device = self.device_config.device
+
+ self.model = None
+ self.block_size = None # Set after initial profiling.
+
+ self.kv_cache_dtype = kv_cache_dtype
+
+ self.attn_backend = get_attn_backend(
+ self.model_config.dtype if model_config is not None else None)
+
+ def load_model(self) -> None:
+ self.model = get_model(self.model_config,
+ self.device_config,
+ lora_config=self.lora_config,
+ parallel_config=self.parallel_config,
+ scheduler_config=self.scheduler_config)
+
+ def _prepare_prompt(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]:
+ assert len(seq_group_metadata_list) > 0
+ input_tokens: List[int] = []
+ input_positions: List[int] = []
+ slot_mapping: List[int] = []
+ prompt_lens: List[int] = []
+
+ for seq_group_metadata in seq_group_metadata_list:
+ assert seq_group_metadata.is_prompt
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+ assert len(seq_ids) == 1
+ seq_id = seq_ids[0]
+
+ seq_data = seq_group_metadata.seq_data[seq_id]
+ prompt_tokens = seq_data.get_token_ids()
+ computed_len = seq_data.get_num_computed_tokens()
+ prompt_len = len(prompt_tokens)
+
+ prompt_lens.append(prompt_len) # Prompt token num
+ input_tokens.extend(prompt_tokens) # Token ids
+
+ # Token position ids
+ # NOTE(woosuk): Here we assume that the first token in the prompt
+ # is always the first token in the sequence.
+ input_positions.extend(list(range(computed_len, prompt_len)))
+
+ # Compute the slot mapping.
+ block_table = seq_group_metadata.block_tables[seq_id]
+ # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
+ # where start_idx is max(0, prompt_len - sliding_window).
+ # For example, if the prompt len is 10, sliding window is 8, and
+ # block size is 4, the first two tokens are masked and the slot
+ # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
+ start_idx = 0
+ if self.sliding_window is not None:
+ start_idx = max(0, prompt_len - self.sliding_window)
+
+ for i in range(computed_len, prompt_len):
+ if i < start_idx:
+ slot_mapping.append(_PAD_SLOT_ID)
+ continue
+
+ block_number = block_table[i //
+ self.block_size] # type: ignore
+ block_offset = i % self.block_size # type: ignore
+ slot = block_number * self.block_size + block_offset
+ slot_mapping.append(slot)
+
+ num_prompt_tokens = len(input_tokens)
+
+ input_tokens = torch.tensor(input_tokens,
+ dtype=torch.long,
+ device=self.device) # type: ignore
+ input_positions = torch.tensor(input_positions,
+ dtype=torch.long,
+ device=self.device) # type: ignore
+ slot_mapping = torch.tensor(slot_mapping,
+ dtype=torch.long,
+ device=self.device) # type: ignore
+
+ attn_metadata = self.attn_backend.make_metadata(
+ is_prompt=True,
+ prompt_lens=prompt_lens,
+ num_prefills=len(prompt_lens),
+ num_prefill_tokens=num_prompt_tokens,
+ num_decode_tokens=0,
+ prefill_metadata=None,
+ decode_metadata=None,
+ max_context_len=None,
+ context_lens=None,
+ block_tables=torch.tensor([]),
+ slot_mapping=slot_mapping,
+ kv_cache_dtype=self.kv_cache_dtype,
+ )
+ return (
+ input_tokens,
+ input_positions,
+ attn_metadata,
+ prompt_lens,
+ )
+
+ def _prepare_decode(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
+ assert len(seq_group_metadata_list) > 0
+ input_tokens: List[int] = []
+ input_positions: List[int] = []
+ slot_mapping: List[int] = []
+ context_lens: List[int] = []
+ block_tables: List[List[int]] = []
+
+ for seq_group_metadata in seq_group_metadata_list:
+ assert not seq_group_metadata.is_prompt
+ assert seq_group_metadata.token_chunk_size == 1
+
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+
+ for seq_id in seq_ids:
+ seq_data = seq_group_metadata.seq_data[seq_id]
+ generation_token = seq_data.get_last_token_id()
+ input_tokens.append(generation_token)
+
+ seq_len = seq_data.get_len()
+ position = seq_len - 1
+ input_positions.append(position)
+
+ context_len = seq_len if self.sliding_window is None else min(
+ seq_len, self.sliding_window)
+ context_lens.append(context_len)
+
+ block_table = seq_group_metadata.block_tables[seq_id]
+ block_number = block_table[position // self.block_size]
+ block_offset = position % self.block_size
+ slot = block_number * self.block_size + block_offset
+ slot_mapping.append(slot)
+
+ if self.sliding_window is not None:
+ sliding_window_blocks = (self.sliding_window //
+ self.block_size)
+ block_table = block_table[-sliding_window_blocks:]
+ block_tables.append(block_table)
+
+ max_context_len = max(context_lens)
+
+ input_tokens = torch.tensor(input_tokens,
+ dtype=torch.long,
+ device=self.device)
+ input_positions = torch.tensor(input_positions,
+ dtype=torch.long,
+ device=self.device)
+ slot_mapping = torch.tensor(slot_mapping,
+ dtype=torch.long,
+ device=self.device)
+ context_lens = torch.tensor(context_lens,
+ dtype=torch.int,
+ device=self.device)
+
+ max_block_table_len = max(
+ len(block_table) for block_table in block_tables)
+ block_tables = make_tensor_with_pad(
+ block_tables,
+ max_len=max_block_table_len,
+ pad=0,
+ dtype=torch.int,
+ device=self.device,
+ )
+
+ attn_metadata = self.attn_backend.make_metadata(
+ is_prompt=False,
+ slot_mapping=slot_mapping,
+ prompt_lens=None,
+ num_prefill_tokens=0,
+ num_decode_tokens=len(input_tokens),
+ max_context_len=max_context_len,
+ num_prefills=0,
+ prefill_metadata=None,
+ decode_metadata=None,
+ context_lens=context_lens,
+ block_tables=block_tables,
+ kv_cache_dtype=self.kv_cache_dtype,
+ )
+ return (
+ input_tokens,
+ input_positions,
+ attn_metadata,
+ )
+
+ def _prepare_sample(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ prompt_lens: List[int],
+ ) -> SamplingMetadata:
+ seq_groups: List[Tuple[List[int], SamplingParams]] = []
+ selected_token_indices: List[int] = []
+ generators: List[torch.Generator] = []
+ selected_token_start_idx = 0
+ categorized_sample_indices = {t: [] for t in SamplingType}
+ categorized_sample_indices_start_idx = 0
+ categorized_sampled_token_indices_start_idx = 0
+
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+ sampling_params = seq_group_metadata.sampling_params
+ seq_groups.append((seq_ids, sampling_params))
+
+ if seq_group_metadata.is_prompt:
+ assert len(seq_ids) == 1
+ subquery_len = prompt_lens[i]
+ if sampling_params.prompt_logprobs is not None:
+ # NOTE: prompt token positions do not need sample, skip
+ categorized_sample_indices_start_idx += subquery_len - 1
+
+ categorized_sample_indices[
+ sampling_params.sampling_type].append([
+ categorized_sample_indices_start_idx,
+ categorized_sampled_token_indices_start_idx
+ ])
+ categorized_sample_indices_start_idx += 1
+ categorized_sampled_token_indices_start_idx += 1
+
+ if sampling_params.prompt_logprobs is not None:
+ selected_token_indices.extend(
+ range(selected_token_start_idx,
+ selected_token_start_idx + subquery_len - 1))
+ selected_token_indices.append(selected_token_start_idx +
+ subquery_len - 1)
+ selected_token_start_idx += subquery_len
+
+ if sampling_params.seed is not None:
+ seq_group_metadata.state.generator = torch.Generator(
+ device=self.device).manual_seed(sampling_params.seed)
+ else:
+ num_seqs = len(seq_ids)
+ selected_token_indices.extend(
+ range(selected_token_start_idx,
+ selected_token_start_idx + num_seqs))
+ selected_token_start_idx += num_seqs
+
+ categorized_sample_indices[
+ sampling_params.sampling_type].extend(
+ zip(
+ range(
+ categorized_sample_indices_start_idx,
+ categorized_sample_indices_start_idx +
+ num_seqs),
+ range(
+ categorized_sampled_token_indices_start_idx,
+ categorized_sampled_token_indices_start_idx +
+ num_seqs)))
+ categorized_sample_indices_start_idx += num_seqs
+ categorized_sampled_token_indices_start_idx += num_seqs
+
+ if sampling_params.seed is not None:
+ generators.append(seq_group_metadata.state.generator)
+
+ selected_token_indices = torch.tensor(selected_token_indices,
+ dtype=torch.long)
+
+ categorized_sample_indices = {
+ t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
+ for t, seq_ids in categorized_sample_indices.items()
+ }
+
+ seq_data: Dict[int, SequenceData] = {}
+ for seq_group_metadata in seq_group_metadata_list:
+ seq_data.update(seq_group_metadata.seq_data)
+
+ sampling_metadata = SamplingMetadata(
+ seq_groups=seq_groups,
+ seq_data=seq_data,
+ prompt_lens=prompt_lens,
+ selected_token_indices=selected_token_indices,
+ categorized_sample_indices=categorized_sample_indices,
+ generators=generators,
+ )
+ return sampling_metadata
+
+ def prepare_input_tensors(
+ self,
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
+ SamplingMetadata]:
+ if self.is_driver_worker:
+ # NOTE: We assume that all sequences in the group are all prompts or
+ # all decodes.
+ is_prompt = seq_group_metadata_list[0].is_prompt
+ # Prepare input tensors.
+ if is_prompt:
+ (input_tokens, input_positions, attn_metadata,
+ prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
+ else:
+ (input_tokens, input_positions,
+ attn_metadata) = self._prepare_decode(seq_group_metadata_list)
+ prompt_lens = []
+ sampling_metadata = self._prepare_sample(seq_group_metadata_list,
+ prompt_lens)
+ # Broadcast the metadata.
+ metadata_dict = {
+ "input_tokens": input_tokens,
+ "input_positions": input_positions,
+ "selected_token_indices":
+ sampling_metadata.selected_token_indices,
+ }
+ metadata_dict.update(attn_metadata.asdict_zerocopy())
+ broadcast_tensor_dict(metadata_dict, src=0)
+ else:
+ metadata_dict = broadcast_tensor_dict(src=0)
+ input_tokens = metadata_dict.pop("input_tokens")
+ input_positions = metadata_dict.pop("input_positions")
+ selected_token_indices = metadata_dict.pop(
+ "selected_token_indices")
+ attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
+ sampling_metadata = SamplingMetadata(
+ seq_groups=None,
+ seq_data=None,
+ prompt_lens=None,
+ selected_token_indices=selected_token_indices,
+ categorized_sample_indices=None,
+ generators=None,
+ perform_sampling=False,
+ )
+
+ return (
+ input_tokens,
+ input_positions,
+ attn_metadata,
+ sampling_metadata,
+ )
+
+ @torch.inference_mode()
+ def execute_model(
+ self,
+ seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ kv_caches: List[torch.Tensor],
+ ) -> Optional[SamplerOutput]:
+ (input_tokens, input_positions, attn_metadata, sampling_metadata
+ ) = self.prepare_input_tensors(seq_group_metadata_list)
+
+ model_executable = self.model
+ execute_model_kwargs = {
+ "input_ids": input_tokens,
+ "positions": input_positions,
+ "kv_caches": kv_caches,
+ "attn_metadata": attn_metadata,
+ }
+
+ hidden_states = model_executable(**execute_model_kwargs)
+
+ # Compute the logits.
+ logits = self.model.compute_logits(hidden_states, sampling_metadata)
+
+ # Only perform sampling in the driver worker.
+ if not sampling_metadata.perform_sampling:
+ return None
+
+ # Sample the next token.
+ output = self.model.sample(
+ logits=logits,
+ sampling_metadata=sampling_metadata,
+ )
+ return output
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 751384eb72af3..3989207e8dd83 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -12,25 +12,14 @@
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
-from vllm.model_executor.model_loader import get_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
-from vllm.worker.model_runner import ModelRunner
+from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
logger = init_logger(__name__)
-class CPUModelRunner(ModelRunner):
-
- def load_model(self) -> None:
- self.model = get_model(self.model_config,
- self.device_config,
- lora_config=self.lora_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config)
-
-
class CPUCacheEngine:
"""Manages the KV cache for CPU backend.
From a10d3056da644c31e4ebf95a2b6ad65a626a7350 Mon Sep 17 00:00:00 2001
From: Antoni Baum
Date: Thu, 11 Apr 2024 13:35:51 -0700
Subject: [PATCH 018/413] [Core] Set `linear_weights` directly on the layer
(#3977)
---
csrc/quantization/gptq/q_gemm.cu | 2 +-
tests/kernels/test_moe.py | 2 +-
vllm/lora/layers.py | 12 +--
vllm/model_executor/layers/linear.py | 77 ++++++++++---------
.../model_executor/layers/quantization/awq.py | 29 +++----
.../layers/quantization/gptq.py | 47 ++++++-----
.../layers/quantization/marlin.py | 23 +++---
.../layers/quantization/squeezellm.py | 24 +++---
8 files changed, 114 insertions(+), 102 deletions(-)
diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu
index 655158e38f557..cc56649917a8a 100644
--- a/csrc/quantization/gptq/q_gemm.cu
+++ b/csrc/quantization/gptq/q_gemm.cu
@@ -2067,7 +2067,7 @@ void gptq_shuffle
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(),
- q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
+ q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 32 / bit,
q_weight.size(1),
bit
diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py
index affbbfb4aa94e..046f11d957bdd 100644
--- a/tests/kernels/test_moe.py
+++ b/tests/kernels/test_moe.py
@@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).cuda()
# Load the weights
- vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
+ vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 84a94091486d7..a8ec4dcfd6137 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -368,7 +368,7 @@ def set_mapping(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
- self.base_layer.linear_weights, x, bias)
+ self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
@@ -402,10 +402,6 @@ def forward(self, input_):
if self.base_layer.skip_bias_add else None)
return output, output_bias
- @property
- def linear_weights(self):
- return self.base_layer.linear_weights
-
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
@@ -505,7 +501,7 @@ def set_lora(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
- self.base_layer.linear_weights, x, bias)
+ self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
@@ -746,7 +742,7 @@ def set_lora(
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
- self.base_layer.linear_weights, x, bias)
+ self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
@@ -838,7 +834,7 @@ def set_mapping(
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
- self.base_layer.linear_weights, x)
+ self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 8f42b3e8a4abe..3ca870742efc5 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional
+from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
- def create_weights(self, input_size_per_partition: int,
+ def create_weights(self, layer: torch.nn.Module,
+ input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
- output_size: int,
- params_dtype: torch.dtype) -> Dict[str, Any]:
- """Create weights for a linear layer."""
+ output_size: int, params_dtype: torch.dtype,
+ **extra_weight_attrs):
+ """Create weights for a linear layer.
+
+ The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply_weights(self,
- weights: Dict[str, torch.Tensor],
+ layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- """Apply the weights to the input tensor."""
+ """Apply the weights in layer to the input tensor.
+
+ Expects create_weights to have been called before on the layer."""
raise NotImplementedError
@@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
- def create_weights(self, input_size_per_partition: int,
+ def create_weights(self, layer: torch.nn.Module,
+ input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
- output_size: int,
- params_dtype: torch.dtype) -> Dict[str, Any]:
+ output_size: int, params_dtype: torch.dtype,
+ **extra_weight_attrs):
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
- return {"weight": weight}
+ layer.register_parameter("weight", weight)
+ set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self,
- weights: Dict[str, torch.Tensor],
+ layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- weight = weights["weight"]
+ weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
@@ -111,12 +118,9 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
- self.linear_weights = self.linear_method.create_weights(
- self.input_size, self.output_size, self.input_size,
- self.output_size, self.params_dtype)
- for name, weight in self.linear_weights.items():
- if isinstance(weight, torch.Tensor):
- self.register_parameter(name, weight)
+ self.linear_method.create_weights(self, self.input_size,
+ self.output_size, self.input_size,
+ self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
@@ -126,7 +130,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
- output = self.linear_method.apply_weights(self.linear_weights, x, bias)
+ output = self.linear_method.apply_weights(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -177,13 +181,13 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
- self.linear_weights = self.linear_method.create_weights(
- self.input_size, self.output_size_per_partition, self.input_size,
- self.output_size, self.params_dtype)
- for name, weight in self.linear_weights.items():
- if isinstance(weight, torch.Tensor):
- self.register_parameter(name, weight)
- set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+ self.linear_method.create_weights(self,
+ self.input_size,
+ self.output_size_per_partition,
+ self.input_size,
+ self.output_size,
+ self.params_dtype,
+ weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
@@ -211,8 +215,7 @@ def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
- output_parallel = self.linear_method.apply_weights(
- self.linear_weights, input_, bias)
+ output_parallel = self.linear_method.apply_weights(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
@@ -523,13 +526,13 @@ def __init__(
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
- self.linear_weights = self.linear_method.create_weights(
- self.input_size_per_partition, self.output_size, self.input_size,
- self.output_size, self.params_dtype)
- for name, weight in self.linear_weights.items():
- if isinstance(weight, torch.Tensor):
- self.register_parameter(name, weight)
- set_weight_attrs(weight, {"weight_loader": self.weight_loader})
+ self.linear_method.create_weights(self,
+ self.input_size_per_partition,
+ self.output_size,
+ self.input_size,
+ self.output_size,
+ self.params_dtype,
+ weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
@@ -569,7 +572,7 @@ def forward(self, input_):
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
- self.linear_weights, input_parallel)
+ self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py
index daea5ac73e429..98651aed8be0e 100644
--- a/vllm/model_executor/layers/quantization/awq.py
+++ b/vllm/model_executor/layers/quantization/awq.py
@@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
- def create_weights(self, input_size_per_partition: int,
+ def create_weights(self, layer: torch.nn.Module,
+ input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
- output_size: int,
- params_dtype: torch.dtype) -> Dict[str, Any]:
+ output_size: int, params_dtype: torch.dtype,
+ **extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
@@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int,
"input_dim": 0,
"output_dim": 1,
})
- return {
- "qweight": qweight,
- "qzeros": qzeros,
- "scales": scales,
- }
+
+ layer.register_parameter("qweight", qweight)
+ set_weight_attrs(qweight, extra_weight_attrs)
+ layer.register_parameter("qzeros", qzeros)
+ set_weight_attrs(qzeros, extra_weight_attrs)
+ layer.register_parameter("scales", scales)
+ set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self,
- weights: Dict[str, Any],
+ layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qweight = weights["qweight"]
- scales = weights["scales"]
- qzeros = weights["qzeros"]
+ qweight = layer.qweight
+ scales = layer.scales
+ qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
@@ -163,5 +166,5 @@ def apply_weights(self,
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
- out = out + bias
+ out.add_(bias)
return out.reshape(out_shape)
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index 757ab1af8392e..f370b94a210ee 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig):
def create_weights(
self,
+ layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
- ) -> Dict[str, Any]:
+ **extra_weight_attrs,
+ ):
del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
@@ -179,37 +181,40 @@ def create_weights(
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
})
- return {
- "qweight": qweight,
- "g_idx": g_idx,
- "qzeros": qzeros,
- "scales": scales,
- "exllama_state": exllama_state,
- }
+
+ layer.register_parameter("qweight", qweight)
+ set_weight_attrs(qweight, extra_weight_attrs)
+ layer.register_parameter("g_idx", g_idx)
+ set_weight_attrs(g_idx, extra_weight_attrs)
+ layer.register_parameter("qzeros", qzeros)
+ set_weight_attrs(qzeros, extra_weight_attrs)
+ layer.register_parameter("scales", scales)
+ set_weight_attrs(scales, extra_weight_attrs)
+
+ layer.exllama_state = exllama_state
def apply_weights(self,
- weights: Dict[str, Any],
+ layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qweight = weights["qweight"]
+ qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
- if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
+ if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
- weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
- torch.int)
+ layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
- weights["g_idx"] = torch.empty((1, 1), device="meta")
- weights["exllama_state"] = ExllamaState.READY
- ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
+ layer.g_idx.data = torch.empty((0, ),
+ device=layer.g_idx.device)
+ layer.exllama_state = ExllamaState.READY
+ ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
- output = ops.gptq_gemm(reshaped_x, weights["qweight"],
- weights["qzeros"], weights["scales"],
- weights["g_idx"],
- weights["exllama_state"] == ExllamaState.READY,
+ output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
+ layer.scales, layer.g_idx,
+ layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
- output = output + bias
+ output.add_(bias)
return output.reshape(out_shape)
diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py
index a6482c059cc41..bf0500f1155a1 100644
--- a/vllm/model_executor/layers/quantization/marlin.py
+++ b/vllm/model_executor/layers/quantization/marlin.py
@@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig):
def create_weights(
self,
+ layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
- ) -> Dict[str, Any]:
+ **extra_weight_attrs,
+ ):
del output_size # Unused.
if params_dtype != torch.float16:
@@ -187,21 +189,22 @@ def create_weights(
dtype=torch.int),
requires_grad=False)
- return {
- "B": qweight,
- "s": scales,
- "workspace": workspace,
- }
+ layer.register_parameter("B", qweight)
+ set_weight_attrs(qweight, extra_weight_attrs)
+ layer.register_parameter("s", scales)
+ set_weight_attrs(scales, extra_weight_attrs)
+ layer.register_parameter("workspace", workspace)
+ set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights(
self,
- weights: Dict[str, Any],
+ layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- qweight = weights["B"]
- scales = weights["s"]
- workspace = weights["workspace"]
+ qweight = layer.B
+ scales = layer.s
+ workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py
index bb295df2acc3f..661ff9c55d0d1 100644
--- a/vllm/model_executor/layers/quantization/squeezellm.py
+++ b/vllm/model_executor/layers/quantization/squeezellm.py
@@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
- def create_weights(self, input_size_per_partition: int,
+ def create_weights(self, layer: torch.nn.Module,
+ input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
- output_size: int,
- params_dtype: torch.dtype) -> Dict[str, Any]:
+ output_size: int, params_dtype: torch.dtype,
+ **extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
@@ -103,17 +104,18 @@ def create_weights(self, input_size_per_partition: int,
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
- return {
- "qweight": qweight,
- "lookup_table": lookup_table,
- }
+
+ layer.register_parameter("qweight", qweight)
+ set_weight_attrs(qweight, extra_weight_attrs)
+ layer.register_parameter("lookup_table", lookup_table)
+ set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self,
- weights: Dict[str, Any],
+ layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qweight = weights["qweight"]
- lookup_table = weights["lookup_table"]
+ qweight = layer.qweight
+ lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
@@ -126,5 +128,5 @@ def apply_weights(self,
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
- out = out + bias
+ out.add_(bias)
return out.reshape(out_shape)
From 559eb852f83fe7867390dd2986b4f93a6572cf10 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Thu, 11 Apr 2024 14:00:48 -0700
Subject: [PATCH 019/413] [Core] init_distributed_environment align with
init_process_group(#4014)
[Core][Distributed] make init_distributed_environment compatible with init_process_group (#4014)
---
vllm/distributed/parallel_state.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 4bb77146295af..9fceffe7cb88b 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -39,9 +39,9 @@
def init_distributed_environment(
- world_size: int,
- rank: int,
- distributed_init_method: Optional[str] = None,
+ world_size: int = -1,
+ rank: int = -1,
+ distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
):
From 95e7d4a97cd64f8c6dc226ec0bbceebef6458701 Mon Sep 17 00:00:00 2001
From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com>
Date: Thu, 11 Apr 2024 15:15:50 -0700
Subject: [PATCH 020/413] Fix echo/logprob OpenAI completion bug (#3441)
Co-authored-by: Dylan Hawk
---
tests/entrypoints/test_openai_server.py | 31 ++++++++++++
vllm/entrypoints/openai/serving_chat.py | 9 ++--
vllm/entrypoints/openai/serving_completion.py | 15 ++++--
vllm/entrypoints/openai/serving_engine.py | 47 +++++++++++--------
4 files changed, 73 insertions(+), 29 deletions(-)
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index 6f2086c4dd269..7940430b8b654 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -742,5 +742,36 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI):
assert content.strip() == ground_truth
+@pytest.mark.parametrize(
+ # first test base model, then test loras
+ "model_name",
+ [MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
+)
+async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
+ model_name: str):
+ tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
+ # test using text and token IDs
+ for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
+ completion = await client.completions.create(model=model_name,
+ prompt=prompt,
+ max_tokens=5,
+ temperature=0.0,
+ echo=True,
+ logprobs=1)
+
+ prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
+ list) else prompt
+ assert (completion.choices[0].text is not None
+ and re.search(r"^" + prompt_text, completion.choices[0].text))
+ logprobs = completion.choices[0].logprobs
+ assert logprobs is not None
+ assert len(logprobs.text_offset) > 5
+ assert (len(logprobs.token_logprobs) > 5
+ and logprobs.token_logprobs[0] is None)
+ assert (len(logprobs.top_logprobs) > 5
+ and logprobs.top_logprobs[0] is None)
+ assert len(logprobs.tokens) > 5
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 0980c3d3cb614..a03c5dc88108f 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -63,8 +63,9 @@ async def create_chat_completion(
request_id = f"cmpl-{random_uuid()}"
try:
- token_ids = self._validate_prompt_and_tokenize(request,
- prompt=prompt)
+ # Tokenize/detokenize depending on prompt format (string/token list)
+ prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
+ request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
@@ -78,8 +79,8 @@ async def create_chat_completion(
except ValueError as e:
return self.create_error_response(str(e))
- result_generator = self.engine.generate(prompt, sampling_params,
- request_id, token_ids,
+ result_generator = self.engine.generate(prompt_text, sampling_params,
+ request_id, prompt_ids,
lora_request)
# Streaming response
if request.stream:
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 06e7a9225fefb..c1f1744a118bd 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -136,23 +136,24 @@ async def create_completion(self, request: CompletionRequest,
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
- input_ids = self._validate_prompt_and_tokenize(
+ prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
- input_ids = self._validate_prompt_and_tokenize(
+ prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
+ prompt_ids, prompt_text = prompt_formats
generators.append(
- self.engine.generate(prompt,
+ self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
- prompt_token_ids=input_ids,
+ prompt_token_ids=prompt_ids,
lora_request=lora_request))
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
@@ -326,7 +327,8 @@ def request_output_to_completion_response(
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
- top_logprobs = prompt_logprobs + output.logprobs
+ top_logprobs = (prompt_logprobs + output.logprobs
+ if request.logprobs else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
@@ -334,6 +336,9 @@ def request_output_to_completion_response(
output_text = output.text
if request.logprobs is not None:
+ assert top_logprobs is not None, (
+ "top_logprobs must be provided when logprobs "
+ "is requested")
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 8f69388c0251e..77a568b564039 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -2,7 +2,7 @@
import json
from dataclasses import dataclass
from http import HTTPStatus
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Tuple, Union
from pydantic import conint
@@ -99,27 +99,32 @@ def _create_logprobs(
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
+
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
- if step_top_logprobs is not None:
- token_logprob = step_top_logprobs[token_id].logprob
+ if step_top_logprobs is None:
+ token = self.tokenizer.decode(token_id)
+ logprobs.tokens.append(token)
+ logprobs.token_logprobs.append(None)
+ logprobs.top_logprobs.append(None)
else:
- token_logprob = None
- token = step_top_logprobs[token_id].decoded_token
- logprobs.tokens.append(token)
- logprobs.token_logprobs.append(token_logprob)
+ token_logprob = step_top_logprobs[token_id].logprob
+ token = step_top_logprobs[token_id].decoded_token
+ logprobs.tokens.append(token)
+ logprobs.token_logprobs.append(token_logprob)
+
+ if num_output_top_logprobs:
+ logprobs.top_logprobs.append({
+ p.decoded_token: p.logprob
+ for i, p in step_top_logprobs.items()
+ } if step_top_logprobs else None)
+
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
-
- if num_output_top_logprobs:
- logprobs.top_logprobs.append({
- p.decoded_token: p.logprob
- for i, p in step_top_logprobs.items()
- } if step_top_logprobs else None)
return logprobs
def create_error_response(
@@ -164,12 +169,12 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
- self,
- request: Union[ChatCompletionRequest, CompletionRequest],
- prompt: Optional[str] = None,
- prompt_ids: Optional[List[int]] = None,
- truncate_prompt_tokens: Optional[conint(ge=1)] = None
- ) -> List[int]:
+ self,
+ request: Union[ChatCompletionRequest, CompletionRequest],
+ prompt: Optional[str] = None,
+ prompt_ids: Optional[List[int]] = None,
+ truncate_prompt_tokens: Optional[conint(ge=1)] = None
+ ) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
@@ -187,6 +192,8 @@ def _validate_prompt_and_tokenize(
else:
input_ids = prompt_ids
+ input_text = prompt if prompt is not None else self.tokenizer.decode(
+ prompt_ids)
token_num = len(input_ids)
if request.max_tokens is None:
@@ -201,4 +208,4 @@ def _validate_prompt_and_tokenize(
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
- return input_ids
+ return input_ids, input_text
From 1e96c3341a4e055ae392085fecc7a672295b71c2 Mon Sep 17 00:00:00 2001
From: Antoni Baum
Date: Thu, 11 Apr 2024 15:18:57 -0700
Subject: [PATCH 021/413] Add extra punica sizes to support bigger vocabs
(#4015)
---
csrc/punica/bgmv/bgmv_config.h | 12 +++++-
csrc/punica/punica_ops.cc | 14 +++---
tests/lora/test_layers.py | 78 +++++++++++++++++++---------------
tests/lora/test_punica.py | 49 +++++++++++++++++++--
vllm/lora/layers.py | 4 +-
5 files changed, 109 insertions(+), 48 deletions(-)
diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h
index 1084a0f20df6b..9b76b98ab3322 100644
--- a/csrc/punica/bgmv/bgmv_config.h
+++ b/csrc/punica/bgmv/bgmv_config.h
@@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
-// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
+ f(in_T, out_T, W_T, narrow, 64000) \
+ f(in_T, out_T, W_T, narrow, 64256) \
+ f(in_T, out_T, W_T, narrow, 64512) \
+ f(in_T, out_T, W_T, narrow, 102400) \
+ f(in_T, out_T, W_T, narrow, 102656) \
+ f(in_T, out_T, W_T, narrow, 102912) \
+ f(in_T, out_T, W_T, narrow, 128000) \
+ f(in_T, out_T, W_T, narrow, 128256) \
+ f(in_T, out_T, W_T, narrow, 128512) \
+// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
+// and vllm/tests/lora/test_punica.py
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc
index 28739be14b862..7ebfd851c4feb 100644
--- a/csrc/punica/punica_ops.cc
+++ b/csrc/punica/punica_ops.cc
@@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
}
}
-inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
- return (uint32_t(a) << 16) | uint32_t(b);
+inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
+ return (uint64_t(a) << 32) | uint64_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
@@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
- uint16_t in_features, uint16_t out_features,
+ uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
- switch (pack_u16(in_features, out_features)) {
+ switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
- case pack_u16(feat_in, feat_out): \
+ case pack_u32(feat_in, feat_out): \
bgmv_kernel(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
@@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
- if (h_in < 65536 && h_out < 65536) {
+ if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
@@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
- if (h_in < 65536 && h_out < 65536) {
+ if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py
index 71ce6f1764832..e9e0c8554c1ef 100644
--- a/tests/lora/test_layers.py
+++ b/tests/lora/test_layers.py
@@ -170,7 +170,8 @@ def create_random_inputs(
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
-def test_embeddings(dist_init, num_loras, device) -> None:
+@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
+def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
torch.set_default_device(device)
max_loras = 8
@@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)
def create_random_embedding_layer():
- embedding = VocabParallelEmbedding(512, 256)
+ embedding = VocabParallelEmbedding(vocab_size, 256)
embedding.weight.data = torch.rand_like(embedding.weight.data)
- embedding.weight.data[512:, :] = 0
+ embedding.weight.data[vocab_size:, :] = 0
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
lora_embedding.create_lora_weights(max_loras, lora_config)
@@ -203,12 +204,13 @@ def create_random_embedding_layer():
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
- input_range=(1, 512),
+ input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
- 512, lora_config.lora_extra_vocab_size)
+ vocab_size,
+ lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info)
lora_result = lora_embedding(torch.cat(inputs))
@@ -240,12 +242,13 @@ def create_random_embedding_layer():
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
- input_range=(1, 512),
+ input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
- 512, lora_config.lora_extra_vocab_size)
+ vocab_size,
+ lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(inputs))
@@ -263,7 +266,9 @@ def create_random_embedding_layer():
# reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
-def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
+@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
+def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
+ vocab_size) -> None:
torch.set_default_device(device)
max_loras = 8
@@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)
def create_random_embedding_layer():
- embedding = VocabParallelEmbedding(512, 256)
+ embedding = VocabParallelEmbedding(vocab_size, 256)
embedding_data = torch.rand_like(embedding.weight.data)
embedding.weight.data = embedding_data
- embedding.weight.data[512:, :] = 0
+ embedding.weight.data[vocab_size:, :] = 0
expanded_embedding = VocabParallelEmbedding(
- 512 + lora_config.lora_extra_vocab_size * max_loras,
+ vocab_size + lora_config.lora_extra_vocab_size * max_loras,
256,
- org_num_embeddings=512)
- expanded_embedding.weight.data[:512, :] = embedding_data
+ org_num_embeddings=vocab_size)
+ expanded_embedding.weight.data[:vocab_size, :] = embedding_data
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding = VocabParallelEmbeddingWithLoRA(
@@ -298,7 +303,7 @@ def create_random_embedding_layer():
id_to_index,
layer=lora_embedding,
layer_weights=torch.zeros(
- (256, 512 + lora_config.lora_extra_vocab_size)),
+ (256, vocab_size + lora_config.lora_extra_vocab_size)),
generate_embeddings_tensor=256,
)
@@ -316,7 +321,7 @@ def create_random_embedding_layer():
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
- input_range=(1, 512),
+ input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@@ -327,16 +332,18 @@ def create_random_embedding_layer():
for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping):
embedding_id = lora_id - 1
- input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
- original_input_[-1] = 512
- input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
- original_input_[-2] = 512 + embeddings_tensor_len - 1
+ input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
+ original_input_[-1] = vocab_size
+ input_[-2] = vocab_size + (
+ (embedding_id + 1) * embeddings_tensor_len - 1)
+ original_input_[-2] = vocab_size + embeddings_tensor_len - 1
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
- 512, lora_config.lora_extra_vocab_size)
+ vocab_size,
+ lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
- expanded_embedding.weight[512:512 +
+ expanded_embedding.weight[vocab_size:vocab_size +
(embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors)
@@ -370,14 +377,15 @@ def create_random_embedding_layer():
active_lora_ids=[0],
num_inputs=num_loras * 3,
input_size=(200, ),
- input_range=(1, 512),
+ input_range=(1, vocab_size),
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
- 512, lora_config.lora_extra_vocab_size)
+ vocab_size,
+ lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(original_inputs))
@@ -393,7 +401,9 @@ def create_random_embedding_layer():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
-def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
+@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
+def test_lm_head_logits_processor(dist_init, num_loras, device,
+ vocab_size) -> None:
torch.set_default_device(device)
max_loras = 8
@@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16)
def _pretest():
- linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
- 1024, 32000)
+ linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
+ 1024, vocab_size)
linear.weight.data = torch.rand_like(linear.weight.data)
- linear.weight.data[:, 32000:] = 0
+ linear.weight.data[:, vocab_size:] = 0
logits_processor = LogitsProcessor(
- 32000 + lora_config.lora_extra_vocab_size, 32000)
+ vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config)
@@ -444,7 +454,7 @@ def _pretest():
lora_mapping,
id_to_index,
max_loras,
- 32000,
+ vocab_size,
lora_config.lora_extra_vocab_size,
)
lora_logits_processor.set_mapping(*mapping_info, )
@@ -460,7 +470,7 @@ def _pretest():
org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor
- logits_processor.org_vocab_size = (32000 +
+ logits_processor.org_vocab_size = (vocab_size +
lora_config.lora_extra_vocab_size)
expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping):
@@ -468,11 +478,11 @@ def _pretest():
result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight,
embedding_bias=None)
- result[:, 32000 + embeddings_tensor_len:] = float("-inf")
+ result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
- logits_processor.org_vocab_size = 32000
+ logits_processor.org_vocab_size = vocab_size
# Check that resetting the lora weights succeeds
@@ -489,14 +499,14 @@ def _pretest():
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
- 32000,
+ vocab_size,
lora_config.lora_extra_vocab_size)
lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
- embedding_bias=None)[:, :32000]
+ embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
embedding=original_weight,
diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py
index 2736a1c7ade27..cab8b44ccd2df 100644
--- a/tests/lora/test_punica.py
+++ b/tests/lora/test_punica.py
@@ -43,10 +43,51 @@ def _lora_ref_impl(
H1 = H2 = [
- 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
- 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
- 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
- 32768, 33024
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1152,
+ 1280,
+ 1536,
+ 2048,
+ 2304,
+ 2560,
+ 2752,
+ 3072,
+ 3456,
+ 3584,
+ 4096,
+ 4608,
+ 5120,
+ 5504,
+ 5632,
+ 6144,
+ 6848,
+ 6912,
+ 7168,
+ 8192,
+ 9216,
+ 10240,
+ 11008,
+ 13824,
+ 14336,
+ 22016,
+ 24576,
+ 27392,
+ 32000,
+ 32256,
+ 32512,
+ 32768,
+ 33024,
+ 36864,
+ 49152,
+ 64000,
+ 64256,
+ 102400,
+ 102656,
+ 128000,
+ 128256,
]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index a8ec4dcfd6137..5456b5613c47a 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -935,9 +935,9 @@ def create_lora_weights(
model_config: Optional[PretrainedConfig] = None,
) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
- if 32000 < self.base_layer.vocab_size > 33024:
+ if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be "
- "32000 >= vocab_size <= 33024")
+ "32000 >= vocab_size <= 128512")
self.lora_a_stacked = torch.zeros(
(
max_loras,
From e46a60aa4c90cf3dfd9b90782f2eeabbda935eef Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Thu, 11 Apr 2024 23:34:12 +0100
Subject: [PATCH 022/413] [BugFix] Fix handling of stop strings and stop token
ids (#3672)
---
tests/conftest.py | 2 +-
.../{samplers => engine}/test_stop_reason.py | 2 +-
tests/engine/test_stop_strings.py | 111 ++++++++++++++++++
vllm/engine/llm_engine.py | 98 ++++++++++------
vllm/outputs.py | 4 +-
vllm/sampling_params.py | 9 ++
vllm/sequence.py | 6 +
vllm/transformers_utils/detokenizer.py | 7 +-
8 files changed, 202 insertions(+), 37 deletions(-)
rename tests/{samplers => engine}/test_stop_reason.py (97%)
create mode 100644 tests/engine/test_stop_strings.py
diff --git a/tests/conftest.py b/tests/conftest.py
index a7e8963af0eda..5c50fc2d1bab6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -401,7 +401,7 @@ def __del__(self):
cleanup()
-@pytest.fixture
+@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner
diff --git a/tests/samplers/test_stop_reason.py b/tests/engine/test_stop_reason.py
similarity index 97%
rename from tests/samplers/test_stop_reason.py
rename to tests/engine/test_stop_reason.py
index b242c405a4fb6..b2f521a8ae4ce 100644
--- a/tests/samplers/test_stop_reason.py
+++ b/tests/engine/test_stop_reason.py
@@ -3,7 +3,7 @@
2. One of the provided stop tokens
3. The EOS token
-Run `pytest tests/samplers/test_stop_reason.py`.
+Run `pytest tests/engine/test_stop_reason.py`.
"""
import pytest
diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py
new file mode 100644
index 0000000000000..6b747beb4b543
--- /dev/null
+++ b/tests/engine/test_stop_strings.py
@@ -0,0 +1,111 @@
+from typing import Any, List, Optional
+
+import pytest
+
+from vllm import CompletionOutput, LLMEngine, SamplingParams
+
+MODEL = "meta-llama/llama-2-7b-hf"
+MAX_TOKENS = 200
+
+
+@pytest.fixture(scope="session")
+def vllm_model(vllm_runner):
+ return vllm_runner(MODEL)
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_basic(vllm_model):
+ _test_stopping(vllm_model.model.llm_engine,
+ stop=["."],
+ include_in_output=False,
+ expected_output="VLLM is a 100% volunteer organization",
+ expected_reason=".")
+
+ _test_stopping(vllm_model.model.llm_engine,
+ stop=["."],
+ include_in_output=True,
+ expected_output="VLLM is a 100% volunteer organization.",
+ expected_reason=".")
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_multi_tokens(vllm_model):
+ _test_stopping(
+ vllm_model.model.llm_engine,
+ stop=["group of peo", "short"],
+ include_in_output=False,
+ expected_output="VLLM is a 100% volunteer organization. We are a ",
+ expected_reason="group of peo")
+
+ _test_stopping(
+ vllm_model.model.llm_engine,
+ stop=["group of peo", "short"],
+ include_in_output=True,
+ expected_output=
+ "VLLM is a 100% volunteer organization. We are a group of peo",
+ expected_reason="group of peo")
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_partial_token(vllm_model):
+ _test_stopping(vllm_model.model.llm_engine,
+ stop=["gani"],
+ include_in_output=False,
+ expected_output="VLLM is a 100% volunteer or",
+ expected_reason="gani")
+
+ _test_stopping(vllm_model.model.llm_engine,
+ stop=["gani"],
+ include_in_output=True,
+ expected_output="VLLM is a 100% volunteer organi",
+ expected_reason="gani")
+
+
+@pytest.mark.skip_global_cleanup
+def test_stop_token_id(vllm_model):
+ # token id 13013 => " organization"
+
+ _test_stopping(vllm_model.model.llm_engine,
+ stop_token_ids=[13013],
+ include_in_output=False,
+ expected_output="VLLM is a 100% volunteer",
+ expected_reason=13013)
+
+ _test_stopping(vllm_model.model.llm_engine,
+ stop_token_ids=[13013],
+ include_in_output=True,
+ expected_output="VLLM is a 100% volunteer organization",
+ expected_reason=13013)
+
+
+def _test_stopping(llm_engine: LLMEngine,
+ expected_output: str,
+ expected_reason: Any,
+ stop: Optional[List[str]] = None,
+ stop_token_ids: Optional[List[int]] = None,
+ include_in_output: bool = False) -> None:
+ llm_engine.add_request(
+ "id", "A story about vLLM:\n",
+ SamplingParams(
+ temperature=0.0,
+ max_tokens=MAX_TOKENS,
+ stop=stop,
+ stop_token_ids=stop_token_ids,
+ include_stop_str_in_output=include_in_output,
+ ), None)
+
+ output: Optional[CompletionOutput] = None
+ output_text = ""
+ stop_reason = None
+ while llm_engine.has_unfinished_requests():
+ (request_output, ) = llm_engine.step()
+ (output, ) = request_output.outputs
+
+ # Ensure we don't backtrack
+ assert output.text.startswith(output_text)
+ output_text = output.text
+ stop_reason = output.stop_reason
+
+ assert output is not None
+ assert output_text == expected_output
+ assert stop_reason == expected_reason
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index ddfdda898a5c6..a91629a630591 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize:
- self.detokenizer.decode_sequence_inplace(
+ new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
- self._check_stop(seq, seq_group.sampling_params)
+ else:
+ new_char_count = 0
+ self._check_stop(seq, new_char_count, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
@@ -798,56 +800,86 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)
- def _check_stop(self, seq: Sequence,
+ def _check_stop(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
- """Stop the finished sequences."""
- # Check if the sequence has reached max_model_len.
- if seq.get_len() > self.scheduler_config.max_model_len:
- seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
- return
+ """Stop the finished sequences.
- # Check if the sequence has reached max_tokens.
- if seq.get_output_len() == sampling_params.max_tokens:
- seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
- return
+ new_char_count is the number of chars added to the
+ sequence's output text for the newly generated token
+ """
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
- if sampling_params.detokenize:
- for stop_str in sampling_params.stop:
- if seq.output_text.endswith(stop_str):
- self._finalize_sequence(seq, sampling_params, stop_str)
- seq.status = SequenceStatus.FINISHED_STOPPED
- seq.stop_reason = stop_str
- return
+ # Check if the sequence has generated the EOS token.
+ if ((not sampling_params.ignore_eos)
+ and seq.get_last_token_id() == seq.eos_token_id):
+ seq.status = SequenceStatus.FINISHED_STOPPED
+ return
+
+ # Check if a stop token was encountered.
+ # This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
- stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
- last_token_id)
- self._finalize_sequence(seq, sampling_params, stop_str)
+ if new_char_count and (
+ not sampling_params.include_stop_str_in_output):
+ # Remove last token
+ seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
- # Check if the sequence has generated the EOS token.
- if ((not sampling_params.ignore_eos)
- and seq.get_last_token_id() == seq.eos_token_id):
+ # Check if any stop strings are matched.
+ stop_str = self._check_stop_strings(seq, new_char_count,
+ sampling_params)
+ if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
+ seq.stop_reason = stop_str
return
- def _finalize_sequence(self, seq: Sequence,
- sampling_params: SamplingParams,
- stop_string: str) -> None:
- if sampling_params.include_stop_str_in_output:
+ # Check if the sequence has reached max_model_len.
+ if seq.get_len() > self.scheduler_config.max_model_len:
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
- if stop_string and seq.output_text.endswith(stop_string):
- # Truncate the output text so that the stop string is
- # not included in the output.
- seq.output_text = seq.output_text[:-len(stop_string)]
+ # Check if the sequence has reached max_tokens.
+ if seq.get_output_len() == sampling_params.max_tokens:
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
+ return
+
+ @staticmethod
+ def _check_stop_strings(seq: Sequence, new_char_count: int,
+ sampling_params: SamplingParams) -> Optional[str]:
+ """Check if any stop strings are matched and truncate sequence
+ output text accordingly.
+
+ Returns the stop string if matched or else None.
+ """
+ if not new_char_count:
+ return None
+
+ for stop_str in sampling_params.stop:
+ stop_string_len = len(stop_str)
+ # Avoid searching already-searched text.
+ stop_index = seq.output_text.find(
+ stop_str, -new_char_count - stop_string_len)
+ if stop_index == -1:
+ continue
+
+ if sampling_params.include_stop_str_in_output:
+ # Truncate to end of stop string.
+ stop_index += stop_string_len
+ if stop_index >= len(seq.output_text):
+ # No truncation required.
+ return stop_str
+
+ # Truncate the output text to either the beginning
+ # or end of the stop string.
+ seq.output_text = seq.output_text[:stop_index]
+ return stop_str
+ return None
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
diff --git a/vllm/outputs.py b/vllm/outputs.py
index 61fe20bfc2744..d01be0eb0efd2 100644
--- a/vllm/outputs.py
+++ b/vllm/outputs.py
@@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
+ text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
- CompletionOutput(seqs.index(seq), seq.output_text,
+ CompletionOutput(seqs.index(seq),
+ seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 4fdc3c6dedaef..0b9787608798c 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -166,6 +166,13 @@ def __init__(
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
+ # Number of characters to hold back for stop string evaluation
+ # until sequence is finished.
+ if self.stop and not include_stop_str_in_output:
+ self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
+ else:
+ self.output_text_buffer_length = 0
+
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
@@ -226,6 +233,8 @@ def _verify_args(self) -> None:
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
+ if any(not stop_str for stop_str in self.stop):
+ raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 77029908c2218..cdb6cce6f0255 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -235,6 +235,12 @@ def __init__(
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
+ def get_output_text_to_return(self, buffer_length: int):
+ # We return the full output text if the sequence is finished.
+ truncate = buffer_length and not self.is_finished()
+ return self.output_text[:-buffer_length] if truncate else (
+ self.output_text)
+
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py
index 486c1938e1e10..005932f1e3df4 100644
--- a/vllm/transformers_utils/detokenizer.py
+++ b/vllm/transformers_utils/detokenizer.py
@@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace(
prev_tokens.extend(next_iter_tokens)
def decode_sequence_inplace(self, seq: Sequence,
- prms: SamplingParams) -> None:
+ prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
+
+ Returns:
+ The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
@@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence,
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text
+ return len(new_decoded_token_text)
+
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
From c2b4a1bce9a7707179cdfab2fb498c20b2b221e6 Mon Sep 17 00:00:00 2001
From: Michael Feil <63565275+michaelfeil@users.noreply.github.com>
Date: Thu, 11 Apr 2024 17:17:21 -0700
Subject: [PATCH 023/413] [Doc] Add typing hints / mypy types cleanup (#3816)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
---
benchmarks/backend_request_func.py | 62 ++++++++++---------
docs/source/conf.py | 3 +-
setup.py | 5 +-
vllm/core/block/interfaces.py | 31 ++++++----
vllm/engine/metrics.py | 10 ++-
vllm/logger.py | 8 ++-
.../model_executor/layers/rotary_embedding.py | 15 ++---
vllm/transformers_utils/config.py | 4 +-
vllm/transformers_utils/configs/dbrx.py | 2 +-
.../transformers_utils/tokenizers/baichuan.py | 10 +--
vllm/utils.py | 4 +-
11 files changed, 90 insertions(+), 64 deletions(-)
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index ad428bd1c3644..bab570252c929 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
- latency: float = 0
- ttft: float = 0 # Time to first token
+ latency: float = 0.0
+ ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0
@@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
- ttft = 0
+ ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
- async for chunk in response.content:
- chunk = chunk.strip()
- if not chunk:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
continue
- chunk = remove_prefix(chunk.decode("utf-8"), "data:")
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
- if ttft == 0:
+ if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
- ttft = 0
+ ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
- async for chunk in response.content:
- chunk = chunk.strip()
- if not chunk:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
continue
- chunk = remove_prefix(chunk.decode("utf-8"), "data:")
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
- if ttft == 0:
+ if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True
else:
- output.error = response.reason
+ output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
@@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
- output.error = response.reason
+ output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
@@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
- ttft = 0
+ ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
- async for chunk in response.content:
- chunk = chunk.strip()
- if not chunk:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
continue
- chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
@@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
- if ttft == 0:
+ if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
- ttft = 0
+ ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
- async for chunk in response.content:
- chunk = chunk.strip()
- if not chunk:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
continue
- chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
@@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
- if ttft == 0:
+ if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
@@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True
output.latency = latency
else:
- output.error = response.reason
+ output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 44cda7c99cdd5..7a8c365ffb3bb 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -12,6 +12,7 @@
import logging
import sys
+from typing import List
from sphinx.ext import autodoc
@@ -45,7 +46,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = []
+exclude_patterns: List[str] = []
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "
diff --git a/setup.py b/setup.py
index 98c92f9196e7e..9f0814e9f3bff 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
import subprocess
import sys
from shutil import which
-from typing import List
+from typing import Dict, List
import torch
from packaging.version import Version, parse
@@ -52,7 +52,7 @@ def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None:
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
- did_config = {}
+ did_config: Dict[str, bool] = {}
#
# Determine number of compilation jobs and optionally nvcc compile threads.
@@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
+ assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py
index 9f466566f096b..fbceacf0ec417 100644
--- a/vllm/core/block/interfaces.py
+++ b/vllm/core/block/interfaces.py
@@ -1,5 +1,5 @@
-from abc import ABC, abstractmethod, abstractproperty
-from typing import Dict, List, Optional, Protocol
+from abc import ABC, abstractmethod
+from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device
@@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None:
pass
- @abstractproperty
+ @property
+ @abstractmethod
def block_id(self) -> Optional[int]:
pass
- @abstractproperty
+ @property
+ @abstractmethod
def token_ids(self) -> List[int]:
pass
- @abstractproperty
+ @property
+ @abstractmethod
def num_empty_slots(self) -> int:
pass
- @abstractproperty
+ @property
+ @abstractmethod
def is_full(self) -> bool:
pass
- @abstractproperty
+ @property
+ @abstractmethod
def prev_block(self) -> Optional["Block"]:
pass
@@ -47,12 +52,13 @@ def __call__(
class BlockAllocator(ABC):
@abstractmethod
- def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
+ def allocate_mutable(self, prev_block: Optional[Block],
+ device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
- token_ids: List[int]) -> Block:
+ token_ids: List[int], device: Device) -> Block:
pass
@abstractmethod
@@ -64,11 +70,12 @@ def fork(self, last_block: Block) -> List[Block]:
pass
@abstractmethod
- def get_num_free_blocks(self) -> int:
+ def get_num_free_blocks(self, device: Device) -> int:
pass
- @abstractproperty
- def all_block_ids(self) -> frozenset[int]:
+ @property
+ @abstractmethod
+ def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index 905db52a1912b..02560907a1282 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
-from typing import Dict, List
+from typing import Dict, List, Protocol
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float]
+class SupportsMetricsInfo(Protocol):
+
+ def metrics_info(self) -> Dict[str, str]:
+ ...
+
+
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
@@ -135,7 +141,7 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
- def info(self, type: str, obj: object) -> None:
+ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
diff --git a/vllm/logger.py b/vllm/logger.py
index e5e46f5cce3fe..af9575085ef37 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -4,6 +4,7 @@
import logging
import os
import sys
+from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
@@ -26,7 +27,7 @@ def format(self, record):
_root_logger = logging.getLogger("vllm")
-_default_handler = None
+_default_handler: Optional[logging.Handler] = None
def _setup_logger():
@@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
+
if VLLM_CONFIGURE_LOGGING:
+ if _default_handler is None:
+ raise ValueError(
+ "_default_handler is not set up. This should never happen!"
+ " Please open an issue on Github.")
logger.addHandler(_default_handler)
logger.propagate = False
return logger
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index eb8d5f6dfb2a9..6519781c8a8eb 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
# Find dim range bounds based on rotations
-def _yarn_find_correction_range(low_rot: int,
- high_rot: int,
- dim: int,
- base: float = 10000,
- max_position_embeddings: int = 2048) -> int:
+def _yarn_find_correction_range(
+ low_rot: int,
+ high_rot: int,
+ dim: int,
+ base: float = 10000,
+ max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
@@ -293,8 +294,8 @@ def __init__(
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
- beta_fast: float = 32,
- beta_slow: float = 1,
+ beta_fast: int = 32,
+ beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 8a6ba6c5b396c..ce7a30dce72fa 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -1,10 +1,10 @@
-from typing import Optional
+from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
-_CONFIG_REGISTRY = {
+_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py
index 3a19af7129e73..1d2724f22abd6 100644
--- a/vllm/transformers_utils/configs/dbrx.py
+++ b/vllm/transformers_utils/configs/dbrx.py
@@ -12,7 +12,7 @@
logger = logging.get_logger(__name__)
-DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig):
diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py
index 02045bdcb2ccf..79894035cb1f1 100644
--- a/vllm/transformers_utils/tokenizers/baichuan.py
+++ b/vllm/transformers_utils/tokenizers/baichuan.py
@@ -16,11 +16,11 @@
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
-PRETRAINED_VOCAB_FILES_MAP = {
+PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
"vocab_file": {},
"tokenizer_file": {},
}
-PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
+PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
class BaichuanTokenizer(PreTrainedTokenizer):
@@ -148,9 +148,9 @@ def save_vocabulary(self,
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
- logger.error(f"Vocabulary path ({save_directory}) "
- "should be a directory")
- return
+ raise ValueError(f"Vocabulary path ({save_directory}) "
+ "should be a directory")
+
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +
diff --git a/vllm/utils.py b/vllm/utils.py
index fdb0a3768ab0d..669b65891d0db 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -294,7 +294,7 @@ def create_kv_caches_with_random(
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
- seed: Optional[int] = 0,
+ seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
@@ -400,7 +400,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
gc.collect()
-def str_to_int_tuple(s: str) -> Tuple[int]:
+def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers."""
try:
return tuple(map(int, s.split(",")))
From 1096717ae9e0b414ad625c1a12354dd1d949ffb1 Mon Sep 17 00:00:00 2001
From: Jee Li
Date: Fri, 12 Apr 2024 12:02:44 +0800
Subject: [PATCH 024/413] [Core] Support LoRA on quantized models (#4012)
---
tests/lora/conftest.py | 5 +
tests/lora/test_quant_model.py | 179 +++++++++++++++++++++++++++++++++
vllm/config.py | 9 +-
vllm/lora/layers.py | 67 +++++++-----
4 files changed, 234 insertions(+), 26 deletions(-)
create mode 100644 tests/lora/test_quant_model.py
diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py
index 207c635e2dc86..1127cc33183c9 100644
--- a/tests/lora/conftest.py
+++ b/tests/lora/conftest.py
@@ -143,6 +143,11 @@ def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
+@pytest.fixture(scope="session")
+def tinyllama_lora_files():
+ return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
+
+
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py
new file mode 100644
index 0000000000000..3d86a4366aa57
--- /dev/null
+++ b/tests/lora/test_quant_model.py
@@ -0,0 +1,179 @@
+# Adapted from
+# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
+from dataclasses import dataclass
+from typing import List
+
+import pytest
+
+import vllm
+from vllm.lora.request import LoRARequest
+
+from .conftest import cleanup
+
+
+@dataclass
+class ModelWithQuantization:
+ model_path: str
+ quantization: str
+
+
+MODELS: List[ModelWithQuantization] = [
+ ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
+ quantization="AWQ"),
+ ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
+ quantization="GPTQ"),
+]
+
+
+def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
+ raw_prompts = [
+ "Give me an orange-ish brown color",
+ "Give me a neon pink color",
+ ]
+
+ def format_prompt_tuples(prompt):
+ return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
+
+ prompts = [format_prompt_tuples(p) for p in raw_prompts]
+
+ sampling_params = vllm.SamplingParams(temperature=0,
+ max_tokens=max_tokens,
+ stop=["<|im_end|>"])
+ outputs = llm.generate(
+ prompts,
+ sampling_params,
+ lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
+ if lora_id else None)
+ # Print the outputs.
+ generated_texts = []
+ for output in outputs:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ generated_texts.append(generated_text)
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ return generated_texts
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("tp_size", [1])
+def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
+ # Cannot use as it will initialize torch.cuda too early...
+ # if torch.cuda.device_count() < tp_size:
+ # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
+
+ llm = vllm.LLM(model=model.model_path,
+ enable_lora=True,
+ max_num_seqs=16,
+ max_loras=4,
+ max_model_len=400,
+ tensor_parallel_size=tp_size,
+ quantization=model.quantization,
+ trust_remote_code=True)
+
+ if model.quantization is None:
+ expected_no_lora_output = [
+ "Here are some examples of orange-brown colors",
+ "I'm sorry, I don't have"
+ ]
+ expected_lora_output = [
+ "#ff8050",
+ "#ff8080",
+ ]
+ elif model.quantization == "AWQ":
+ expected_no_lora_output = [
+ "I'm sorry, I don't understand",
+ "I'm sorry, I don't understand",
+ ]
+ expected_lora_output = [
+ "#f07700: A v",
+ "#f00000: A v",
+ ]
+ elif model.quantization == "GPTQ":
+ expected_no_lora_output = [
+ "I'm sorry, I don't have",
+ "I'm sorry, I don't have",
+ ]
+ expected_lora_output = [
+ "#f08800: This is",
+ "#f07788 \n#",
+ ]
+
+ def expect_match(output, expected_output):
+ # HACK: GPTQ lora outputs are just incredibly unstable.
+ # Assert that the outputs changed.
+ if (model.quantization == "GPTQ"
+ and expected_output is expected_lora_output):
+ assert output != expected_no_lora_output
+ for i, o in enumerate(output):
+ assert o.startswith(
+ '#'), f"Expected example {i} to start with # but got {o}"
+ return
+ assert output == expected_output
+
+ max_tokens = 10
+
+ print("lora adapter created")
+ output = do_sample(llm,
+ tinyllama_lora_files,
+ lora_id=0,
+ max_tokens=max_tokens)
+ expect_match(output, expected_no_lora_output)
+
+ print("lora 1")
+ output = do_sample(llm,
+ tinyllama_lora_files,
+ lora_id=1,
+ max_tokens=max_tokens)
+ expect_match(output, expected_lora_output)
+
+ print("no lora")
+ output = do_sample(llm,
+ tinyllama_lora_files,
+ lora_id=0,
+ max_tokens=max_tokens)
+ expect_match(output, expected_no_lora_output)
+
+ print("lora 2")
+ output = do_sample(llm,
+ tinyllama_lora_files,
+ lora_id=2,
+ max_tokens=max_tokens)
+ expect_match(output, expected_lora_output)
+
+ print("removing lora")
+
+ del llm
+ cleanup()
+
+
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.skip("Requires multiple GPUs")
+def test_quant_model_tp_equality(tinyllama_lora_files, model):
+ # Cannot use as it will initialize torch.cuda too early...
+ # if torch.cuda.device_count() < 2:
+ # pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
+
+ llm_tp1 = vllm.LLM(model=model.model_path,
+ enable_lora=True,
+ max_num_seqs=16,
+ max_loras=4,
+ tensor_parallel_size=1,
+ quantization=model.quantization,
+ trust_remote_code=True)
+ output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
+
+ del llm_tp1
+ cleanup()
+
+ llm_tp2 = vllm.LLM(model=model.model_path,
+ enable_lora=True,
+ max_num_seqs=16,
+ max_loras=4,
+ tensor_parallel_size=2,
+ quantization=model.quantization)
+ output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
+
+ del llm_tp2
+ cleanup()
+
+ assert output_tp1 == output_tp2
diff --git a/vllm/config.py b/vllm/config.py
index 4102edbe01d35..da7eb2810ff05 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -822,9 +822,12 @@ def verify_with_model_config(self, model_config: ModelConfig):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
- if model_config.quantization is not None:
- raise ValueError(
- "LoRA is not supported with quantized models yet.")
+ if model_config.quantization and model_config.quantization not in [
+ "awq", "gptq"
+ ]:
+ # TODO support marlin and squeezellm
+ logger.warning(f"{model_config.quantization} quantization is not "
+ "tested with LoRA yet.")
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 5456b5613c47a..4b9653de73a88 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -29,6 +29,19 @@
pass
+def _get_lora_device(base_layer: nn.Module) -> torch.device:
+ # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
+ """Returns the device for where to place the LoRA tensors."""
+ if hasattr(base_layer, "weight"):
+ return base_layer.weight.device
+ if hasattr(base_layer, "linear_weights") and isinstance(
+ base_layer.linear_weights, dict):
+ values = list(base_layer.linear_weights.values())
+ if len(values) and isinstance(values[0], torch.Tensor):
+ return values[0].device
+ raise ValueError(f"Unsupported base layer: {base_layer}")
+
+
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
@@ -302,6 +315,9 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
+ self.input_size = self.base_layer.input_size
+ self.output_size = self.base_layer.output_size_per_partition
+ self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
@@ -312,17 +328,17 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
- self.base_layer.weight.shape[1],
+ self.input_size,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
- self.base_layer.weight.shape[0],
+ self.output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
)
self.indices: Optional[torch.Tensor] = None
@@ -442,18 +458,18 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
- self.base_layer.weight.shape[1],
+ self.input_size,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
) for _ in range(n_slices))
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
- self.base_layer.weight.shape[0] // 2,
+ self.output_size // 2,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None
@@ -619,25 +635,25 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
- self.base_layer.weight.shape[1],
+ self.input_size,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
- self.base_layer.weight.shape[1],
+ self.input_size,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
),
torch.zeros(
max_loras,
1,
lora_config.max_lora_rank,
- self.base_layer.weight.shape[1],
+ self.input_size,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
),
)
self.lora_b_stacked = (
@@ -647,7 +663,7 @@ def create_lora_weights(
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
),
torch.zeros(
max_loras,
@@ -655,7 +671,7 @@ def create_lora_weights(
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
),
torch.zeros(
max_loras,
@@ -663,7 +679,7 @@ def create_lora_weights(
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
),
)
@@ -766,6 +782,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
+ self.input_size = self.base_layer.input_size_per_partition
+ self.output_size = self.base_layer.output_size
+ self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
@@ -777,20 +796,20 @@ def create_lora_weights(
max_loras,
1,
lora_config.max_lora_rank,
- self.base_layer.weight.shape[1],
+ self.input_size,
),
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
- self.base_layer.weight.shape[0],
+ self.output_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
- device=self.base_layer.weight.device,
+ device=self.device,
)
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
@@ -809,7 +828,7 @@ def set_lora(
self.reset_lora(index)
if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
- shard_size = self.base_layer.weight.shape[1]
+ shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
@@ -884,7 +903,9 @@ def forward(self, input_):
@property
def weight(self):
- return self.base_layer.weight
+
+ return self.base_layer.weight if hasattr(
+ self.base_layer, "weight") else self.base_layer.qweight
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
From 7fd3949a0b1c6cd0dcd7066aca48d9d589f2f68e Mon Sep 17 00:00:00 2001
From: Cyrus Leung
Date: Fri, 12 Apr 2024 13:30:54 +0800
Subject: [PATCH 025/413] [Frontend][Core] Move `merge_async_iterators` to
utils (#4026)
---
vllm/entrypoints/openai/serving_completion.py | 38 +-----------------
vllm/utils.py | 40 ++++++++++++++++++-
2 files changed, 39 insertions(+), 39 deletions(-)
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index c1f1744a118bd..e24aa2489a80f 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -1,4 +1,3 @@
-import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
@@ -17,7 +16,7 @@
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
-from vllm.utils import random_uuid
+from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts
-def merge_async_iterators(*iterators):
- """Merge multiple asynchronous iterators into a single iterator.
-
- This method handle the case where some iterators finish before others.
- When it yields, it yields a tuple (i, item) where i is the index of the
- iterator that yields the item.
- """
- queue = asyncio.Queue()
-
- finished = [False] * len(iterators)
-
- async def producer(i, iterator):
- try:
- async for item in iterator:
- await queue.put((i, item))
- except Exception as e:
- await queue.put(e)
- finished[i] = True
-
- _tasks = [
- asyncio.create_task(producer(i, iterator))
- for i, iterator in enumerate(iterators)
- ]
-
- async def consumer():
- while not all(finished) or not queue.empty():
- item = await queue.get()
- if isinstance(item, Exception):
- raise item
- yield item
- await asyncio.gather(*_tasks)
-
- return consumer()
-
-
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
diff --git a/vllm/utils.py b/vllm/utils.py
index 669b65891d0db..0967dfc969c8a 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -9,8 +9,8 @@
from collections import OrderedDict, defaultdict
from functools import lru_cache, partial
from platform import uname
-from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List,
- Optional, Tuple, TypeVar, Union)
+from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
+ Hashable, List, Optional, Tuple, TypeVar, Union)
import psutil
import torch
@@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future:
return _async_wrapper
+def merge_async_iterators(
+ *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
+ """Merge multiple asynchronous iterators into a single iterator.
+
+ This method handle the case where some iterators finish before others.
+ When it yields, it yields a tuple (i, item) where i is the index of the
+ iterator that yields the item.
+ """
+ queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
+
+ finished = [False] * len(iterators)
+
+ async def producer(i: int, iterator: AsyncIterator[T]):
+ try:
+ async for item in iterator:
+ await queue.put((i, item))
+ except Exception as e:
+ await queue.put(e)
+ finished[i] = True
+
+ _tasks = [
+ asyncio.create_task(producer(i, iterator))
+ for i, iterator in enumerate(iterators)
+ ]
+
+ async def consumer():
+ while not all(finished) or not queue.empty():
+ item = await queue.get()
+ if isinstance(item, Exception):
+ raise item
+ yield item
+ await asyncio.gather(*_tasks)
+
+ return consumer()
+
+
def get_ip() -> str:
host_ip = os.environ.get("HOST_IP")
if host_ip:
From 36729bac1303b655b816b77f45b17237bfafd692 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Sat, 13 Apr 2024 01:56:57 +0900
Subject: [PATCH 026/413] [Test] Test multiple attn backend for chunked
prefill. (#4023)
---
.buildkite/test-pipeline.yaml | 8 +++++++-
.../test_basic_correctness.py | 6 ------
.../basic_correctness/test_chunked_prefill.py | 4 ----
vllm/attention/backends/rocm_flash_attn.py | 18 ++++++------------
4 files changed, 13 insertions(+), 23 deletions(-)
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 695290ed74ab5..8d7d6304cf12e 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -12,7 +12,13 @@ steps:
command: pytest -v -s async_engine
- label: Basic Correctness Test
- command: pytest -v -s basic_correctness
+ commands:
+ - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
+ - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
+ - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
+ - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
+ - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
+ - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test
command: pytest -v -s core
diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py
index bd4c7ea3301be..97cff623c5e1d 100644
--- a/tests/basic_correctness/test_basic_correctness.py
+++ b/tests/basic_correctness/test_basic_correctness.py
@@ -4,8 +4,6 @@
"""
import pytest
-from vllm.attention.selector import VLLM_ATTENTION_BACKEND
-
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
@@ -16,7 +14,6 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
-@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
@@ -25,10 +22,7 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
- attn_backend: str,
- monkeypatch,
) -> None:
- monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py
index 9ff07b3c09020..d83416eb51b43 100644
--- a/tests/basic_correctness/test_chunked_prefill.py
+++ b/tests/basic_correctness/test_chunked_prefill.py
@@ -33,10 +33,6 @@ def test_models(
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
- if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
- and not enforce_eager):
- pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
- "for high TP to save testing time.")
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py
index e55435cd2c947..c42660fb8f74f 100644
--- a/vllm/attention/backends/rocm_flash_attn.py
+++ b/vllm/attention/backends/rocm_flash_attn.py
@@ -162,7 +162,7 @@ def __init__(
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
- self.attn_fuc = _naive_attention()
+ self.attn_fuc = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
@@ -334,26 +334,21 @@ def _naive_attention(
prompt_lens: List[int],
scale: float,
) -> torch.Tensor:
- num_tokens = query.shape[0]
output = torch.empty_like(query)
start = 0
for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len
out = _naive_masked_attention(
- query[None, start:end],
- key[None, start:end],
- value[None, start:end],
+ query[start:end],
+ key[start:end],
+ value[start:end],
scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len
- # Using view got RuntimeError: view size is not compatible
- # with input tensor's size and stride (at least one
- # dimension spans across two contiguous subspaces).
- # Use reshape instead.
- return output.reshape(num_tokens, -1)
+ return output
def _naive_masked_attention(
@@ -362,14 +357,13 @@ def _naive_masked_attention(
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
- seq_len, _, _ = query.shape
+ seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min
-
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
From 96b6a6d790115d04bb87d410f3bdd5d7d85b43f1 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Fri, 12 Apr 2024 12:35:44 -0700
Subject: [PATCH 027/413] [Bugfix] fix type hint for py 3.8 (#4036)
---
vllm/executor/executor_base.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index c18edd75d7a4d..55bccfa8e3ca9 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@@ -31,7 +31,7 @@ def __init__(
raise NotImplementedError
@abstractmethod
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
From d4ec9ffb9574988132d927fd1615180522877262 Mon Sep 17 00:00:00 2001
From: Zhuohan Li
Date: Fri, 12 Apr 2024 13:56:04 -0700
Subject: [PATCH 028/413] [Misc] Fix typo in scheduler.py (#4022)
---
vllm/core/scheduler.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 2942eab735a92..e44f983e15374 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -674,7 +674,7 @@ def _schedule_prefills(
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
- The current policy is designed to opimimize the throughput. First,
+ The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
be swapped or preempted.
From 09473ee41c0a22c4d18936ea7eb2328071c19308 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Sat, 13 Apr 2024 06:35:50 +0900
Subject: [PATCH 029/413] [mypy] Add mypy type annotation part 1 (#4006)
---
.github/workflows/mypy.yaml | 50 ++++++++++++++++++++++++++
format.sh | 22 +++++++++---
pyproject.toml | 5 ++-
requirements-common.txt | 3 +-
requirements-dev.txt | 2 +-
vllm/config.py | 9 +++--
vllm/core/block_manager_v1.py | 12 ++++---
vllm/core/block_manager_v2.py | 4 ++-
vllm/core/interfaces.py | 4 ++-
vllm/core/scheduler.py | 25 +++++++------
vllm/distributed/communication_op.py | 10 +++---
vllm/engine/ray_utils.py | 18 ++++++----
vllm/entrypoints/api_server.py | 1 +
vllm/entrypoints/llm.py | 8 +++--
vllm/executor/cpu_executor.py | 4 +--
vllm/executor/gpu_executor.py | 4 +--
vllm/executor/neuron_executor.py | 4 +--
vllm/executor/ray_gpu_executor.py | 11 +++---
vllm/sampling_params.py | 5 +--
vllm/sequence.py | 8 ++---
vllm/transformers_utils/config.py | 3 +-
vllm/transformers_utils/detokenizer.py | 7 ++--
vllm/transformers_utils/tokenizer.py | 4 +--
vllm/usage/usage_lib.py | 8 ++---
vllm/utils.py | 12 ++++---
25 files changed, 171 insertions(+), 72 deletions(-)
create mode 100644 .github/workflows/mypy.yaml
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
new file mode 100644
index 0000000000000..fbe0f816fd4af
--- /dev/null
+++ b/.github/workflows/mypy.yaml
@@ -0,0 +1,50 @@
+name: mypy
+
+on:
+ # Trigger the workflow on push or pull request,
+ # but only for the main branch
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ ruff:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.8"]
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install mypy==1.9.0
+ pip install types-setuptools
+ pip install types-PyYAML
+ pip install types-requests
+ pip install types-setuptools
+ - name: Mypy
+ run: |
+ mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
+
+ # TODO(sang): Follow up
+ # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
+ # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
+ # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
+ # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
+ # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
+
diff --git a/format.sh b/format.sh
index deb57b2b049d1..1c195b899c742 100755
--- a/format.sh
+++ b/format.sh
@@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done'
# Run mypy
-# TODO(zhuohan): Enable mypy
-# echo 'vLLM mypy:'
-# mypy
+echo 'vLLM mypy:'
+mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
+
+# TODO(sang): Follow up
+# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
+# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
+# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
+# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
+# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
+
CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
@@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
exit 1
fi
-
-
diff --git a/pyproject.toml b/pyproject.toml
index 2a00d6796ee02..b870a4b85897b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,10 +46,13 @@ ignore = [
python_version = "3.8"
ignore_missing_imports = true
+ check_untyped_defs = true
files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
-exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
+exclude = [
+ "vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
+]
[tool.codespell]
diff --git a/requirements-common.txt b/requirements-common.txt
index ff053388a23e1..c96f9c9937fb0 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -11,4 +11,5 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
-outlines == 0.0.34 # Requires torch >= 2.1.0
\ No newline at end of file
+outlines == 0.0.34 # Requires torch >= 2.1.0
+typing_extensions
\ No newline at end of file
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 75d22bbdb2a1b..96dfda6faf00f 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2
# type checking
-mypy==0.991
+mypy==1.9.0
types-PyYAML
types-requests
types-setuptools
diff --git a/vllm/config.py b/vllm/config.py
index da7eb2810ff05..bbda4ecf3cc56 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -2,7 +2,7 @@
import json
import os
from dataclasses import dataclass, fields
-from typing import TYPE_CHECKING, ClassVar, Optional, Union
+from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch
from packaging.version import Version
@@ -141,7 +141,7 @@ def _verify_load_format(self) -> None:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
- rocm_not_supported_load_format = []
+ rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
@@ -679,6 +679,9 @@ def maybe_create_spec_config(
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
+ assert (speculative_model is not None
+ and num_speculative_tokens is not None)
+
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
@@ -993,7 +996,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor
if max_model_len is None:
- max_model_len = derived_max_model_len
+ max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py
index e7e3b4dc1e9b4..e391a3b1e5a33 100644
--- a/vllm/core/block_manager_v1.py
+++ b/vllm/core/block_manager_v1.py
@@ -1,5 +1,6 @@
"""A block manager that manages token blocks."""
from abc import ABC, abstractmethod
+from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set
@@ -231,10 +232,10 @@ def __init__(
if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
- self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
- num_gpu_blocks)
- self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
- num_cpu_blocks)
+ self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
+ Device.GPU, block_size, num_gpu_blocks)
+ self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
+ Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
@@ -588,7 +589,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
- def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
+ def get_common_computed_block_ids(
+ self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py
index 813e71ad883b2..19f0cf415eb34 100644
--- a/vllm/core/block_manager_v2.py
+++ b/vllm/core/block_manager_v2.py
@@ -1,4 +1,5 @@
"""A block manager that manages token blocks."""
+from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional
from vllm.core.block.block_table import BlockTable
@@ -205,7 +206,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# as computed.
self.block_allocator.mark_blocks_as_computed()
- def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
+ def get_common_computed_block_ids(
+ self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py
index 711536bcc97be..c1f68a2e891bf 100644
--- a/vllm/core/interfaces.py
+++ b/vllm/core/interfaces.py
@@ -1,5 +1,6 @@
import enum
from abc import ABC, abstractmethod
+from collections.abc import Sequence as GenericSequence
from typing import Dict, List
from vllm.sequence import Sequence, SequenceGroup
@@ -103,7 +104,8 @@ def access_all_blocks_in_seq(
pass
@abstractmethod
- def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
+ def get_common_computed_block_ids(
+ self, seqs: List[Sequence]) -> GenericSequence[int]:
pass
@abstractmethod
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index e44f983e15374..18ddcd1d6d466 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -42,8 +42,8 @@ class SchedulingBudget:
"""
token_budget: int
max_num_seqs: int
- _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
- _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
+ _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
+ _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0
@@ -133,7 +133,7 @@ def is_empty(self) -> bool:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
- def _sort_by_lora_ids(self) -> bool:
+ def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
self.free_seq(seq)
def has_unfinished_seqs(self) -> bool:
- return self.waiting or self.running or self.swapped
+ return len(self.waiting) != 0 or len(self.running) != 0 or len(
+ self.swapped) != 0
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
@@ -404,7 +405,7 @@ def _schedule_running(
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
- curr_loras.pop(seq_group.lora_int_id)
+ curr_loras.remove(seq_group.lora_int_id)
if running_queue:
# Preempt the lowest-priority sequence groups.
@@ -496,7 +497,7 @@ def _schedule_swapped(
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
- leftover_swapped = deque()
+ leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
@@ -507,7 +508,9 @@ def _schedule_swapped(
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
- if (lora_int_id > 0 and lora_int_id not in curr_loras
+ assert curr_loras is not None
+ assert self.lora_config is not None
+ if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
@@ -593,7 +596,7 @@ def _schedule_prefills(
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
- leftover_waiting_sequences = deque()
+ leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]
@@ -635,6 +638,8 @@ def _schedule_prefills(
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
+ assert curr_loras is not None
+ assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
@@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self):
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
- curr_loras = set()
+ curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
@@ -1087,7 +1092,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
- budget: SchedulingBudget) -> Tuple[int, bool]:
+ budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.
diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py
index 1004d626b6a4b..a3e93691a1e8e 100644
--- a/vllm/distributed/communication_op.py
+++ b/vllm/distributed/communication_op.py
@@ -1,5 +1,5 @@
from collections import namedtuple
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed import ProcessGroup
@@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
-) -> Dict[Any, Union[torch.Tensor, Any]]:
+) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
@@ -157,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank()
if rank == src:
+ metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
- metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
@@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
- metadata_list = recv_metadata_list[0]
+ assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
- for key, value in metadata_list:
+ for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py
index 70d5c9b1fae05..04d4ed83976d0 100644
--- a/vllm/engine/ray_utils.py
+++ b/vllm/engine/ray_utils.py
@@ -1,9 +1,10 @@
import pickle
-from typing import List, Optional, Tuple
+from typing import Callable, List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
+from vllm.worker.worker import Worker
logger = init_logger(__name__)
@@ -18,15 +19,20 @@ def __init__(self, init_cached_hf_modules=False) -> None:
if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
- self.worker = None
+ self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
- def init_worker(self, worker_init_fn):
- self.worker = worker_init_fn()
+ def init_worker(self, worker_init_fn: Callable[[], Worker]):
+ self._worker = worker_init_fn()
+
+ @property
+ def worker(self) -> Worker:
+ assert self._worker is not None
+ return self._worker
def __getattr__(self, name):
return getattr(self.worker, name)
@@ -70,8 +76,8 @@ def execute_model_compiled_dag_remote(self, ignored):
logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with "
"`pip install ray`.")
- ray = None
- RayWorkerVllm = None
+ ray = None # type: ignore
+ RayWorkerVllm = None # type: ignore
def initialize_ray_cluster(
diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py
index 2a47eae112c12..587142adb9c6b 100644
--- a/vllm/entrypoints/api_server.py
+++ b/vllm/entrypoints/api_server.py
@@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
+ assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 5777e8179a1c1..63ff0b30da552 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -170,8 +170,12 @@ def generate(
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
- num_requests = len(prompts) if prompts is not None else len(
- prompt_token_ids)
+ if prompts is not None:
+ num_requests = len(prompts)
+ else:
+ assert prompt_token_ids is not None
+ num_requests = len(prompt_token_ids)
+
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index eda4e8989c163..33e67d8b3eec2 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -1,5 +1,5 @@
import os
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
import torch
@@ -61,7 +61,7 @@ def _init_worker(self):
self.driver_worker.init_device()
self.driver_worker.load_model()
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index 80ca5cb7367c5..f20221a0b941a 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@@ -66,7 +66,7 @@ def _init_worker(self):
self.driver_worker.init_device()
self.driver_worker.load_model()
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index 57436a85cfa27..ee8e87432fa67 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@@ -47,7 +47,7 @@ def _init_worker(self):
self.driver_worker.init_device()
self.driver_worker.load_model()
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 6c0ccd7e64c90..b937693c92257 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -3,7 +3,7 @@
import os
import pickle
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
@@ -197,7 +197,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
max_parallel_loading_workers,
)
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
@@ -205,7 +205,7 @@ def determine_num_available_blocks(self) -> tuple[int, int]:
compatible with all workers.
Returns:
- - tuple[num_gpu_blocks, num_cpu_blocks]
+ - Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", )
@@ -276,7 +276,7 @@ def _run_workers(
self,
method: str,
*args,
- driver_args: Optional[List[Any]] = None,
+ driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
@@ -291,6 +291,7 @@ def _run_workers(
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
+ assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
@@ -369,7 +370,7 @@ async def _run_workers_async(
self,
method: str,
*args,
- driver_args: Optional[List[Any]] = None,
+ driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 0b9787608798c..53a38b25bfdac 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -5,7 +5,8 @@
from typing import Callable, List, Optional, Union
import torch
-from pydantic import conint
+from pydantic import Field
+from typing_extensions import Annotated
_SAMPLING_EPS = 1e-5
@@ -127,7 +128,7 @@ def __init__(
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
- truncate_prompt_tokens: Optional[conint(ge=1)] = None,
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
diff --git a/vllm/sequence.py b/vllm/sequence.py
index cdb6cce6f0255..dcde81df19923 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -171,10 +171,10 @@ def get_last_token_id(self) -> int:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
- def get_prompt_token_ids(self) -> int:
+ def get_prompt_token_ids(self) -> List[int]:
return self.prompt_token_ids
- def get_output_token_ids(self) -> int:
+ def get_output_token_ids(self) -> List[int]:
return self.output_token_ids
@property
@@ -370,7 +370,7 @@ class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
- generator: Optional = None
+ generator: Optional = None # type: ignore
class MultiModalData:
@@ -599,7 +599,7 @@ def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
- def token_chunk_size(self) -> int:
+ def token_chunk_size(self) -> Optional[int]:
"""Return the number of tokens to be processed (chunk size)."""
return self._token_chunk_size
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index ce7a30dce72fa..1756c91a612f0 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -2,7 +2,8 @@
from transformers import AutoConfig, PretrainedConfig
-from vllm.transformers_utils.configs import *
+from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
+ JAISConfig, MPTConfig, RWConfig)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py
index 005932f1e3df4..f064c26c3f40c 100644
--- a/vllm/transformers_utils/detokenizer.py
+++ b/vllm/transformers_utils/detokenizer.py
@@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders(
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
- sub_texts = []
- current_sub_text = []
+ sub_texts: List[str] = []
+ current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
@@ -263,6 +263,7 @@ def detokenize_incrementally(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
+ assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer):
@@ -271,6 +272,8 @@ def detokenize_incrementally(
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
+ if isinstance(new_tokens, str):
+ new_tokens = [new_tokens]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index e216a99af91f9..5d3d5801c960d 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -5,7 +5,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
-from vllm.transformers_utils.tokenizers import *
+from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async
logger = init_logger(__name__)
@@ -28,7 +28,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
- class CachedTokenizer(tokenizer.__class__):
+ class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self):
diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py
index 658fe5c98f5ee..b2672f7f1da61 100644
--- a/vllm/usage/usage_lib.py
+++ b/vllm/usage/usage_lib.py
@@ -7,7 +7,7 @@
from enum import Enum
from pathlib import Path
from threading import Thread
-from typing import Dict, Optional
+from typing import Any, Dict, Optional
from uuid import uuid4
import cpuinfo
@@ -124,7 +124,7 @@ def __init__(self) -> None:
def report_usage(self,
model_architecture: str,
usage_context: UsageContext,
- extra_kvs: Dict[str, any] = None) -> None:
+ extra_kvs: Optional[Dict[str, Any]] = None) -> None:
t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True)
@@ -132,13 +132,13 @@ def report_usage(self,
def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext,
- extra_kvs: Dict[str, any]) -> None:
+ extra_kvs: Dict[str, Any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage()
def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext,
- extra_kvs: Dict[str, any]) -> None:
+ extra_kvs: Dict[str, Any]) -> None:
# Platform information
if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0)
diff --git a/vllm/utils.py b/vllm/utils.py
index 0967dfc969c8a..4c0dc9ca729a9 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -60,7 +60,7 @@ def __contains__(self, key: Hashable) -> bool:
def __len__(self) -> int:
return len(self.cache)
- def __getitem__(self, key: Hashable) -> T:
+ def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key)
def __setitem__(self, key: Hashable, value: T) -> None:
@@ -76,7 +76,7 @@ def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache:
- value = self.cache[key]
+ value: Optional[T] = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
@@ -87,7 +87,7 @@ def put(self, key: Hashable, value: T) -> None:
self.cache.move_to_end(key)
self._remove_old_if_needed()
- def _on_remove(self, key: Hashable, value: T):
+ def _on_remove(self, key: Hashable, value: Optional[T]):
pass
def remove_oldest(self):
@@ -100,9 +100,11 @@ def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
- def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
+ def pop(self,
+ key: Hashable,
+ default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
- value = self.cache.pop(key, default_value)
+ value: Optional[T] = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value
From fbb9d9eef48a29e0ea821bbf399e4bf9a08d6ac1 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Fri, 12 Apr 2024 16:40:39 -0700
Subject: [PATCH 030/413] [Core] fix custom allreduce default value (#4040)
---
vllm/entrypoints/llm.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 63ff0b30da552..9e08c253dc539 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -86,7 +86,7 @@ def __init__(
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
- disable_custom_all_reduce: bool = True,
+ disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
From d04973ad5446fe05c06035f6b2d99402fc3ac7bf Mon Sep 17 00:00:00 2001
From: Bellk17
Date: Fri, 12 Apr 2024 16:41:26 -0700
Subject: [PATCH 031/413] Fix triton compilation issue (#3984)
Co-authored-by: Woosuk Kwon
---
vllm/attention/ops/triton_flash_attention.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index 87cf30cbef79a..e160411859f0b 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -415,7 +415,11 @@ def attn_fwd(
return
is_mqa = hq != hk
- off_h_k = off_h_q % hk if is_mqa else off_h_q
+ if is_mqa: # noqa: SIM108
+ off_h_k = off_h_q % hk
+ else:
+ off_h_k = off_h_q
+
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
From b8aacac31a4e2e03381fdaef6f1e4bbb895f3b64 Mon Sep 17 00:00:00 2001
From: Jee Li
Date: Sat, 13 Apr 2024 07:56:37 +0800
Subject: [PATCH 032/413] [Bugfix] Fix LoRA bug (#4032)
---
vllm/lora/layers.py | 15 +++++++++------
1 file changed, 9 insertions(+), 6 deletions(-)
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 4b9653de73a88..aac86351b15e1 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -32,14 +32,17 @@
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
+ # unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
- if hasattr(base_layer, "linear_weights") and isinstance(
- base_layer.linear_weights, dict):
- values = list(base_layer.linear_weights.values())
- if len(values) and isinstance(values[0], torch.Tensor):
- return values[0].device
- raise ValueError(f"Unsupported base layer: {base_layer}")
+ # GPTQ/AWQ/SqueezeLLM
+ elif hasattr(base_layer, "qweight"):
+ return base_layer.qweight.device
+ # marlin
+ elif hasattr(base_layer, "B"):
+ return base_layer.B.device
+ else:
+ raise ValueError(f"Unsupported base layer: {base_layer}")
def _apply_lora(
From 546e7211684a28bbe53088961b4cf5123e235760 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Fri, 12 Apr 2024 18:43:37 -0700
Subject: [PATCH 033/413] [CI/Test] expand ruff and yapf for all supported
python version (#4037)
---
.github/workflows/mypy.yaml | 2 +-
.github/workflows/ruff.yml | 2 +-
.github/workflows/yapf.yml | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
index fbe0f816fd4af..6db0bb7645ecd 100644
--- a/.github/workflows/mypy.yaml
+++ b/.github/workflows/mypy.yaml
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml
index e8060e369a889..e71033f828006 100644
--- a/.github/workflows/ruff.yml
+++ b/.github/workflows/ruff.yml
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml
index b163c960db555..04f307bcf8b0e 100644
--- a/.github/workflows/yapf.yml
+++ b/.github/workflows/yapf.yml
@@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.10"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
From 5c2e66e4871917c5d59cc4a8b89ef53e690e9bd9 Mon Sep 17 00:00:00 2001
From: Dylan Hawk <51147702+dylanwhawk@users.noreply.github.com>
Date: Fri, 12 Apr 2024 21:07:04 -0700
Subject: [PATCH 034/413] [Bugfix] More type hint fixes for py 3.8 (#4039)
---
vllm/executor/executor_base.py | 2 +-
vllm/worker/cpu_worker.py | 4 ++--
vllm/worker/neuron_worker.py | 4 ++--
vllm/worker/worker_base.py | 6 +++---
4 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index 55bccfa8e3ca9..bbfbfc689c99f 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -39,7 +39,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers.
- Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
+ Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 3989207e8dd83..41341b063bed7 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -1,5 +1,5 @@
"""A CPU worker class."""
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
@@ -157,7 +157,7 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py
index 6136d50d0c068..2f22f82c045db 100644
--- a/vllm/worker/neuron_worker.py
+++ b/vllm/worker/neuron_worker.py
@@ -1,5 +1,5 @@
"""A Neuron worker class."""
-from typing import List, Optional
+from typing import List, Optional, Tuple
import torch
import torch.distributed
@@ -40,7 +40,7 @@ def init_device(self) -> None:
def load_model(self):
self.model_runner.load_model()
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks.
Swapping is not yet supported, so always return num_cpu_blocks=0.
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index e3027c406ffeb..d8c9febb11584 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Dict, List
+from typing import Dict, List, Tuple
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@@ -18,14 +18,14 @@ def init_device(self) -> None:
raise NotImplementedError
@abstractmethod
- def determine_num_available_blocks(self) -> tuple[int, int]:
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
- Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
+ Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
From 98afde19fc273b1e6a695990b93ec07157b856f1 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Sat, 13 Apr 2024 07:12:53 -0700
Subject: [PATCH 035/413] [Core][Distributed] improve logging for init dist
(#4042)
---
vllm/distributed/parallel_state.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 9fceffe7cb88b..1258bf58cb453 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -8,6 +8,10 @@
import torch
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
@@ -45,6 +49,8 @@ def init_distributed_environment(
local_rank: int = -1,
backend: str = "nccl",
):
+ logger.debug(f"{world_size=} {rank=} {local_rank=} "
+ f"{distributed_init_method=} {backend=}")
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
From ec8e3c695f2dce080bde569746180300e91084a3 Mon Sep 17 00:00:00 2001
From: zspo
Date: Sat, 13 Apr 2024 22:52:36 +0800
Subject: [PATCH 036/413] [Bugfix] fix_log_time_in_metrics (#4050)
---
vllm/engine/metrics.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index 02560907a1282..04e27e69ce0f3 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -130,7 +130,7 @@ class StatLogger:
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally.
- self.last_local_log = time.monotonic()
+ self.last_local_log = time.time()
self.local_interval = local_interval
# Tracked stats over current local logging interval.
From 0a430b4ae2763c2f161e3bfb1529acf4685f7caa Mon Sep 17 00:00:00 2001
From: zspo
Date: Sat, 13 Apr 2024 22:54:03 +0800
Subject: [PATCH 037/413] [Bugfix] fix_small_bug_in_neuron_executor (#4051)
---
vllm/executor/neuron_executor.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index ee8e87432fa67..d45f18e466256 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -25,6 +25,7 @@ def __init__(
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
+ self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
@@ -43,6 +44,7 @@ def _init_worker(self):
self.parallel_config,
self.scheduler_config,
self.device_config,
+ self.cache_config,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
From 989ae2538df211ca3a31f77ac8e106c5c97c6e53 Mon Sep 17 00:00:00 2001
From: Jee Li
Date: Sat, 13 Apr 2024 22:55:05 +0800
Subject: [PATCH 038/413] [Kernel] Add punica dimension for Baichuan-13B
(#4053)
---
csrc/punica/bgmv/bgmv_config.h | 1 +
tests/lora/test_baichuan.py | 2 +-
tests/lora/test_punica.py | 1 +
3 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h
index 9b76b98ab3322..d2906914f927e 100644
--- a/csrc/punica/bgmv/bgmv_config.h
+++ b/csrc/punica/bgmv/bgmv_config.h
@@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
+ f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py
index 2178266d2e0c8..5ab863eea94b3 100644
--- a/tests/lora/test_baichuan.py
+++ b/tests/lora/test_baichuan.py
@@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files):
@pytest.mark.skip("Requires multiple GPUs")
-def test_llama_tensor_parallel_equality(baichuan_lora_files):
+def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py
index cab8b44ccd2df..8b174f01d87d4 100644
--- a/tests/lora/test_punica.py
+++ b/tests/lora/test_punica.py
@@ -72,6 +72,7 @@ def _lora_ref_impl(
11008,
13824,
14336,
+ 15360,
22016,
24576,
27392,
From 711a000255eac3e034f0b73aa5cc62b45201a571 Mon Sep 17 00:00:00 2001
From: Sanger Steel
Date: Sat, 13 Apr 2024 20:13:01 -0400
Subject: [PATCH 039/413] [Frontend] [Core] feat: Add model loading using
`tensorizer` (#3476)
---
.buildkite/test-pipeline.yaml | 3 +
docs/source/conf.py | 1 +
docs/source/models/engine_args.rst | 3 +-
examples/tensorize_vllm_model.py | 254 ++++++++++++++
requirements-cpu.txt | 2 +-
requirements-dev.txt | 1 +
setup.py | 3 +
tests/tensorizer/__init__.py | 0
.../tensorize_vllm_model_for_testing.py | 245 ++++++++++++++
tests/tensorizer/test_tensorizer.py | 302 +++++++++++++++++
vllm/config.py | 74 +++-
vllm/engine/arg_utils.py | 45 ++-
vllm/engine/llm_engine.py | 8 +-
vllm/executor/gpu_executor.py | 23 +-
vllm/executor/ray_gpu_executor.py | 6 +-
vllm/model_executor/model_loader.py | 61 +++-
vllm/model_executor/tensorizer_loader.py | 319 ++++++++++++++++++
vllm/model_executor/weight_utils.py | 34 +-
vllm/worker/model_runner.py | 9 +-
vllm/worker/worker.py | 9 +-
20 files changed, 1351 insertions(+), 51 deletions(-)
create mode 100644 examples/tensorize_vllm_model.py
create mode 100644 tests/tensorizer/__init__.py
create mode 100644 tests/tensorizer/tensorize_vllm_model_for_testing.py
create mode 100644 tests/tensorizer/test_tensorizer.py
create mode 100644 vllm/model_executor/tensorizer_loader.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 8d7d6304cf12e..aa4582bbda0c7 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -91,6 +91,9 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
+- label: Tensorizer Test
+ command: apt-get install curl libsodium23 && pytest -v -s tensorizer
+
- label: Metrics Test
command: pytest -v -s metrics
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 7a8c365ffb3bb..19cc8557a7541 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -83,6 +83,7 @@
"vllm._C",
"numpy",
"tqdm",
+ "tensorizer",
]
for mock_target in autodoc_mock_imports:
diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst
index d8a7ac72e0175..886a806934c04 100644
--- a/docs/source/models/engine_args.rst
+++ b/docs/source/models/engine_args.rst
@@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
Directory to download and load the weights, default to the default cache dir of huggingface.
-.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
+.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
The format of the model weights to load.
@@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
* "safetensors" will load the weights in the safetensors format.
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
* "dummy" will initialize the weights with random values, mainly for profiling.
+ * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`.
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py
new file mode 100644
index 0000000000000..3c20a38c7f726
--- /dev/null
+++ b/examples/tensorize_vllm_model.py
@@ -0,0 +1,254 @@
+import argparse
+import dataclasses
+import os
+import time
+import uuid
+from functools import partial
+from typing import Type
+
+import torch
+import torch.nn as nn
+from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
+ TensorSerializer, stream_io)
+from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
+from transformers import AutoConfig, PretrainedConfig
+
+from vllm.distributed import initialize_model_parallel
+from vllm.engine.arg_utils import EngineArgs
+from vllm.engine.llm_engine import LLMEngine
+from vllm.model_executor.models import ModelRegistry
+from vllm.model_executor.tensorizer_loader import TensorizerArgs
+
+# yapf conflicts with isort for this docstring
+# yapf: disable
+"""
+tensorize_vllm_model.py is a script that can be used to serialize and
+deserialize vLLM models. These models can be loaded using tensorizer directly
+to the GPU extremely quickly. Tensor encryption and decryption is also
+supported, although libsodium must be installed to use it. Install
+vllm with tensorizer support using `pip install vllm[tensorizer]`.
+
+To serialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+ --model EleutherAI/gpt-j-6B \
+ --dtype float16 \
+ serialize \
+ --serialized-directory s3://my-bucket/ \
+ --suffix vllm
+
+Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
+and saves it to your S3 bucket. A local directory can also be used.
+
+You can also encrypt the model weights with a randomly-generated key by
+providing a `--keyfile` argument.
+
+To deserialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+ --model EleutherAI/gpt-j-6B \
+ --dtype float16 \
+ deserialize \
+ --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
+
+Which downloads the model tensors from your S3 bucket and deserializes them.
+To provide S3 credentials, you can provide `--s3-access-key-id` and
+`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
+the OpenAI entrypoint, as arguments for LLM(), or as environment variables
+in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
+
+
+You can also provide a `--keyfile` argument to decrypt the model weights if
+they were serialized with encryption.
+
+For more information on the available arguments, run
+`python tensorize_vllm_model.py --help`.
+"""
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="An example script that can be used to serialize and "
+ "deserialize vLLM models. These models "
+ "can be loaded using tensorizer directly to the GPU "
+ "extremely quickly. Tensor encryption and decryption is "
+ "also supported, although libsodium must be installed to "
+ "use it.")
+ parser = EngineArgs.add_cli_args(parser)
+ subparsers = parser.add_subparsers(dest='command')
+
+ serialize_parser = subparsers.add_parser(
+ 'serialize', help="Serialize a model to `--serialized-directory`")
+
+ serialize_parser.add_argument(
+ "--suffix",
+ type=str,
+ required=False,
+ help=(
+ "The suffix to append to the serialized model directory, which is "
+ "used to construct the location of the serialized model tensors, "
+ "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
+ "`--suffix` is `v1`, the serialized model tensors will be "
+ "saved to "
+ "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
+ "If none is provided, a random UUID will be used."))
+ serialize_parser.add_argument(
+ "--serialized-directory",
+ type=str,
+ required=True,
+ help="The directory to serialize the model to. "
+ "This can be a local directory or S3 URI. The path to where the "
+ "tensors are saved is a combination of the supplied `dir` and model "
+ "reference ID. For instance, if `dir` is the serialized directory, "
+ "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
+ "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
+ "where `suffix` is given by `--suffix` or a random UUID if not "
+ "provided.")
+
+ serialize_parser.add_argument(
+ "--keyfile",
+ type=str,
+ required=False,
+ help=("Encrypt the model weights with a randomly-generated binary key,"
+ " and save the key at this path"))
+
+ deserialize_parser = subparsers.add_parser(
+ 'deserialize',
+ help=("Deserialize a model from `--path-to-tensors`"
+ " to verify it can be loaded and used."))
+
+ deserialize_parser.add_argument(
+ "--path-to-tensors",
+ type=str,
+ required=True,
+ help="The local path or S3 URI to the model tensors to deserialize. ")
+
+ deserialize_parser.add_argument(
+ "--keyfile",
+ type=str,
+ required=False,
+ help=("Path to a binary key to use to decrypt the model weights,"
+ " if the model was serialized with encryption"))
+
+ return parser.parse_args()
+
+
+def make_model_contiguous(model):
+ # Ensure tensors are saved in memory contiguously
+ for param in model.parameters():
+ param.data = param.data.contiguous()
+
+
+def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+ architectures = getattr(config, "architectures", [])
+ for arch in architectures:
+ model_cls = ModelRegistry.load_model_cls(arch)
+ if model_cls is not None:
+ return model_cls
+ raise ValueError(
+ f"Model architectures {architectures} are not supported for now. "
+ f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def serialize():
+
+ eng_args_dict = {f.name: getattr(args, f.name) for f in
+ dataclasses.fields(EngineArgs)}
+ engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
+ engine = LLMEngine.from_engine_args(engine_args)
+
+ model = (engine.model_executor.driver_worker.
+ model_runner.model)
+
+ encryption_params = EncryptionParams.random() if keyfile else None
+ if keyfile:
+ with _write_stream(keyfile) as stream:
+ stream.write(encryption_params.key)
+
+ with _write_stream(model_path) as stream:
+ serializer = TensorSerializer(stream, encryption=encryption_params)
+ serializer.write_module(model)
+ serializer.close()
+
+ print("Serialization complete. Model tensors saved to", model_path)
+ if keyfile:
+ print("Key saved to", keyfile)
+
+
+def deserialize():
+ config = AutoConfig.from_pretrained(model_ref)
+
+ with no_init_or_tensor():
+ model_class = _get_vllm_model_architecture(config)
+ model = model_class(config)
+
+ before_mem = get_mem_usage()
+ start = time.time()
+
+ if keyfile:
+ with _read_stream(keyfile) as stream:
+ key = stream.read()
+ decryption_params = DecryptionParams.from_key(key)
+ tensorizer_args.deserializer_params['encryption'] = \
+ decryption_params
+
+ with (_read_stream(model_path)) as stream, TensorDeserializer(
+ stream, **tensorizer_args.deserializer_params) as deserializer:
+ deserializer.load_into_module(model)
+ end = time.time()
+
+ # Brag about how fast we are.
+ total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+ duration = end - start
+ per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+ after_mem = get_mem_usage()
+ print(
+ f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
+ )
+ print(f"Memory usage before: {before_mem}")
+ print(f"Memory usage after: {after_mem}")
+
+ return model
+
+
+args = parse_args()
+
+s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
+ or None)
+s3_secret_access_key = (args.s3_secret_access_key
+ or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
+
+s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
+
+_read_stream, _write_stream = (partial(
+ stream_io.open_stream,
+ mode=mode,
+ s3_access_key_id=s3_access_key_id,
+ s3_secret_access_key=s3_secret_access_key,
+ s3_endpoint=s3_endpoint,
+) for mode in ("rb", "wb+"))
+
+model_ref = args.model
+
+model_name = model_ref.split("/")[1]
+
+os.environ["MASTER_ADDR"] = "127.0.0.1"
+os.environ["MASTER_PORT"] = "8080"
+
+torch.distributed.init_process_group(world_size=1, rank=0)
+initialize_model_parallel()
+
+keyfile = args.keyfile if args.keyfile else None
+
+if args.command == "serialize":
+ input_dir = args.serialized_directory.rstrip('/')
+ suffix = args.suffix if args.suffix else uuid.uuid4().hex
+ base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
+ model_path = f"{base_path}/model.tensors"
+ serialize()
+elif args.command == "deserialize":
+ tensorizer_args = TensorizerArgs.from_cli_args(args)
+ model_path = args.path_to_tensors
+ deserialize()
+else:
+ raise ValueError("Either serialize or deserialize must be specified.")
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
index 36d20bc9473ea..5779b38b24e69 100644
--- a/requirements-cpu.txt
+++ b/requirements-cpu.txt
@@ -3,4 +3,4 @@
# Dependencies for x86_64 CPUs
torch == 2.2.1+cpu
-triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error.
+triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error.
\ No newline at end of file
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 96dfda6faf00f..1317e51b2dd11 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -14,6 +14,7 @@ types-setuptools
# testing
pytest
+tensorizer==2.9.0a0
pytest-forked
pytest-asyncio
pytest-rerunfailures
diff --git a/setup.py b/setup.py
index 9f0814e9f3bff..813321efe796d 100644
--- a/setup.py
+++ b/setup.py
@@ -405,6 +405,9 @@ def _read_requirements(filename: str) -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
+ extras_require={
+ "optional": ["tensorizer==2.9.0a1"],
+ },
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
)
diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer/tensorize_vllm_model_for_testing.py
new file mode 100644
index 0000000000000..d0be08329fd64
--- /dev/null
+++ b/tests/tensorizer/tensorize_vllm_model_for_testing.py
@@ -0,0 +1,245 @@
+import argparse
+import dataclasses
+import os
+import time
+import uuid
+from functools import partial
+from typing import Type
+
+import torch
+import torch.nn as nn
+from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
+ TensorSerializer, stream_io)
+from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
+from transformers import AutoConfig, PretrainedConfig
+
+from vllm.distributed import initialize_model_parallel
+from vllm.engine.arg_utils import EngineArgs
+from vllm.engine.llm_engine import LLMEngine
+from vllm.model_executor.models import ModelRegistry
+from vllm.model_executor.tensorizer_loader import TensorizerArgs
+
+# yapf conflicts with isort for this docstring
+# yapf: disable
+"""
+tensorize_vllm_model.py is a script that can be used to serialize and
+deserialize vLLM models. These models can be loaded using tensorizer directly
+to the GPU extremely quickly. Tensor encryption and decryption is also
+supported, although libsodium must be installed to use it. Install
+vllm with tensorizer support using `pip install vllm[tensorizer]`.
+
+To serialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+ --model EleutherAI/gpt-j-6B \
+ --dtype float16 \
+ serialize \
+ --serialized-directory s3://my-bucket/ \
+ --suffix vllm
+
+Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
+and saves it to your S3 bucket. A local directory can also be used.
+
+You can also encrypt the model weights with a randomly-generated key by
+providing a `--keyfile` argument.
+
+To deserialize a model, you can run something like this:
+
+python tensorize_vllm_model.py \
+ --model EleutherAI/gpt-j-6B \
+ --dtype float16 \
+ deserialize \
+ --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
+
+Which downloads the model tensors from your S3 bucket and deserializes them.
+To provide S3 credentials, you can provide `--s3-access-key-id` and
+`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
+the OpenAI entrypoint, as arguments for LLM(), or as environment variables
+in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
+
+
+You can also provide a `--keyfile` argument to decrypt the model weights if
+they were serialized with encryption.
+
+For more information on the available arguments, run
+`python tensorize_vllm_model.py --help`.
+"""
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="An example script that can be used to serialize and "
+ "deserialize vLLM models. These models "
+ "can be loaded using tensorizer directly to the GPU "
+ "extremely quickly. Tensor encryption and decryption is "
+ "also supported, although libsodium must be installed to "
+ "use it.")
+ parser = EngineArgs.add_cli_args(parser)
+ subparsers = parser.add_subparsers(dest='command')
+
+ serialize_parser = subparsers.add_parser(
+ 'serialize', help="Serialize a model to `--serialized-directory`")
+
+ serialize_parser.add_argument(
+ "--suffix",
+ type=str,
+ required=False,
+ help=(
+ "The suffix to append to the serialized model directory, which is "
+ "used to construct the location of the serialized model tensors, "
+ "e.g. if `--serialized-directory` is `s3://my-bucket/` and "
+ "`--suffix` is `v1`, the serialized model tensors will be "
+ "saved to "
+ "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
+ "If none is provided, a random UUID will be used."))
+ serialize_parser.add_argument(
+ "--serialized-directory",
+ type=str,
+ required=True)
+
+ serialize_parser.add_argument(
+ "--keyfile",
+ type=str,
+ required=False,
+ help=("Encrypt the model weights with a randomly-generated binary key,"
+ " and save the key at this path"))
+
+ deserialize_parser = subparsers.add_parser(
+ 'deserialize',
+ help=("Deserialize a model from `--path-to-tensors`"
+ " to verify it can be loaded and used."))
+
+ deserialize_parser.add_argument(
+ "--path-to-tensors",
+ type=str,
+ required=True,
+ help="The local path or S3 URI to the model tensors to deserialize. ")
+
+ deserialize_parser.add_argument(
+ "--keyfile",
+ type=str,
+ required=False,
+ help=("Path to a binary key to use to decrypt the model weights,"
+ " if the model was serialized with encryption"))
+
+ return parser.parse_args()
+
+
+def make_model_contiguous(model):
+ # Ensure tensors are saved in memory contiguously
+ for param in model.parameters():
+ param.data = param.data.contiguous()
+
+
+def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+ architectures = getattr(config, "architectures", [])
+ for arch in architectures:
+ model_cls = ModelRegistry.load_model_cls(arch)
+ if model_cls is not None:
+ return model_cls
+ raise ValueError(
+ f"Model architectures {architectures} are not supported for now. "
+ f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def serialize():
+ eng_args_dict = {f.name: getattr(args, f.name) for f in
+ dataclasses.fields(EngineArgs)}
+ engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
+ engine = LLMEngine.from_engine_args(engine_args)
+
+ model = (engine.model_executor.driver_worker.
+ model_runner.model)
+
+ encryption_params = EncryptionParams.random() if keyfile else None
+ if keyfile:
+ with _write_stream(keyfile) as stream:
+ stream.write(encryption_params.key)
+
+ with _write_stream(model_path) as stream:
+ serializer = TensorSerializer(stream, encryption=encryption_params)
+ serializer.write_module(model)
+ serializer.close()
+
+ print("Serialization complete. Model tensors saved to", model_path)
+ if keyfile:
+ print("Key saved to", keyfile)
+
+
+def deserialize():
+ config = AutoConfig.from_pretrained(model_ref)
+
+ with no_init_or_tensor():
+ model_class = _get_vllm_model_architecture(config)
+ model = model_class(config)
+
+ before_mem = get_mem_usage()
+ start = time.time()
+
+ if keyfile:
+ with _read_stream(keyfile) as stream:
+ key = stream.read()
+ decryption_params = DecryptionParams.from_key(key)
+ tensorizer_args.deserializer_params['encryption'] = \
+ decryption_params
+
+ with (_read_stream(model_path)) as stream, TensorDeserializer(
+ stream, **tensorizer_args.deserializer_params) as deserializer:
+ deserializer.load_into_module(model)
+ end = time.time()
+
+ # Brag about how fast we are.
+ total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+ duration = end - start
+ per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+ after_mem = get_mem_usage()
+ print(
+ f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
+ )
+ print(f"Memory usage before: {before_mem}")
+ print(f"Memory usage after: {after_mem}")
+
+ return model
+
+
+args = parse_args()
+
+s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
+ or None)
+s3_secret_access_key = (args.s3_secret_access_key
+ or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
+
+s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
+
+_read_stream, _write_stream = (partial(
+ stream_io.open_stream,
+ mode=mode,
+ s3_access_key_id=s3_access_key_id,
+ s3_secret_access_key=s3_secret_access_key,
+ s3_endpoint=s3_endpoint,
+) for mode in ("rb", "wb+"))
+
+model_ref = args.model
+
+model_name = model_ref.split("/")[1]
+
+os.environ["MASTER_ADDR"] = "127.0.0.1"
+os.environ["MASTER_PORT"] = "8080"
+
+torch.distributed.init_process_group(world_size=1, rank=0)
+initialize_model_parallel()
+
+keyfile = args.keyfile if args.keyfile else None
+
+if args.command == "serialize":
+ input_dir = args.serialized_directory.rstrip('/')
+ suffix = args.suffix if args.suffix else uuid.uuid4().hex
+ base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
+ model_path = f"{base_path}/model.tensors"
+ serialize()
+elif args.command == "deserialize":
+ tensorizer_args = TensorizerArgs.from_cli_args(args)
+ model_path = args.path_to_tensors
+ deserialize()
+else:
+ raise ValueError("Either serialize or deserialize must be specified.")
diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer/test_tensorizer.py
new file mode 100644
index 0000000000000..2ab893e95da9c
--- /dev/null
+++ b/tests/tensorizer/test_tensorizer.py
@@ -0,0 +1,302 @@
+import gc
+import subprocess
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from tests.entrypoints.test_openai_server import ServerRunner
+from vllm import SamplingParams
+from vllm.config import TensorizerConfig
+from vllm.model_executor.tensorizer_loader import (
+ EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
+ load_with_tensorizer, open_stream)
+
+prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+]
+# Create a sampling params object.
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
+
+model_ref = "facebook/opt-125m"
+
+
+def is_curl_installed():
+ try:
+ subprocess.check_call(['curl', '--version'])
+ return True
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ return False
+
+
+@pytest.fixture(autouse=True)
+def tensorizer_config():
+ config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
+ return config
+
+
+@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
+def test_load_with_tensorizer(mock_agent, tensorizer_config):
+ mock_linear_method = MagicMock()
+ mock_agent_instance = mock_agent.return_value
+ mock_agent_instance.deserialize.return_value = MagicMock()
+
+ result = load_with_tensorizer(tensorizer_config,
+ linear_method=mock_linear_method)
+
+ mock_agent.assert_called_once_with(tensorizer_config,
+ linear_method=mock_linear_method)
+ mock_agent_instance.deserialize.assert_called_once()
+ assert result == mock_agent_instance.deserialize.return_value
+
+
+def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
+ tensorizer_config.vllm_tensorized = True
+
+ result = is_vllm_serialized_tensorizer(tensorizer_config)
+
+ assert result is True
+
+
+def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
+ tensorizer_config.vllm_tensorized = False
+
+ result = is_vllm_serialized_tensorizer(tensorizer_config)
+
+ assert result is False
+
+
+def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
+ vllm_model = vllm_runner(model_ref)
+ model_path = tmp_path / (model_ref + ".tensors")
+ outputs = vllm_model.generate(prompts, sampling_params)
+ model = (vllm_model.model.llm_engine.model_executor.driver_worker.
+ model_runner.model)
+ with open_stream(model_path, "wb+") as stream:
+ serializer = TensorSerializer(stream)
+ serializer.write_module(model)
+ del vllm_model, model
+ gc.collect()
+ torch.cuda.empty_cache()
+ loaded_vllm_model = vllm_runner(model_ref,
+ load_format="tensorizer",
+ tensorizer_uri=model_path,
+ num_readers=1,
+ vllm_tensorized=True)
+ deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
+
+ # Assumes SamplingParams being seeded ensures the outputs are deterministic
+ assert outputs == deserialized_outputs
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_can_deserialize_s3(vllm_runner):
+ model_ref = "EleutherAI/pythia-1.4b"
+ tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
+
+ loaded_hf_model = vllm_runner(
+ model_ref,
+ tensorizer_uri=tensorized_path,
+ load_format="tensorizer",
+ num_readers=1,
+ vllm_tensorized=False,
+ s3_endpoint="object.ord1.coreweave.com",
+ )
+
+ deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
+
+ assert deserialized_outputs
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_deserialized_encrypted_vllm_model_has_same_outputs(
+ vllm_runner, tmp_path):
+ vllm_model = vllm_runner(model_ref)
+ model_path = tmp_path / (model_ref + ".tensors")
+ key_path = tmp_path / (model_ref + ".key")
+ outputs = vllm_model.generate(prompts, sampling_params)
+ model = (vllm_model.model.llm_engine.model_executor.driver_worker.
+ model_runner.model)
+
+ encryption_params = EncryptionParams.random()
+ with open_stream(model_path, "wb+") as stream:
+ serializer = TensorSerializer(stream, encryption=encryption_params)
+ serializer.write_module(model)
+ with open_stream(key_path, "wb+") as stream:
+ stream.write(encryption_params.key)
+ del vllm_model, model
+ gc.collect()
+ torch.cuda.empty_cache()
+ loaded_vllm_model = vllm_runner(model_ref,
+ tensorizer_uri=model_path,
+ load_format="tensorizer",
+ encryption_keyfile=key_path,
+ num_readers=1,
+ vllm_tensorized=True)
+
+ deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
+
+ # Assumes SamplingParams being seeded ensures the outputs are deterministic
+ assert outputs == deserialized_outputs
+
+
+def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
+ tmp_path):
+ hf_model = hf_runner(model_ref)
+ model_path = tmp_path / (model_ref + ".tensors")
+ max_tokens = 50
+ outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
+ with open_stream(model_path, "wb+") as stream:
+ serializer = TensorSerializer(stream)
+ serializer.write_module(hf_model.model)
+ del hf_model
+ gc.collect()
+ torch.cuda.empty_cache()
+ loaded_hf_model = vllm_runner(model_ref,
+ tensorizer_uri=model_path,
+ load_format="tensorizer",
+ num_readers=1,
+ vllm_tensorized=False)
+
+ deserialized_outputs = loaded_hf_model.generate_greedy(
+ prompts, max_tokens=max_tokens)
+
+ assert outputs == deserialized_outputs
+
+
+def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
+ from huggingface_hub import snapshot_download
+
+ from examples.multilora_inference import (create_test_prompts,
+ process_requests)
+
+ model_ref = "meta-llama/Llama-2-7b-hf"
+ lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
+ test_prompts = create_test_prompts(lora_path)
+
+ # Serialize model before deserializing and binding LoRA adapters
+ vllm_model = vllm_runner(model_ref, )
+ model_path = tmp_path / (model_ref + ".tensors")
+ model = (vllm_model.model.llm_engine.model_executor.driver_worker.
+ model_runner.model)
+ with open_stream(model_path, "wb+") as stream:
+ serializer = TensorSerializer(stream)
+ serializer.write_module(model)
+ del vllm_model, model
+ gc.collect()
+ torch.cuda.empty_cache()
+ loaded_vllm_model = vllm_runner(
+ model_ref,
+ tensorizer_uri=model_path,
+ load_format="tensorizer",
+ num_readers=1,
+ vllm_tensorized=True,
+ enable_lora=True,
+ max_loras=1,
+ max_lora_rank=8,
+ max_cpu_loras=2,
+ max_num_seqs=50,
+ max_model_len=1000,
+ )
+ process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
+
+ assert loaded_vllm_model
+
+
+def test_load_without_tensorizer_load_format(vllm_runner):
+ with pytest.raises(ValueError):
+ vllm_runner(model_ref, tensorizer_uri="test")
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_tensorize_vllm_model(tmp_path):
+ # Test serialize command
+ serialize_args = [
+ "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+ model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
+ tmp_path, "--suffix", "tests"
+ ]
+ result = subprocess.run(serialize_args, capture_output=True, text=True)
+ print(result.stdout) # Print the output of the serialize command
+
+ assert result.returncode == 0, (f"Serialize command failed with output:"
+ f"\n{result.stdout}\n{result.stderr}")
+
+ path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
+
+ # Test deserialize command
+ deserialize_args = [
+ "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+ model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
+ path_to_tensors
+ ]
+ result = subprocess.run(deserialize_args, capture_output=True, text=True)
+ assert result.returncode == 0, (f"Deserialize command failed with output:"
+ f"\n{result.stdout}\n{result.stderr}")
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_openai_apiserver_with_tensorizer(tmp_path):
+ ## Serialize model
+ serialize_args = [
+ "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+ model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
+ tmp_path, "--suffix", "tests"
+ ]
+ result = subprocess.run(serialize_args, capture_output=True, text=True)
+ print(result.stdout) # Print the output of the serialize command
+
+ assert result.returncode == 0, (f"Serialize command failed with output:"
+ f"\n{result.stdout}\n{result.stderr}")
+
+ path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
+
+ ## Start OpenAI API server
+ openai_args = [
+ "--model", model_ref, "--dtype", "float16", "--load-format",
+ "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
+ "--port", "8000"
+ ]
+
+ server = ServerRunner.remote(openai_args)
+
+ print("Server ready.")
+ assert server.ready.remote()
+
+
+def test_raise_value_error_on_invalid_load_format(vllm_runner):
+ with pytest.raises(ValueError):
+ vllm_runner(model_ref,
+ load_format="safetensors",
+ tensorizer_uri="test")
+
+
+def test_tensorizer_with_tp(vllm_runner):
+ with pytest.raises(ValueError):
+ model_ref = "EleutherAI/pythia-1.4b"
+ tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
+
+ vllm_runner(
+ model_ref,
+ tensorizer_uri=tensorized_path,
+ load_format="tensorizer",
+ num_readers=1,
+ vllm_tensorized=False,
+ s3_endpoint="object.ord1.coreweave.com",
+ tensor_parallel_size=2,
+ )
+
+
+@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
+def test_tensorizer_warn_quant(tmp_path):
+ model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
+ serialize_args = [
+ "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
+ model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
+ "serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
+ ]
+ result = subprocess.run(serialize_args, capture_output=True, text=True)
+ assert 'PerformanceWarning' in result.stderr
diff --git a/vllm/config.py b/vllm/config.py
index bbda4ecf3cc56..dce2944b2ee8a 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1,6 +1,8 @@
import enum
+import io
import json
import os
+import typing
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
@@ -16,6 +18,8 @@
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
+ from vllm.model_executor.tensorizer_loader import TensorizerArgs
+
logger = init_logger(__name__)
_GB = 1 << 30
@@ -139,13 +143,14 @@ def __init__(
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
supported_load_format = [
- "auto", "pt", "safetensors", "npcache", "dummy"
+ "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
]
rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
- "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
+ "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
+ "'dummy'.")
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f for f in supported_load_format
@@ -882,6 +887,65 @@ def get_image_input_enum_type(
f"{[x.name for x in cls.ImageInputType]}.") from e
+@dataclass
+class TensorizerConfig:
+ tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+ str, bytes, os.PathLike, int]
+ vllm_tensorized: bool
+ verify_hash: Optional[bool] = False
+ num_readers: Optional[int] = 1
+ encryption_keyfile: Optional[str] = None
+ s3_access_key_id: Optional[str] = None
+ s3_secret_access_key: Optional[str] = None
+ s3_endpoint: Optional[str] = None
+ model_class: Optional[torch.nn.Module] = None
+ hf_config: Optional[PretrainedConfig] = None
+ dtype: Union[str, torch.dtype] = None
+
+ def _construct_tensorizer_args(self) -> "TensorizerArgs":
+ from vllm.model_executor.tensorizer_loader import TensorizerArgs
+ tensorizer_args = {
+ "tensorizer_uri": self.tensorizer_uri,
+ "vllm_tensorized": self.vllm_tensorized,
+ "verify_hash": self.verify_hash,
+ "num_readers": self.num_readers,
+ "encryption_keyfile": self.encryption_keyfile,
+ "s3_access_key_id": self.s3_access_key_id,
+ "s3_secret_access_key": self.s3_secret_access_key,
+ "s3_endpoint": self.s3_endpoint,
+ }
+ return TensorizerArgs(**tensorizer_args)
+
+ def verify_with_parallel_config(
+ self,
+ parallel_config: "ParallelConfig",
+ ) -> None:
+ if (parallel_config.tensor_parallel_size > 1
+ and self.tensorizer_uri is not None):
+ raise ValueError(
+ "Loading to multiple GPUs is not currently supported with "
+ "vLLM-serialized models. Please set tensor_parallel_size=1."
+ " or use a non-vLLM-serialized model, such as a "
+ "serialized Hugging Face `PretrainedModel`.")
+
+ def verify_with_model_config(self, model_config) -> None:
+ if (model_config.quantization is not None
+ and self.tensorizer_uri is not None):
+ from vllm.model_executor.tensorizer_loader import (
+ tensorizer_warning)
+ tensorizer_warning(
+ "Loading a model using Tensorizer with quantization on vLLM"
+ " is unstable and may lead to errors.")
+
+ if (model_config.load_format != "tensorizer"
+ and self.tensorizer_uri is not None):
+ raise ValueError(
+ "A tensorizer uri was passed for tensorizer loading, but the "
+ f"load format was set to {model_config.load_format}. "
+ "Please set the load format to 'tensorizer' to use "
+ f"tensorizer args.")
+
+
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
@@ -1029,6 +1093,7 @@ class EngineConfig:
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
+ tensorizer_config: Optional[TensorizerConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
@@ -1036,6 +1101,11 @@ def __post_init__(self):
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
+ if self.tensorizer_config:
+ self.tensorizer_config.verify_with_parallel_config(
+ self.parallel_config)
+ self.tensorizer_config.verify_with_model_config(self.model_config)
+
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index daefddc01b431..831a03be65f61 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1,12 +1,15 @@
import argparse
import dataclasses
+import io
+import os
from dataclasses import dataclass
-from typing import Optional
+from typing import BinaryIO, Optional, Union
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
- SpeculativeConfig, TokenizerPoolConfig,
- VisionLanguageConfig)
+ SpeculativeConfig, TensorizerConfig,
+ TokenizerPoolConfig, VisionLanguageConfig)
+from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple
@@ -58,12 +61,22 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
+ # Tensorizer configuration parameters
+ tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
+ bytes, os.PathLike, int] = None
+ vllm_tensorized: bool = False
+ verify_hash: Optional[bool] = False
+ num_readers: Optional[int] = 1
+ encryption_keyfile: Optional[str] = None
+ s3_access_key_id: Optional[str] = None
+ s3_secret_access_key: Optional[str] = None
+ s3_endpoint: Optional[str] = None
+
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
-
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
@@ -135,7 +148,9 @@ def add_cli_args(
'--load-format',
type=str,
default=EngineArgs.load_format,
- choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
+ choices=[
+ 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
+ ],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
@@ -145,7 +160,10 @@ def add_cli_args(
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
- 'which is mainly for profiling.')
+ 'which is mainly for profiling.'
+ '"tensorizer" will load the weights using tensorizer from CoreWeave'
+ 'which assumes tensorizer_uri is set to the location of the '
+ 'serialized weights.')
parser.add_argument(
'--dtype',
type=str,
@@ -403,6 +421,7 @@ def add_cli_args(
default=None,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding')
+ parser = TensorizerArgs.add_cli_args(parser)
return parser
@classmethod
@@ -465,6 +484,17 @@ def create_engine_config(self, ) -> EngineConfig:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
+ tensorizer_config = TensorizerConfig(
+ tensorizer_uri=self.tensorizer_uri,
+ vllm_tensorized=self.vllm_tensorized,
+ verify_hash=self.verify_hash,
+ num_readers=self.num_readers,
+ encryption_keyfile=self.encryption_keyfile,
+ s3_access_key_id=self.s3_access_key_id,
+ s3_secret_access_key=self.s3_secret_access_key,
+ s3_endpoint=self.s3_endpoint,
+ )
+
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
@@ -488,7 +518,8 @@ def create_engine_config(self, ) -> EngineConfig:
device_config=device_config,
lora_config=lora_config,
vision_language_config=vision_language_config,
- speculative_config=speculative_config)
+ speculative_config=speculative_config,
+ tensorizer_config=tensorizer_config)
@dataclass
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index a91629a630591..8c37c5a9d6ee9 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -6,7 +6,7 @@
import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
- VisionLanguageConfig)
+ TensorizerConfig, VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
@@ -74,6 +74,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
+ tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -110,6 +111,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
+ self.tensorizer_config = tensorizer_config
self.log_stats = log_stats
self._init_tokenizer()
@@ -125,6 +127,7 @@ def __init__(
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
+ tensorizer_config=tensorizer_config,
)
self._initialize_kv_caches()
@@ -264,6 +267,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
+ if self.tensorizer_config:
+ self.tensorizer_config.verify_with_parallel_config(
+ self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index f20221a0b941a..30577ecf62faa 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -2,7 +2,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
- VisionLanguageConfig)
+ TensorizerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -15,17 +15,14 @@
class GPUExecutor(ExecutorBase):
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- speculative_config: Optional[SpeculativeConfig],
- ) -> None:
+ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig,
+ device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig],
+ speculative_config: Optional[SpeculativeConfig],
+ tensorizer_config: Optional[TensorizerConfig]) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
@@ -33,6 +30,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
+ self.tensorizer_config = tensorizer_config
assert (not speculative_config
), "Speculative decoding not yet supported for GPU backend"
@@ -61,6 +59,7 @@ def _init_worker(self):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
+ tensorizer_config=self.tensorizer_config,
is_driver_worker=True,
)
self.driver_worker.init_device()
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index b937693c92257..28dc3e0db312a 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -7,7 +7,7 @@
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
- VisionLanguageConfig)
+ TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
@@ -42,6 +42,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
+ tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
@@ -50,6 +51,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
+ self.tensorizer_config = tensorizer_config
assert (not speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
@@ -171,6 +173,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
+ tensorizer_config=self.tensorizer_config,
))
# Initialize the driver worker with the Worker class.
@@ -187,6 +190,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
+ tensorizer_config=self.tensorizer_config,
is_driver_worker=True,
)
diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py
index 2745dbd89ab0f..c70ca48bca70a 100644
--- a/vllm/model_executor/model_loader.py
+++ b/vllm/model_executor/model_loader.py
@@ -3,11 +3,14 @@
from typing import Tuple, Type
import torch
-import torch.nn as nn
+from torch import nn
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
+from vllm.model_executor.tensorizer_loader import (
+ ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
+ load_with_tensorizer)
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
@@ -51,6 +54,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
vision_language_config = kwargs.get("vision_language_config", None)
+ tensorizer_config = kwargs.get("tensorizer_config", None)
model_class = _get_model_architecture(model_config)[0]
# Get the (maybe quantized) linear method.
@@ -71,33 +75,54 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
+
linear_method = quant_config.get_linear_method()
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
+ extra_kwargs = {}
+ if hasattr(model_class, "supported_lora_modules"):
+ extra_kwargs["lora_config"] = lora_config
+ elif lora_config:
+ raise ValueError(
+ f"Model {model_class.__name__} does not support LoRA, "
+ "but LoRA is enabled. Support for this model may "
+ "be added in the future. If this is important to you, "
+ "please open an issue on github.")
+ elif model_class in _VISION_MODEL_CLASSES:
+ extra_kwargs["vision_language_config"] = vision_language_config
+
with torch.device(device_config.device):
- if hasattr(model_class, "supported_lora_modules"):
- model = model_class(model_config.hf_config, linear_method,
- lora_config)
- elif lora_config:
- raise ValueError(
- f"Model {model_class.__name__} does not support LoRA, "
- "but LoRA is enabled. Support for this model may "
- "be added in the future. If this is important to you, "
- "please open an issue on github.")
- else:
- if model_class not in _VISION_MODEL_CLASSES:
- model = model_class(model_config.hf_config, linear_method)
- else:
- model = model_class(model_config.hf_config,
- vision_language_config, linear_method)
+ if (model_config.load_format == "tensorizer"
+ and is_vllm_serialized_tensorizer(tensorizer_config)):
+ extra_kwargs["linear_method"] = linear_method
+ tensorizer_config.model_class = model_class
+ tensorizer_config.hf_config = model_config.hf_config
+ tensorizer_config.dtype = model_config.dtype
+ model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
+ return model.eval()
+ model = model_class(config=model_config.hf_config,
+ linear_method=linear_method,
+ **extra_kwargs)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
- model.load_weights(model_config.model, model_config.download_dir,
- model_config.load_format, model_config.revision)
+ if model_config.load_format == "tensorizer":
+ # Provide a dynamic load format for `model.load_weights`
+ # to retain tensorizer args from CLI.
+ model_config.load_format = ParameterizedLoadFormat(
+ model_config.load_format)
+ model_config.load_format.params = (
+ tensorizer_config._construct_tensorizer_args())
+
+ model.load_weights(
+ model_config.model,
+ model_config.download_dir,
+ model_config.load_format,
+ model_config.revision,
+ )
return model.eval()
diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py
new file mode 100644
index 0000000000000..ed3ad9e2ffa15
--- /dev/null
+++ b/vllm/model_executor/tensorizer_loader.py
@@ -0,0 +1,319 @@
+import argparse
+import dataclasses
+import io
+import os
+import time
+import typing
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from vllm.config import TensorizerConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.linear import LinearMethodBase
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ VocabParallelEmbedding)
+
+tensorizer_load_fail = False
+
+try:
+ from tensorizer import (DecryptionParams, EncryptionParams,
+ TensorDeserializer, TensorSerializer)
+ from tensorizer.stream_io import open_stream
+ from tensorizer.utils import (convert_bytes, get_mem_usage,
+ no_init_or_tensor)
+except ImportError:
+ tensorizer_load_fail = True
+
+__all__ = [
+ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
+ 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
+ 'no_init_or_tensor'
+]
+
+logger = init_logger(__name__)
+
+
+def load_with_tensorizer(tensorizer_config: TensorizerConfig,
+ **extra_kwargs) -> nn.Module:
+ tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
+ return tensorizer.deserialize()
+
+
+def tensorizer_warning(message: str):
+ return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
+
+
+def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
+ if tensorizer_config is None:
+ return False
+ return tensorizer_config.vllm_tensorized
+
+
+class ParameterizedLoadFormat(str):
+ __slots__ = "params"
+
+
+class PerformanceWarning(UserWarning):
+
+ def __str__(self):
+ return (f"{super().__str__()}"
+ " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
+ " environment variable to hide this)")
+
+
+if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
+ not in ("", "0", "n", "no", "off", "disable")):
+ warnings.simplefilter("ignore", category=PerformanceWarning)
+
+
+@dataclass
+class TensorizerArgs:
+ tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+ str, bytes, os.PathLike, int]
+ vllm_tensorized: bool
+ verify_hash: Optional[bool] = False
+ num_readers: Optional[int] = 1
+ encryption_keyfile: Optional[str] = None
+ s3_access_key_id: Optional[str] = None
+ s3_secret_access_key: Optional[str] = None
+ s3_endpoint: Optional[str] = None
+ """
+ Args for the TensorizerAgent class. These are used to configure the behavior
+ of the TensorDeserializer when loading tensors from a serialized model.
+
+ Args:
+ tensorizer_uri: Path to serialized model tensors. Can be a local file
+ path or a S3 URI.
+ vllm_tensorized: If True, indicates that the serialized model is a
+ vLLM model. This is used to determine the behavior of the
+ TensorDeserializer when loading tensors from a serialized model.
+ It is far faster to deserialize a vLLM model as it utilizes
+ tensorizer's optimized GPU loading.
+ verify_hash: If True, the hashes of each tensor will be verified against
+ the hashes stored in the metadata. A `HashMismatchError` will be
+ raised if any of the hashes do not match.
+ num_readers: Controls how many threads are allowed to read concurrently
+ from the source file. Default is 1. This greatly increases
+ performance.
+ encryption_keyfile: File path to a binary file containing a
+ binary key to use for decryption. `None` (the default) means
+ no decryption. See the example script in
+ examples/tensorize_vllm_model.py.
+ s3_access_key_id: The access key for the S3 bucket. Can also be set via
+ the S3_ACCESS_KEY_ID environment variable.
+ s3_secret_access_key: The secret access key for the S3 bucket. Can also
+ be set via the S3_SECRET_ACCESS_KEY environment variable.
+ s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
+ S3_ENDPOINT_URL environment variable.
+ """
+
+ def __post_init__(self):
+ self.file_obj = self.tensorizer_uri
+ self.s3_access_key_id = (self.s3_access_key_id
+ or os.environ.get("S3_ACCESS_KEY_ID")) or None
+ self.s3_secret_access_key = (
+ self.s3_secret_access_key
+ or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
+ self.s3_endpoint = (self.s3_endpoint
+ or os.environ.get("S3_ENDPOINT_URL")) or None
+ self.stream_params = {
+ "s3_access_key_id": self.s3_access_key_id,
+ "s3_secret_access_key": self.s3_secret_access_key,
+ "s3_endpoint": self.s3_endpoint,
+ }
+
+ # Omitting self.dtype and self.device as this behaves weirdly
+ self.deserializer_params = {
+ "verify_hash": self.verify_hash,
+ "encryption": self.encryption_keyfile,
+ "num_readers": self.num_readers
+ }
+ if self.encryption_keyfile:
+ with open_stream(
+ self.encryption_keyfile,
+ **self.stream_params,
+ ) as stream:
+ key = stream.read()
+ decryption_params = DecryptionParams.from_key(key)
+ self.deserializer_params['encryption'] = decryption_params
+
+ def add_cli_args(
+ parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ """Tensorizer CLI arguments"""
+
+ # Create the argument group
+ group = parser.add_argument_group(
+ 'tensorizer options',
+ description=('Options for configuring the behavior of the'
+ ' tensorizer deserializer when '
+ '--load-format=tensorizer'))
+
+ group.add_argument(
+ "--tensorizer-uri",
+ help="Path to serialized model tensors. Can be a local file path,"
+ " or an HTTP(S) or S3 URI.",
+ )
+ group.add_argument(
+ "--verify-hash",
+ action="store_true",
+ help="If enabled, the hashes of each tensor will be verified"
+ " against the hashes stored in the file metadata. An exception"
+ " will be raised if any of the hashes do not match.",
+ )
+ group.add_argument(
+ "--encryption-keyfile",
+ default=None,
+ help="The file path to a binary file containing a binary key to "
+ "use for decryption. Can be a file path or S3 network URI.")
+ group.add_argument(
+ "--num-readers",
+ default=1,
+ type=int,
+ help="Controls how many threads are allowed to read concurrently "
+ "from the source file.")
+ group.add_argument(
+ "--s3-access-key-id",
+ default=None,
+ help="The access key for the S3 bucket. Can also be set via the "
+ "S3_ACCESS_KEY_ID environment variable.",
+ )
+ group.add_argument(
+ "--s3-secret-access-key",
+ default=None,
+ help="The secret access key for the S3 bucket. Can also be set via "
+ "the S3_SECRET_ACCESS_KEY environment variable.",
+ )
+ group.add_argument(
+ "--s3-endpoint",
+ default=None,
+ help="The endpoint for the S3 bucket. Can also be set via the "
+ "S3_ENDPOINT_URL environment variable.",
+ )
+ group.add_argument(
+ "--vllm-tensorized",
+ action="store_true",
+ help="If enabled, indicates that the serialized model is a vLLM "
+ "model. This is used to determine the behavior of the "
+ "TensorDeserializer when loading tensors from a "
+ "serialized model.")
+
+ return parser
+
+ @classmethod
+ def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
+ # Get the list of attributes of this dataclass.
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
+ # Set the attributes from the parsed arguments.
+ tensorizer_args = cls(**{
+ attr: getattr(args, attr)
+ for attr in attrs if hasattr(args, attr)
+ })
+ return tensorizer_args
+
+
+class TensorizerAgent:
+ """
+ A class for performing tensorizer deserializations specifically for
+ vLLM models using plaid_mode. Uses TensorizerArgs to configure the
+ behavior of the TensorDeserializer when loading tensors from a serialized
+ model. For deserializations of HuggingFace models, TensorDeserializer is
+ instead used as an iterator directly in the func hf_model_weights_iterator
+ in vllm/model_executor/weight_utils.py
+ """
+
+ def __init__(self, tensorizer_config: TensorizerConfig,
+ linear_method: LinearMethodBase, **extra_kwargs):
+ self.tensorizer_config = tensorizer_config
+ self.tensorizer_args = (
+ self.tensorizer_config._construct_tensorizer_args())
+ self.extra_kwargs = extra_kwargs
+ if extra_kwargs.get("linear_method", None) is not None:
+ self.linear_method = extra_kwargs["linear_method"]
+ else:
+ self.linear_method = linear_method
+ self.model = self._init_model()
+
+ if tensorizer_load_fail:
+ raise ImportError(
+ "Tensorizer is not installed. Please install tensorizer "
+ "to use this feature with `pip install vllm[tensorizer]`.")
+
+ def _init_model(self):
+ model_args = self.tensorizer_config.hf_config
+ model_args.torch_dtype = self.tensorizer_config.dtype
+ with no_init_or_tensor():
+ return self.tensorizer_config.model_class(
+ config=model_args,
+ linear_method=self.linear_method,
+ **self.extra_kwargs)
+
+ def _resize_lora_embeddings(self):
+ """Modify LoRA embedding layers to use bigger tensors
+ to allow for adapter added tokens."""
+ for child in self.model.modules():
+ if (isinstance(child, VocabParallelEmbedding)
+ and child.weight.shape[0] <
+ child.num_embeddings_per_partition):
+ new_weight = torch.empty(child.num_embeddings_per_partition,
+ child.embedding_dim,
+ dtype=child.weight.dtype,
+ device=child.weight.device)
+ new_weight[:child.weight.shape[0]].copy_(child.weight.data)
+ new_weight[child.weight.shape[0]:].fill_(0)
+ child.weight.data = new_weight
+
+ def _check_tensors_on_meta_device(self):
+ for tensor in self.model.state_dict().values():
+ if tensor.device.type == 'meta':
+ raise ValueError(
+ "The serialized model contains tensors on the meta device,"
+ " indicating that some tensors were not loaded properly."
+ " Please check that the parameters of the model being"
+ " specified match that of the serialized model, such as"
+ " its quantization.")
+
+ def deserialize(self):
+ """
+ Deserialize the model using the TensorDeserializer. This method is
+ specifically for vLLM models using tensorizer's plaid_mode.
+
+ The deserializer makes use of tensorizer_args.stream_params
+ to configure the behavior of the stream when loading tensors from a
+ serialized model. The deserializer_params are used to configure the
+ behavior of the TensorDeserializer when loading tensors themselves.
+ Documentation on these params can be found in TensorizerArgs
+
+ Returns:
+ nn.Module: The deserialized model.
+ """
+ before_mem = get_mem_usage()
+ # Lazy load the tensors from S3 into the model.
+ start = time.perf_counter()
+ with open_stream(
+ self.tensorizer_args.tensorizer_uri,
+ mode="rb",
+ **self.tensorizer_args.stream_params,
+ ) as stream, TensorDeserializer(
+ stream,
+ dtype=self.tensorizer_config.dtype,
+ **self.tensorizer_args.deserializer_params) as deserializer:
+ deserializer.load_into_module(self.model)
+ end = time.perf_counter()
+
+ total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
+ duration = end - start
+ per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
+ after_mem = get_mem_usage()
+ deserializer.close()
+ logger.info(f"Deserialized {total_bytes_str} in "
+ f"{end - start:0.2f}s, {per_second}/s")
+ logger.info(f"Memory usage before: {before_mem}")
+ logger.info(f"Memory usage after: {after_mem}")
+
+ self._check_tensors_on_meta_device()
+ self._resize_lora_embeddings()
+ return self.model.eval()
diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py
index 0961478930d74..08425604f0511 100644
--- a/vllm/model_executor/weight_utils.py
+++ b/vllm/model_executor/weight_utils.py
@@ -5,7 +5,7 @@
import json
import os
from collections import defaultdict
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
import filelock
import huggingface_hub.constants
@@ -161,7 +161,8 @@ def prepare_hf_model_weights(
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
- is_local = os.path.isdir(model_name_or_path)
+ is_local = os.path.isdir(model_name_or_path) \
+ and load_format != "tensorizer"
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
@@ -173,13 +174,15 @@ def prepare_hf_model_weights(
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
+ elif load_format == "tensorizer":
+ allow_patterns = ["*.tensors"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
- if not is_local:
+ if not is_local and load_format != "tensorizer":
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
@@ -224,6 +227,9 @@ def prepare_hf_model_weights(
if not any(f.endswith(x) for x in blacklist)
]
+ if load_format == "tensorizer":
+ return hf_folder, hf_weights_files, use_safetensors
+
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
@@ -234,7 +240,7 @@ def prepare_hf_model_weights(
def hf_model_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str] = None,
- load_format: str = "auto",
+ load_format: Union[Tuple, str] = "auto",
revision: Optional[str] = None,
fall_back_to_pt: Optional[bool] = True,
) -> Iterator[Tuple[str, torch.Tensor]]:
@@ -277,6 +283,26 @@ def hf_model_weights_iterator(
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
+ elif load_format == "tensorizer":
+ from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
+ open_stream,
+ tensorizer_warning)
+ tensorizer_args = load_format.params
+ tensorizer_warning(
+ "Deserializing HuggingFace models is not optimized for "
+ "loading on vLLM, as tensorizer is forced to load to CPU. "
+ "Consider deserializing a vLLM model instead for faster "
+ "load times. See the examples/tensorize_vllm_model.py example "
+ "script for serializing vLLM models.")
+
+ deserializer_args = tensorizer_args.deserializer_params
+ stream_params = tensorizer_args.stream_params
+ stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
+ with TensorDeserializer(stream, **deserializer_args,
+ device="cpu") as state:
+ for name, param in state.items():
+ yield name, param
+ del state
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 47ad8f0c9b78b..7dbe14ead0976 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -10,7 +10,8 @@
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig, VisionLanguageConfig)
+ SchedulerConfig, TensorizerConfig,
+ VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils)
@@ -111,11 +112,13 @@ def __init__(
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
+ tensorizer_config: Optional[TensorizerConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
+ self.tensorizer_config = tensorizer_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
@@ -158,7 +161,9 @@ def load_model(self) -> None:
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config)
+ scheduler_config=self.scheduler_config,
+ tensorizer_config=self.tensorizer_config,
+ )
self.model_memory_usage = m.consumed_memory
logger.info(f"Loading model weights took "
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 3f0b2fd83f3e5..82491c6df6616 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -7,7 +7,8 @@
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, VisionLanguageConfig)
+ ParallelConfig, SchedulerConfig, TensorizerConfig,
+ VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
@@ -42,6 +43,7 @@ def __init__(
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
+ tensorizer_config: Optional[TensorizerConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
@@ -53,6 +55,7 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
+ self.tensorizer_config = tensorizer_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
@@ -70,7 +73,9 @@ def __init__(
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
- vision_language_config=vision_language_config)
+ vision_language_config=vision_language_config,
+ tensorizer_config=tensorizer_config,
+ )
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine = None
From 2cd6b4f3625466eb5849bcfd7a6fb316735adab8 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Sat, 13 Apr 2024 23:40:21 -0700
Subject: [PATCH 040/413] [Core] avoid too many cuda context by caching p2p
test (#4021)
---
.../device_communicators/custom_all_reduce.py | 53 +++++------
vllm/distributed/parallel_state.py | 9 ++
vllm/distributed/utils.py | 87 ++++++++++++++++++-
3 files changed, 116 insertions(+), 33 deletions(-)
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index 84238d2e46076..f83caef879da3 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -42,12 +42,17 @@ def init_custom_ar() -> None:
" disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
- if not _can_p2p(rank, world_size):
+ num_dev = torch.cuda.device_count()
+ # note: num dev can be larger than world_size if we're only using
+ # first few GPUs
+ if num_dev < world_size:
logger.warn(
- "Custom allreduce is disabled because your platform lacks GPU P2P"
- " capability or P2P test failed. To silence this warning, specify"
- " disable_custom_all_reduce=True explicitly.")
- return
+ "Cannot test GPU P2P because not all GPUs are visible to the "
+ "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
+ " is set.")
+ return False
+ # test nvlink first, this will filter out most of the cases
+ # where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size)
if world_size > 2 and not full_nvlink:
logger.warn(
@@ -55,6 +60,15 @@ def init_custom_ar() -> None:
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
+ # test P2P capability
+ # this is expensive to compute at the first time
+ # then we cache the result
+ if not _can_p2p(rank, world_size):
+ logger.warn(
+ "Custom allreduce is disabled because your platform lacks GPU P2P"
+ " capability or P2P test failed. To silence this warning, specify"
+ " disable_custom_all_reduce=True explicitly.")
+ return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
@@ -143,40 +157,15 @@ def _is_full_nvlink(rank, world_size):
def _can_p2p(rank: int, world_size: int) -> bool:
- num_dev = torch.cuda.device_count()
- # note: num dev can be larger than world_size if we're only using
- # first few GPUs
- if num_dev < world_size:
- logger.warn(
- "Cannot test GPU P2P because not all GPUs are visible to the "
- "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
- " is set.")
- return False
+ from vllm.distributed.utils import gpu_p2p_access_check
for i in range(world_size):
if i == rank:
continue
- if not torch.cuda.can_device_access_peer(rank, i):
- return False
- # on some platforms, P2P support might be buggy and we need
- # additional checks. See also:
- # https://github.com/vllm-project/vllm/issues/2728
- if not _can_actually_p2p(rank, i):
+ if not gpu_p2p_access_check(rank, i):
return False
return True
-# code partly borrowed from
-# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
-# License: MIT
-def _can_actually_p2p(idx_a, idx_b):
- dev_i = f"cuda:{idx_a}"
- dev_j = f"cuda:{idx_b}"
- a = torch.randn(5, device=dev_i) + 123.0
- b = a.to(dev_j)
- c = b.to(dev_i)
- return torch.all(a == c)
-
-
class CustomAllreduce:
# max_size: max supported allreduce size
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 1258bf58cb453..e2473736375e0 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -41,6 +41,13 @@
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
+_LOCAL_RANK = -1
+
+
+def get_local_rank():
+ global _LOCAL_RANK
+ return _LOCAL_RANK
+
def init_distributed_environment(
world_size: int = -1,
@@ -66,6 +73,8 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
+ global _LOCAL_RANK
+ _LOCAL_RANK = local_rank
def initialize_model_parallel(
diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py
index 0cd420c8e11b5..e0a871ebe1756 100644
--- a/vllm/distributed/utils.py
+++ b/vllm/distributed/utils.py
@@ -2,9 +2,18 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-from typing import Sequence
+import json
+import os
+from typing import Dict, Optional, Sequence
import torch
+import torch.distributed as dist
+
+from vllm.logger import init_logger
+
+from .parallel_state import get_cpu_world_group, get_local_rank
+
+logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator):
@@ -46,3 +55,79 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
+
+
+# code partly borrowed from
+# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
+# License: MIT
+def _can_actually_p2p(idx_a, idx_b):
+ dev_i = f"cuda:{idx_a}"
+ dev_j = f"cuda:{idx_b}"
+ a = torch.randn(5, device=dev_i) + 123.0
+ b = a.to(dev_j)
+ c = b.to(dev_i)
+ return torch.all(a == c).cpu().item()
+
+
+# why do we need this cache?
+# 1. we can have runtime checks for P2P access, where every process checks
+# P2P access to all other GPUs. Unfortunately, the test might cost many
+# (world_size * world_size) cuda context, and reduce the memory available
+# for the model. see https://github.com/vllm-project/vllm/issues/3821
+# 2. alternatively, we can have a p2p map that is generated by the master
+# process and broadcasted to all other processes. This still requires
+# #world_size of cuda context, belonging to the master process, on each GPU.
+# 3. we can have a cache file, that records the p2p access status. The first
+# time the master process checks the p2p access, it will generate the cache
+# file, at the cost of #world_size of cuda context. Later on, all processes
+# can read the cache file to check the p2p access status without any cost of
+# additional cuda context.
+# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
+# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
+# e.g. used by different vllm engines. The device id in the cache file is a
+# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
+# of visible devices in the vllm engine.
+_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
+
+
+def gpu_p2p_access_check(i: int, j: int) -> bool:
+ """Check if GPU i can access GPU j."""
+
+ # if the cache variable is already calculated,
+ # read from the cache instead of checking it again
+ global _gpu_p2p_access_cache
+ if _gpu_p2p_access_cache is not None:
+ return _gpu_p2p_access_cache[f"{i}->{j}"]
+
+ is_distributed = dist.is_initialized()
+
+ num_dev = torch.cuda.device_count()
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ if cuda_visible_devices is None:
+ cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
+ path = os.path.expanduser(
+ f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ if (not is_distributed or get_local_rank() == 0) \
+ and (not os.path.exists(path)):
+ # only the local master process (with local_rank == 0) can
+ # enter this block to calculate the cache
+ logger.info(f"generating GPU P2P access cache for in {path}")
+ cache = {}
+ for _i in range(num_dev):
+ for _j in range(num_dev):
+ # on some platforms, P2P support might be buggy and we need
+ # additional checks. See also:
+ # https://github.com/vllm-project/vllm/issues/2728
+ cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
+ _i, _j) and _can_actually_p2p(_i, _j)
+ with open(path, "w") as f:
+ json.dump(cache, f, indent=4)
+ if is_distributed:
+ cpu_world_group = get_cpu_world_group()
+ dist.barrier(cpu_world_group)
+ logger.info(f"reading GPU P2P access cache from {path}")
+ with open(path, "r") as f:
+ cache = json.load(f)
+ _gpu_p2p_access_cache = cache
+ return _gpu_p2p_access_cache[f"{i}->{j}"]
From 563c54f760f870ae44c7662c8a9ec3a223a3c4c4 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Sun, 14 Apr 2024 22:12:42 +0100
Subject: [PATCH 041/413] [BugFix] Fix tensorizer extra in setup.py (#4072)
---
setup.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/setup.py b/setup.py
index 813321efe796d..19a9150ad2e64 100644
--- a/setup.py
+++ b/setup.py
@@ -406,7 +406,7 @@ def _read_requirements(filename: str) -> List[str]:
install_requires=get_requirements(),
ext_modules=ext_modules,
extras_require={
- "optional": ["tensorizer==2.9.0a1"],
+ "tensorizer": ["tensorizer==2.9.0a1"],
},
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
From aceb17cf2d629175a484c3d9df355f44bd334cb3 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Sun, 14 Apr 2024 14:35:55 -0700
Subject: [PATCH 042/413] [Docs] document that mixtral 8x22b is supported
(#4073)
---
README.md | 2 +-
docs/source/models/supported_models.rst | 36 ++++++++++++-------------
2 files changed, 19 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index d53227b82d87a..8434c11883341 100644
--- a/README.md
+++ b/README.md
@@ -72,7 +72,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
-- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
+- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index c09b0ff250437..5e5ce871f61dd 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -30,23 +30,23 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`CohereForCausalLM`
- Command-R
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
- -
+ -
* - :code:`DbrxForCausalLM`
- DBRX
- :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc.
- -
+ -
* - :code:`DeciLMForCausalLM`
- DeciLM
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
- -
+ -
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
- -
+ -
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
- -
+ -
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
@@ -54,19 +54,19 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
- -
+ -
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
- -
+ -
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
- -
+ -
* - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
- -
+ -
* - :code:`InternLMForCausalLM`
- InternLM
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
@@ -93,32 +93,32 @@ Alongside each architecture, we include some popular models that use it.
- ✅︎
* - :code:`MixtralForCausalLM`
- Mixtral-8x7B, Mixtral-8x7B-Instruct
- - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
+ - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
- ✅︎
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
- -
+ -
* - :code:`OLMoForCausalLM`
- OLMo
- :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
- -
+ -
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
- -
+ -
* - :code:`OrionForCausalLM`
- Orion
- :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc.
- -
+ -
* - :code:`PhiForCausalLM`
- Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
- -
+ -
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
- -
+ -
* - :code:`Qwen2ForCausalLM`
- Qwen2
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
@@ -126,11 +126,11 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`Qwen2MoeForCausalLM`
- Qwen2MoE
- :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.
- -
+ -
* - :code:`StableLmForCausalLM`
- StableLM
- :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
- -
+ -
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model.
From 8db1bf32f8924403c6a845b5ce71ba0f14beb038 Mon Sep 17 00:00:00 2001
From: Roy
Date: Mon, 15 Apr 2024 08:43:54 +0800
Subject: [PATCH 043/413] [Misc] Upgrade triton to 2.2.0 (#4061)
---
requirements-cpu.txt | 2 +-
requirements-cuda.txt | 1 -
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
index 5779b38b24e69..e911ad03295f0 100644
--- a/requirements-cpu.txt
+++ b/requirements-cpu.txt
@@ -3,4 +3,4 @@
# Dependencies for x86_64 CPUs
torch == 2.2.1+cpu
-triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error.
\ No newline at end of file
+triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
\ No newline at end of file
diff --git a/requirements-cuda.txt b/requirements-cuda.txt
index 6ee75e8139c04..c6d2cd46aee54 100644
--- a/requirements-cuda.txt
+++ b/requirements-cuda.txt
@@ -7,4 +7,3 @@ pynvml == 11.5.0
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1
xformers == 0.0.25 # Requires PyTorch 2.2.1
-triton >= 2.1.0
From e11e2007368b22fce05b9ecdf00dd48eda471f9e Mon Sep 17 00:00:00 2001
From: Zhuohan Li
Date: Sun, 14 Apr 2024 21:50:08 -0700
Subject: [PATCH 044/413] [Bugfix] Fix filelock version requirement (#4075)
---
requirements-common.txt | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/requirements-common.txt b/requirements-common.txt
index c96f9c9937fb0..90a3bc8abc1db 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -12,4 +12,5 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0
-typing_extensions
\ No newline at end of file
+typing_extensions
+filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
From 0003e9154bf1091d0de7ca7a6c7f1253df1eca5b Mon Sep 17 00:00:00 2001
From: "Li, Jiang"
Date: Mon, 15 Apr 2024 23:35:55 +0800
Subject: [PATCH 045/413] [Misc][Minor] Fix CPU block num log in CPUExecutor.
(#4088)
---
vllm/executor/cpu_executor.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 33e67d8b3eec2..e63a88be7868f 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -74,7 +74,10 @@ def initialize_cache(self, num_gpu_blocks: int,
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
- logger.info(f"# CPU blocks: {num_cpu_blocks}")
+ # NOTE: `cpu block` for CPU backend is located on CPU memory but is
+ # referred as `gpu block`. Because we want to reuse the existing block
+ # management procedure.
+ logger.info(f"# CPU blocks: {num_gpu_blocks}")
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self,
From eb46fbfda25348422918c4a876e17aef05fc5e34 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Mon, 15 Apr 2024 13:05:09 -0700
Subject: [PATCH 046/413] [Core] Simplifications to executor classes (#4071)
---
vllm/executor/cpu_executor.py | 31 +++++++++-------------------
vllm/executor/executor_base.py | 27 +++++++++++++++++-------
vllm/executor/gpu_executor.py | 32 ++++-------------------------
vllm/executor/neuron_executor.py | 29 ++++++--------------------
vllm/executor/ray_gpu_executor.py | 34 ++++---------------------------
5 files changed, 44 insertions(+), 109 deletions(-)
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index e63a88be7868f..f562e4e0ae3de 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -1,10 +1,9 @@
import os
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Set, Tuple
import torch
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig)
+from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -16,23 +15,13 @@
class CPUExecutor(ExecutorBase):
- def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
- assert device_config.device_type == "cpu"
- assert lora_config is None, "cpu backend doesn't support LoRA"
- model_config = _verify_and_get_model_config(model_config)
- cache_config = _verify_and_get_cache_config(cache_config)
- scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
-
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
+ def _init_executor(self) -> None:
+ assert self.device_config.device_type == "cpu"
+ assert self.lora_config is None, "cpu backend doesn't support LoRA"
+ self.model_config = _verify_and_get_model_config(self.model_config)
+ self.cache_config = _verify_and_get_cache_config(self.cache_config)
+ self.scheduler_config = _verify_and_get_scheduler_config(
+ self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
@@ -99,7 +88,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index bbfbfc689c99f..bbb6ec80f7b7e 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
- VisionLanguageConfig)
+ TensorizerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@@ -16,7 +16,6 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
"""
- @abstractmethod
def __init__(
self,
model_config: ModelConfig,
@@ -27,8 +26,23 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
+ tensorizer_config: Optional[TensorizerConfig],
) -> None:
- raise NotImplementedError
+ self.model_config = model_config
+ self.cache_config = cache_config
+ self.lora_config = lora_config
+ self.parallel_config = parallel_config
+ self.scheduler_config = scheduler_config
+ self.device_config = device_config
+ self.vision_language_config = vision_language_config
+ self.speculative_config = speculative_config
+ self.tensorizer_config = tensorizer_config
+
+ self._init_executor()
+
+ @abstractmethod
+ def _init_executor(self) -> None:
+ pass
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
@@ -71,7 +85,7 @@ def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
raise NotImplementedError
@abstractmethod
@@ -94,8 +108,7 @@ async def execute_model_async(
"""Executes one model step on the given sequences."""
raise NotImplementedError
- @abstractmethod
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
- raise NotImplementedError
+ self.check_health()
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index 30577ecf62faa..bae509f48025b 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -1,8 +1,5 @@
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Set, Tuple
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, SpeculativeConfig,
- TensorizerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -15,24 +12,8 @@
class GPUExecutor(ExecutorBase):
- def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- speculative_config: Optional[SpeculativeConfig],
- tensorizer_config: Optional[TensorizerConfig]) -> None:
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.vision_language_config = vision_language_config
- self.tensorizer_config = tensorizer_config
-
- assert (not speculative_config
+ def _init_executor(self) -> None:
+ assert (not self.speculative_config
), "Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU.
@@ -103,7 +84,7 @@ def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
@@ -127,8 +108,3 @@ async def execute_model_async(
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output
-
- async def check_health_async(self) -> None:
- # GPUExecutor will always be healthy as long as
- # it's running.
- return
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index d45f18e466256..273b17a927efd 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -1,8 +1,5 @@
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Set, Tuple
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, SpeculativeConfig,
- VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -13,24 +10,10 @@
class NeuronExecutor(ExecutorBase):
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- speculative_config: Optional[SpeculativeConfig],
- ) -> None:
- self.model_config = model_config
- self.cache_config = cache_config
- assert lora_config is None, "LoRA is not supported for Neuron backend."
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- assert (not speculative_config
+ def _init_executor(self) -> None:
+ assert (self.lora_config is
+ None), "LoRA is not supported for Neuron backend."
+ assert (not self.speculative_config
), "Speculative decoding not yet supported for Neuron backend."
# Instantiate the worker and load the model to the device.
@@ -80,7 +63,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 28dc3e0db312a..5db2f3f652532 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -3,11 +3,8 @@
import os
import pickle
from collections import defaultdict
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, SpeculativeConfig,
- TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
@@ -32,27 +29,8 @@
class RayGPUExecutor(ExecutorBase):
- def __init__(
- self,
- model_config: ModelConfig,
- cache_config: CacheConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- lora_config: Optional[LoRAConfig],
- vision_language_config: Optional[VisionLanguageConfig],
- speculative_config: Optional[SpeculativeConfig],
- tensorizer_config: Optional[TensorizerConfig],
- ) -> None:
- self.model_config = model_config
- self.cache_config = cache_config
- self.lora_config = lora_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.vision_language_config = vision_language_config
- self.tensorizer_config = tensorizer_config
- assert (not speculative_config
+ def _init_executor(self) -> None:
+ assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
@@ -273,7 +251,7 @@ def remove_lora(self, lora_id: int) -> bool:
lora_id=lora_id,
)
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def _run_workers(
@@ -416,7 +394,3 @@ async def execute_model_async(
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
-
- async def check_health_async(self) -> None:
- """Raises an error if engine is unhealthy."""
- self._check_if_any_actor_is_dead()
From d619ae2d19c41d9aa8f68fa0e5e32cc410dc2522 Mon Sep 17 00:00:00 2001
From: Sanger Steel
Date: Mon, 15 Apr 2024 16:28:25 -0400
Subject: [PATCH 047/413] [Doc] Add better clarity for tensorizer usage (#4090)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
---
docs/source/models/engine_args.rst | 2 +-
examples/tensorize_vllm_model.py | 60 +++++++++++++++++-------
vllm/model_executor/tensorizer_loader.py | 6 +--
3 files changed, 46 insertions(+), 22 deletions(-)
diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst
index 886a806934c04..235cb4e128c99 100644
--- a/docs/source/models/engine_args.rst
+++ b/docs/source/models/engine_args.rst
@@ -45,7 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
* "safetensors" will load the weights in the safetensors format.
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
* "dummy" will initialize the weights with random values, mainly for profiling.
- * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`.
+ * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_ See `examples/tensorize_vllm_model.py `_ to serialize a vLLM model, and for more information.
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py
index 3c20a38c7f726..8cf8be09d0b9c 100644
--- a/examples/tensorize_vllm_model.py
+++ b/examples/tensorize_vllm_model.py
@@ -23,14 +23,16 @@
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and
-deserialize vLLM models. These models can be loaded using tensorizer directly
-to the GPU extremely quickly. Tensor encryption and decryption is also
-supported, although libsodium must be installed to use it. Install
-vllm with tensorizer support using `pip install vllm[tensorizer]`.
+deserialize vLLM models. These models can be loaded using tensorizer
+to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
+or locally. Tensor encryption and decryption is also supported, although
+libsodium must be installed to use it. Install vllm with tensorizer support
+using `pip install vllm[tensorizer]`.
-To serialize a model, you can run something like this:
+To serialize a model, install vLLM from source, then run something
+like this from the root level of this repository:
-python tensorize_vllm_model.py \
+python -m examples.tensorize_vllm_model \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
serialize \
@@ -38,31 +40,57 @@
--suffix vllm
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
-and saves it to your S3 bucket. A local directory can also be used.
+and saves it to your S3 bucket. A local directory can also be used. This
+assumes your S3 credentials are specified as environment variables
+in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
+To provide S3 credentials directly, you can provide `--s3-access-key-id` and
+`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
+script.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
-To deserialize a model, you can run something like this:
+To deserialize a model, you can run something like this from the root
+level of this repository:
-python tensorize_vllm_model.py \
+python -m examples.tensorize_vllm_model \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
-To provide S3 credentials, you can provide `--s3-access-key-id` and
-`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
-the OpenAI entrypoint, as arguments for LLM(), or as environment variables
-in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
-
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
-For more information on the available arguments, run
-`python tensorize_vllm_model.py --help`.
+For more information on the available arguments for serializing, run
+`python -m examples.tensorize_vllm_model serialize --help`.
+
+Or for deserializing:
+
+`python -m examples.tensorize_vllm_model deserialize --help`.
+
+Once a model is serialized, it can be used to load the model when running the
+OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
+the `--tensorizer-uri` CLI argument that is functionally the same as the
+`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
+signify that the model to be deserialized is a vLLM model, rather than a
+HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
+in the same inference server, albeit without the speed optimizations. To
+deserialize an encrypted file, the `--encryption-keyfile` argument can be used
+to provide the path to the keyfile used to encrypt the model weights. For
+information on all the arguments that can be used to configure tensorizer's
+deserialization, check out the tensorizer options argument group in the
+`vllm/entrypoints/openai/api_server.py` script with `--help`.
+
+Tensorizer can also be invoked with the `LLM` class directly to load models:
+
+ llm = LLM(model="facebook/opt-125m",
+ load_format="tensorizer",
+ tensorizer_uri=path_to_opt_tensors,
+ num_readers=3,
+ vllm_tensorized=True)
"""
diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/tensorizer_loader.py
index ed3ad9e2ffa15..8550cc97aefe8 100644
--- a/vllm/model_executor/tensorizer_loader.py
+++ b/vllm/model_executor/tensorizer_loader.py
@@ -126,7 +126,6 @@ def __post_init__(self):
"s3_endpoint": self.s3_endpoint,
}
- # Omitting self.dtype and self.device as this behaves weirdly
self.deserializer_params = {
"verify_hash": self.verify_hash,
"encryption": self.encryption_keyfile,
@@ -145,7 +144,7 @@ def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments"""
- # Create the argument group
+ # Tensorizer options arg group
group = parser.add_argument_group(
'tensorizer options',
description=('Options for configuring the behavior of the'
@@ -205,9 +204,7 @@ def add_cli_args(
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
- # Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
- # Set the attributes from the parsed arguments.
tensorizer_args = cls(**{
attr: getattr(args, attr)
for attr in attrs if hasattr(args, attr)
@@ -291,7 +288,6 @@ def deserialize(self):
nn.Module: The deserialized model.
"""
before_mem = get_mem_usage()
- # Lazy load the tensors from S3 into the model.
start = time.perf_counter()
with open_stream(
self.tensorizer_args.tensorizer_uri,
From 4695397dcfef693a0a10f1eb8bf77ea905c54829 Mon Sep 17 00:00:00 2001
From: Ricky Xu
Date: Mon, 15 Apr 2024 14:24:45 -0700
Subject: [PATCH 048/413] [Bugfix] Fix ray workers profiling with nsight
(#4095)
---
vllm/executor/ray_gpu_executor.py | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 5db2f3f652532..7aca5e36107aa 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -48,6 +48,21 @@ def _init_executor(self) -> None:
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
+ def _configure_ray_workers_use_nsight(self,
+ ray_remote_kwargs) -> Dict[str, Any]:
+ # If nsight profiling is enabled, we need to set the profiling
+ # configuration for the ray workers as runtime env.
+ runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
+ runtime_env.update({
+ "nsight": {
+ "t": "cuda,cudnn,cublas",
+ "o": "'worker_process_%p'",
+ "cuda-graph-trace": "node",
+ }
+ })
+
+ return ray_remote_kwargs
+
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
@@ -63,6 +78,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = []
+ if self.parallel_config.ray_workers_use_nsight:
+ ray_remote_kwargs = self._configure_ray_workers_use_nsight(
+ ray_remote_kwargs)
+
# Create the workers.
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
From 37e84a403d6d11b670a42e84153204cd8b76b849 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Tue, 16 Apr 2024 06:47:31 +0900
Subject: [PATCH 049/413] [Typing] Fix Sequence type GenericAlias only
available after Python 3.9. (#4092)
---
vllm/core/block_manager_v1.py | 5 +++--
vllm/core/block_manager_v2.py | 2 +-
vllm/core/interfaces.py | 2 +-
vllm/utils.py | 7 ++++---
4 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py
index e391a3b1e5a33..be093922b84f2 100644
--- a/vllm/core/block_manager_v1.py
+++ b/vllm/core/block_manager_v1.py
@@ -1,9 +1,10 @@
"""A block manager that manages token blocks."""
from abc import ABC, abstractmethod
-from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile
from os.path import commonprefix
-from typing import Dict, List, Optional, Set
+from typing import Dict, List, Optional
+from typing import Sequence as GenericSequence
+from typing import Set
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py
index 19f0cf415eb34..6339a6baf4161 100644
--- a/vllm/core/block_manager_v2.py
+++ b/vllm/core/block_manager_v2.py
@@ -1,6 +1,6 @@
"""A block manager that manages token blocks."""
-from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional
+from typing import Sequence as GenericSequence
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py
index c1f68a2e891bf..56c2c5995c38b 100644
--- a/vllm/core/interfaces.py
+++ b/vllm/core/interfaces.py
@@ -1,7 +1,7 @@
import enum
from abc import ABC, abstractmethod
-from collections.abc import Sequence as GenericSequence
from typing import Dict, List
+from typing import Sequence as GenericSequence
from vllm.sequence import Sequence, SequenceGroup
diff --git a/vllm/utils.py b/vllm/utils.py
index 4c0dc9ca729a9..aad62516ad1b9 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -6,11 +6,12 @@
import subprocess
import uuid
import warnings
-from collections import OrderedDict, defaultdict
+from collections import defaultdict
from functools import lru_cache, partial
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
- Hashable, List, Optional, Tuple, TypeVar, Union)
+ Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
+ Union)
import psutil
import torch
@@ -51,7 +52,7 @@ def reset(self) -> None:
class LRUCache(Generic[T]):
def __init__(self, capacity: int):
- self.cache = OrderedDict[Hashable, T]()
+ self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
From 4e7ee664e201442e24e2298a36a5264b98691626 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Tue, 16 Apr 2024 14:24:53 +0900
Subject: [PATCH 050/413] [Core] Fix engine-use-ray broken (#4105)
---
tests/async_engine/test_api_server.py | 17 +++++++++++++----
vllm/engine/async_llm_engine.py | 7 +++----
2 files changed, 16 insertions(+), 8 deletions(-)
diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py
index 248bfbc8ab5c0..7f57d5cf9b182 100644
--- a/tests/async_engine/test_api_server.py
+++ b/tests/async_engine/test_api_server.py
@@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict:
@pytest.fixture
-def api_server(tokenizer_pool_size: int):
+def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
+ worker_use_ray: bool):
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
- uvicorn_process = subprocess.Popen([
+ commands = [
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m", "--host",
"127.0.0.1", "--tokenizer-pool-size",
str(tokenizer_pool_size)
- ])
+ ]
+ if engine_use_ray:
+ commands.append("--engine-use-ray")
+ if worker_use_ray:
+ commands.append("--worker-use-ray")
+ uvicorn_process = subprocess.Popen(commands)
yield
uvicorn_process.terminate()
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
-def test_api_server(api_server, tokenizer_pool_size: int):
+@pytest.mark.parametrize("worker_use_ray", [False, True])
+@pytest.mark.parametrize("engine_use_ray", [False, True])
+def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
+ engine_use_ray: bool):
"""
Run the API server and test it.
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index f610495135121..1dbf58904541c 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -333,8 +333,7 @@ def from_engine_args(
if engine_config.device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for "
"async engine yet.")
- elif (engine_config.parallel_config.worker_use_ray
- or engine_args.engine_use_ray):
+ elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
@@ -410,8 +409,8 @@ def _init_engine(self, *args,
else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
- cache_config = args[1]
- parallel_config = args[2]
+ cache_config = kwargs["cache_config"]
+ parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1:
num_gpus = cache_config.gpu_memory_utilization
else:
From 05434764cd99990035779cf9a4ed86623b528825 Mon Sep 17 00:00:00 2001
From: Noam Gat
Date: Tue, 16 Apr 2024 08:54:57 +0300
Subject: [PATCH 051/413] LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by: Simon Mo
---
requirements-common.txt | 1 +
tests/entrypoints/test_guided_processors.py | 42 +++++++-
tests/entrypoints/test_openai_server.py | 69 ++++++++----
vllm/config.py | 26 ++++-
vllm/engine/arg_utils.py | 18 +++-
vllm/engine/llm_engine.py | 10 +-
vllm/entrypoints/openai/protocol.py | 12 +++
vllm/entrypoints/openai/serving_chat.py | 6 +-
vllm/entrypoints/openai/serving_completion.py | 6 +-
.../guided_decoding/__init__.py | 25 +++++
.../lm_format_enforcer_decoding.py | 69 ++++++++++++
.../outlines_decoding.py} | 7 +-
.../outlines_logits_processors.py} | 100 +++++++++---------
13 files changed, 304 insertions(+), 87 deletions(-)
create mode 100644 vllm/model_executor/guided_decoding/__init__.py
create mode 100644 vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
rename vllm/model_executor/{guided_decoding.py => guided_decoding/outlines_decoding.py} (93%)
rename vllm/model_executor/{guided_logits_processors.py => guided_decoding/outlines_logits_processors.py} (70%)
diff --git a/requirements-common.txt b/requirements-common.txt
index 90a3bc8abc1db..c1614d2537b25 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -11,6 +11,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
+lm-format-enforcer == 0.9.3
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py
index 5622744566bcc..30f0ad5d8272f 100644
--- a/tests/entrypoints/test_guided_processors.py
+++ b/tests/entrypoints/test_guided_processors.py
@@ -1,11 +1,14 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
-
+import pytest
import torch
from transformers import AutoTokenizer
-from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
- RegexLogitsProcessor)
+from vllm.entrypoints.openai.protocol import CompletionRequest
+from vllm.model_executor.guided_decoding import (
+ get_guided_decoding_logits_processor)
+from vllm.model_executor.guided_decoding.outlines_logits_processors import (
+ JSONLogitsProcessor, RegexLogitsProcessor)
TEST_SCHEMA = {
"type": "object",
@@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
+async def test_guided_logits_processor_black_box(backend: str):
+ tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
+ token_ids = tokenizer.encode(
+ f"Give an example IPv4 address with this regex: {TEST_REGEX}")
+ regex_request = CompletionRequest(model='test',
+ prompt=token_ids,
+ guided_regex=TEST_REGEX)
+ regex_lp = await get_guided_decoding_logits_processor(
+ backend, regex_request, tokenizer)
+ assert regex_lp is not None
+ tensor = torch.rand(32000)
+ original_tensor = torch.clone(tensor)
+ tensor = regex_lp(token_ids, tensor)
+ assert tensor.shape == original_tensor.shape
+ assert not torch.allclose(tensor, original_tensor)
+
+ token_ids = tokenizer.encode(
+ f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
+ json_request = CompletionRequest(model='test',
+ prompt=token_ids,
+ guided_json=TEST_SCHEMA)
+ json_lp = await get_guided_decoding_logits_processor(
+ backend, json_request, tokenizer)
+ assert json_lp is not None
+ tensor = torch.rand(32000)
+ original_tensor = torch.clone(tensor)
+ tensor = json_lp(token_ids, tensor)
+ assert tensor.shape == original_tensor.shape
+ assert not torch.allclose(tensor, original_tensor)
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index 7940430b8b654..14e6ee0ffe9d9 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text
-async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile "
@@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
n=3,
temperature=1.0,
max_tokens=500,
- extra_body=dict(guided_json=TEST_SCHEMA))
+ extra_body=dict(guided_json=TEST_SCHEMA,
+ guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
@@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
-async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
- max_tokens=500,
- extra_body=dict(guided_json=TEST_SCHEMA))
+ max_tokens=1000,
+ extra_body=dict(guided_json=TEST_SCHEMA,
+ guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json1 = json.loads(message.content)
@@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
- max_tokens=500,
- extra_body=dict(guided_json=TEST_SCHEMA))
+ max_tokens=1000,
+ extra_body=dict(guided_json=TEST_SCHEMA,
+ guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message
assert message.content is not None
json2 = json.loads(message.content)
@@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert json1["age"] != json2["age"]
-async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3,
temperature=1.0,
max_tokens=20,
- extra_body=dict(guided_regex=TEST_REGEX))
+ extra_body=dict(guided_regex=TEST_REGEX,
+ guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3
@@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
-async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
- extra_body=dict(guided_regex=TEST_REGEX))
+ extra_body=dict(guided_regex=TEST_REGEX,
+ guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content
assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None
@@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=20,
- extra_body=dict(guided_regex=TEST_REGEX))
+ extra_body=dict(guided_regex=TEST_REGEX,
+ guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content
assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2
-async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
completion = await client.completions.create(
model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ",
n=2,
temperature=1.0,
max_tokens=10,
- extra_body=dict(guided_choice=TEST_CHOICE))
+ extra_body=dict(guided_choice=TEST_CHOICE,
+ guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2
@@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
assert completion.choices[i].text in TEST_CHOICE
-async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
- extra_body=dict(guided_choice=TEST_CHOICE))
+ extra_body=dict(guided_choice=TEST_CHOICE,
+ guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE
@@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME,
messages=messages,
max_tokens=10,
- extra_body=dict(guided_choice=TEST_CHOICE))
+ extra_body=dict(guided_choice=TEST_CHOICE,
+ guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE
assert choice1 != choice2
-async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
with pytest.raises(openai.BadRequestError):
_ = await client.completions.create(
model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42",
- extra_body=dict(guided_json=42))
+ extra_body=dict(guided_json=42,
+ guided_decoding_backend=guided_decoding_backend))
messages = [{
"role": "system",
diff --git a/vllm/config.py b/vllm/config.py
index dce2944b2ee8a..bf31b03b7c6c4 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -66,8 +66,8 @@ class ModelConfig:
weights. If None, we assume the model weights are not quantized.
quantization_param_path: Path to JSON file containing scaling factors.
Used to load KV cache scaling factors into the model when KV cache
- type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
- be used to load activation and weight scaling factors when the
+ type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
+ be used to load activation and weight scaling factors when the
model dtype is FP8_E4M3 on ROCm.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
@@ -422,7 +422,7 @@ def verify_with_parallel_config(
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
-
+
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
@@ -446,9 +446,9 @@ def create_config(
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
-
+
If tokenizer_pool_size is 0, return None.
-
+
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
@@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
return int(max_model_len)
+@dataclass
+class DecodingConfig:
+ """Dataclass which contains the decoding strategy of the engine"""
+
+ # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
+ guided_decoding_backend: str = 'outlines'
+
+ def __post_init__(self):
+ valid_guided_backends = ['outlines', 'lm-format-enforcer']
+ backend = self.guided_decoding_backend
+ if backend not in valid_guided_backends:
+ raise ValueError(f"Invalid guided_decoding_backend '{backend},"
+ f"must be one of {valid_guided_backends}")
+
+
@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
@@ -1093,6 +1108,7 @@ class EngineConfig:
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
+ decoding_config: Optional[DecodingConfig]
tensorizer_config: Optional[TensorizerConfig]
def __post_init__(self):
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 831a03be65f61..3de74b0ac28b9 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -5,9 +5,9 @@
from dataclasses import dataclass
from typing import BinaryIO, Optional, Union
-from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
- ModelConfig, ParallelConfig, SchedulerConfig,
- SpeculativeConfig, TensorizerConfig,
+from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
+ EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
+ SchedulerConfig, SpeculativeConfig, TensorizerConfig,
TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple
@@ -80,6 +80,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False
+ guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
@@ -200,6 +201,13 @@ def add_cli_args(
default=EngineArgs.max_model_len,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
+ parser.add_argument(
+ '--guided-decoding-backend',
+ type=str,
+ default='outlines',
+ choices=['outlines', 'lm-format-enforcer'],
+ help='Which engine will be used for guided decoding'
+ ' (JSON schema / regex etc)')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
@@ -511,6 +519,9 @@ def create_engine_config(self, ) -> EngineConfig:
else:
vision_language_config = None
+ decoding_config = DecodingConfig(
+ guided_decoding_backend=self.guided_decoding_backend)
+
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
@@ -519,6 +530,7 @@ def create_engine_config(self, ) -> EngineConfig:
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
+ decoding_config=decoding_config,
tensorizer_config=tensorizer_config)
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 8c37c5a9d6ee9..f06c1d18ace4b 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -4,9 +4,10 @@
from transformers import PreTrainedTokenizer
import vllm
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, SpeculativeConfig,
- TensorizerConfig, VisionLanguageConfig)
+from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
+ ModelConfig, ParallelConfig, SchedulerConfig,
+ SpeculativeConfig, TensorizerConfig,
+ VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
@@ -74,6 +75,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
+ decoding_config: Optional[DecodingConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
@@ -100,6 +102,7 @@ def __init__(
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"quantization_param_path={model_config.quantization_param_path}, "
f"device_config={device_config.device}, "
+ f"decoding_config={decoding_config!r}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
@@ -111,6 +114,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
+ self.decoding_config = decoding_config or DecodingConfig()
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index f94d22d279cc4..cf779d44c816b 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
+ guided_decoding_backend: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, will override the default guided decoding backend "
+ "of the server for this specific request. If set, must be either "
+ "'outlines' / 'lm-format-enforcer'"))
# doc: end-chat-completion-extra-params
@@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
+ guided_decoding_backend: Optional[str] = Field(
+ default=None,
+ description=(
+ "If specified, will override the default guided decoding backend "
+ "of the server for this specific request. If set, must be one of "
+ "'outlines' / 'lm-format-enforcer'"))
# doc: end-completion-extra-params
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index a03c5dc88108f..c9ed4a9de20f4 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -68,9 +68,13 @@ async def create_chat_completion(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
+ decoding_config = self.engine.engine.decoding_config
+ guided_decoding_backend = request.guided_decoding_backend \
+ or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
- request, await self.engine.get_tokenizer()))
+ guided_decoding_backend, request, await
+ self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index e24aa2489a80f..a71f2d6a4426a 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -88,9 +88,13 @@ async def create_completion(self, request: CompletionRequest,
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
+ decoding_config = self.engine.engine.decoding_config
+ guided_decoding_backend = request.guided_decoding_backend \
+ or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
- request, await self.engine.get_tokenizer()))
+ guided_decoding_backend, request, await
+ self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py
new file mode 100644
index 0000000000000..0558d6c95d97b
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/__init__.py
@@ -0,0 +1,25 @@
+from typing import Optional, Union
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ CompletionRequest)
+from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
+ get_lm_format_enforcer_guided_decoding_logits_processor)
+from vllm.model_executor.guided_decoding.outlines_decoding import (
+ get_outlines_guided_decoding_logits_processor)
+from vllm.sampling_params import LogitsProcessor
+
+
+async def get_guided_decoding_logits_processor(
+ guided_decoding_backend: str, request: Union[CompletionRequest,
+ ChatCompletionRequest],
+ tokenizer) -> Optional[LogitsProcessor]:
+ if guided_decoding_backend == 'outlines':
+ return await get_outlines_guided_decoding_logits_processor(
+ request, tokenizer)
+ if guided_decoding_backend == 'lm-format-enforcer':
+ return await get_lm_format_enforcer_guided_decoding_logits_processor(
+ request, tokenizer)
+
+ raise ValueError(
+ f"Unknown guided decoding backend '{guided_decoding_backend}'. "
+ "Must be one of 'outlines, 'lm-format-enforcer'")
diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
new file mode 100644
index 0000000000000..0d74a5f8e81ff
--- /dev/null
+++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
@@ -0,0 +1,69 @@
+from functools import lru_cache
+from json import loads as json_loads
+from typing import Optional, Union
+
+from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
+ RegexParser, StringParser,
+ TokenEnforcerTokenizerData, UnionParser)
+from lmformatenforcer.integrations.vllm import (
+ build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
+from pydantic import BaseModel
+from transformers import PreTrainedTokenizerBase
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ CompletionRequest)
+from vllm.model_executor.guided_decoding.outlines_decoding import (
+ get_outlines_guided_decoding_logits_processor)
+from vllm.sampling_params import LogitsProcessor
+
+
+async def get_lm_format_enforcer_guided_decoding_logits_processor(
+ request: Union[CompletionRequest, ChatCompletionRequest],
+ tokenizer) -> Optional[LogitsProcessor]:
+ """
+ Given an OpenAI-compatible request, check for guided decoding parameters
+ and get the necessary logits processor for the given guide.
+ We cache logit processors by (guide, tokenizer), and on cache hit
+ we make a shallow copy to reuse the same underlying FSM.
+ """
+
+ tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
+ tokenizer)
+ character_level_parser: CharacterLevelParser
+ if request.guided_json:
+ schema = _normalize_json_schema_object(request.guided_json)
+ character_level_parser = JsonSchemaParser(schema)
+ elif request.guided_choice:
+ character_level_parser = UnionParser(
+ [StringParser(choice) for choice in request.guided_choice])
+ elif request.guided_regex:
+ character_level_parser = RegexParser(request.guided_regex)
+ elif request.guided_grammar:
+ # CFG grammar not supported by LMFE, revert to outlines
+ return await get_outlines_guided_decoding_logits_processor(
+ request, tokenizer)
+ elif (request.response_format is not None
+ and request.response_format.type == "json_object"):
+ character_level_parser = JsonSchemaParser(
+ None) # None means any json object
+ else:
+ return None
+
+ logits_processor = build_vllm_logits_processor(tokenizer_data,
+ character_level_parser)
+ return logits_processor
+
+
+def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
+ if isinstance(schema, str):
+ return json_loads(schema)
+ if isinstance(schema, dict):
+ return schema
+ if isinstance(schema, BaseModel):
+ return schema.model_json_schema()
+
+
+@lru_cache
+def _cached_build_vllm_token_enforcer_tokenizer_data(
+ tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
+ return build_vllm_token_enforcer_tokenizer_data(tokenizer)
diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py
similarity index 93%
rename from vllm/model_executor/guided_decoding.py
rename to vllm/model_executor/guided_decoding/outlines_decoding.py
index 8e710f1ac2b53..bd4564a36e1ed 100644
--- a/vllm/model_executor/guided_decoding.py
+++ b/vllm/model_executor/guided_decoding/outlines_decoding.py
@@ -12,9 +12,8 @@
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
-from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor,
- JSONLogitsProcessor,
- RegexLogitsProcessor)
+from vllm.model_executor.guided_decoding.outlines_logits_processors import (
+ CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
class GuidedDecodingMode(Enum):
@@ -54,7 +53,7 @@ class GuidedDecodingMode(Enum):
global_thread_pool = None # used for generating logits processor fsm
-async def get_guided_decoding_logits_processor(
+async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
"""
diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
similarity index 70%
rename from vllm/model_executor/guided_logits_processors.py
rename to vllm/model_executor/guided_decoding/outlines_logits_processors.py
index 035fe00037328..28041695546dc 100644
--- a/vllm/model_executor/guided_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import json
import math
from collections import defaultdict
+from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union
import torch
@@ -27,50 +29,6 @@
class BaseLogitsProcessor:
- def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
- """Adapt vLLM's tokenizer to use to compile the FSM.
-
- The API of Outlines tokenizers is slightly different to that of
- `transformers`. The decoder of outlines, returns a list whereas
- the decode of vLLM returns an str. To sync the vLLM decoder with
- outlines internal api, the decoder should be adapted. In addition
- we need to handle the missing spaces to Llama's tokenizer to be
- able to compile FSMs for this model.
-
- """
- if getattr(tokenizer, "_outlines_adapted", False):
- return tokenizer
-
- tokenizer.vocabulary = tokenizer.get_vocab()
- tokenizer.special_tokens = set(tokenizer.all_special_tokens)
-
- def convert_token_to_string(token: str) -> str:
- from transformers.file_utils import SPIECE_UNDERLINE
-
- string = tokenizer.convert_tokens_to_string([token])
-
- # A hack to handle missing spaces to HF's Llama tokenizers
- if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
- return " " + string
-
- return string
-
- def change_decoder(
- decoder: Callable[[List[int]], str]
- ) -> Callable[[List[int]], List[str]]:
- """Sync vLLM's decoder with the outlines by returning list."""
-
- def new_decoder(inp_tokens: List[int]) -> List[str]:
- return [decoder(inp_tokens)]
-
- return new_decoder
-
- tokenizer.convert_token_to_string = convert_token_to_string
- tokenizer.decode = change_decoder(tokenizer.decode)
- setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
-
- return tokenizer
-
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
@@ -78,7 +36,6 @@ def init_state(self):
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
-
seq_id = hash(tuple(input_ids))
if len(input_ids) == 0:
@@ -96,7 +53,6 @@ def __call__(self, input_ids: List[int],
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
-
return scores
@@ -113,7 +69,7 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer
"""
- tokenizer = self.adapt_tokenizer(tokenizer)
+ tokenizer = _adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
@@ -167,6 +123,54 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer
"""
- tokenizer = self.adapt_tokenizer(tokenizer)
+ tokenizer = _adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm
+
+
+@lru_cache
+def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
+ """Adapt vLLM's tokenizer to use to compile the FSM.
+
+ The API of Outlines tokenizers is slightly different to that of
+ `transformers`. The decoder of outlines, returns a list whereas
+ the decode of vLLM returns an str. To sync the vLLM decoder with
+ outlines internal api, the decoder should be adapted. In addition
+ we need to handle the missing spaces to Llama's tokenizer to be
+ able to compile FSMs for this model.
+
+ """
+ if getattr(tokenizer, "_outlines_adapted", False):
+ return tokenizer
+
+ tokenizer = copy.deepcopy(tokenizer)
+
+ tokenizer.vocabulary = tokenizer.get_vocab()
+ tokenizer.special_tokens = set(tokenizer.all_special_tokens)
+
+ def convert_token_to_string(token: str) -> str:
+ from transformers.file_utils import SPIECE_UNDERLINE
+
+ string = tokenizer.convert_tokens_to_string([token])
+
+ # A hack to handle missing spaces to HF's Llama tokenizers
+ if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
+ return " " + string
+
+ return string
+
+ def change_decoder(
+ decoder: Callable[[List[int]],
+ str]) -> Callable[[List[int]], List[str]]:
+ """Sync vLLM's decoder with the outlines by returning list."""
+
+ def new_decoder(inp_tokens: List[int]) -> List[str]:
+ return [decoder(inp_tokens)]
+
+ return new_decoder
+
+ tokenizer.convert_token_to_string = convert_token_to_string
+ tokenizer.decode = change_decoder(tokenizer.decode)
+ setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
+
+ return tokenizer
From 69e1d2fb6922b2d388bae41286d8867976cbd6c6 Mon Sep 17 00:00:00 2001
From: Antoni Baum
Date: Tue, 16 Apr 2024 11:34:39 -0700
Subject: [PATCH 052/413] [Core] Refactor model loading code (#4097)
---
.buildkite/test-pipeline.yaml | 2 +-
examples/fp8/extract_scales.py | 4 +-
examples/tensorize_vllm_model.py | 2 +-
tests/lora/conftest.py | 10 +-
tests/lora/test_worker.py | 10 +-
tests/model_executor/weight_utils.py | 2 +-
.../test_autogptq_marlin_configs.py | 4 -
tests/samplers/test_sampler.py | 14 +-
tests/spec_decode/utils.py | 1 +
.../__init__.py | 0
.../tensorize_vllm_model_for_testing.py | 4 +-
.../test_tensorizer.py | 139 ++++---
tests/test_config.py | 4 -
tests/test_logits_processor.py | 7 +-
tests/worker/test_model_runner.py | 39 +-
tests/worker/test_swap.py | 1 +
vllm/config.py | 201 ++++------
vllm/engine/arg_utils.py | 59 ++-
vllm/engine/llm_engine.py | 19 +-
vllm/executor/cpu_executor.py | 1 +
vllm/executor/executor_base.py | 10 +-
vllm/executor/gpu_executor.py | 2 +-
vllm/executor/ray_gpu_executor.py | 5 +-
vllm/model_executor/model_loader.py | 128 -------
vllm/model_executor/model_loader/__init__.py | 30 ++
vllm/model_executor/model_loader/loader.py | 354 ++++++++++++++++++
.../neuron.py} | 0
.../tensorizer.py} | 116 ++++--
vllm/model_executor/model_loader/utils.py | 40 ++
.../{ => model_loader}/weight_utils.py | 295 +++++++--------
vllm/model_executor/models/baichuan.py | 14 +-
vllm/model_executor/models/bloom.py | 14 +-
vllm/model_executor/models/chatglm.py | 14 +-
vllm/model_executor/models/commandr.py | 16 +-
vllm/model_executor/models/dbrx.py | 16 +-
vllm/model_executor/models/decilm.py | 14 +-
vllm/model_executor/models/deepseek.py | 20 +-
vllm/model_executor/models/falcon.py | 14 +-
vllm/model_executor/models/gemma.py | 14 +-
vllm/model_executor/models/gpt2.py | 14 +-
vllm/model_executor/models/gpt_bigcode.py | 14 +-
vllm/model_executor/models/gpt_j.py | 14 +-
vllm/model_executor/models/gpt_neox.py | 14 +-
vllm/model_executor/models/internlm2.py | 14 +-
vllm/model_executor/models/jais.py | 16 +-
vllm/model_executor/models/llama.py | 16 +-
vllm/model_executor/models/llava.py | 14 +-
vllm/model_executor/models/minicpm.py | 14 +-
vllm/model_executor/models/mixtral.py | 20 +-
vllm/model_executor/models/mixtral_quant.py | 19 +-
vllm/model_executor/models/mpt.py | 14 +-
vllm/model_executor/models/olmo.py | 16 +-
vllm/model_executor/models/opt.py | 14 +-
vllm/model_executor/models/orion.py | 14 +-
vllm/model_executor/models/phi.py | 14 +-
vllm/model_executor/models/qwen.py | 14 +-
vllm/model_executor/models/qwen2.py | 14 +-
vllm/model_executor/models/qwen2_moe.py | 20 +-
vllm/model_executor/models/stablelm.py | 14 +-
vllm/model_executor/models/starcoder2.py | 14 +-
vllm/model_executor/models/xverse.py | 14 +-
vllm/transformers_utils/tokenizer.py | 21 +-
vllm/worker/cpu_model_runner.py | 12 +-
vllm/worker/cpu_worker.py | 7 +-
vllm/worker/model_runner.py | 15 +-
vllm/worker/neuron_model_runner.py | 2 +-
vllm/worker/worker.py | 10 +-
67 files changed, 1064 insertions(+), 973 deletions(-)
rename tests/{tensorizer => tensorizer_loader}/__init__.py (100%)
rename tests/{tensorizer => tensorizer_loader}/tensorize_vllm_model_for_testing.py (98%)
rename tests/{tensorizer => tensorizer_loader}/test_tensorizer.py (67%)
delete mode 100644 vllm/model_executor/model_loader.py
create mode 100644 vllm/model_executor/model_loader/__init__.py
create mode 100644 vllm/model_executor/model_loader/loader.py
rename vllm/model_executor/{neuron_model_loader.py => model_loader/neuron.py} (100%)
rename vllm/model_executor/{tensorizer_loader.py => model_loader/tensorizer.py} (78%)
create mode 100644 vllm/model_executor/model_loader/utils.py
rename vllm/model_executor/{ => model_loader}/weight_utils.py (53%)
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index aa4582bbda0c7..f39c3302ac2e9 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -92,7 +92,7 @@ steps:
parallelism: 4
- label: Tensorizer Test
- command: apt-get install curl libsodium23 && pytest -v -s tensorizer
+ command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
- label: Metrics Test
command: pytest -v -s metrics
diff --git a/examples/fp8/extract_scales.py b/examples/fp8/extract_scales.py
index 5e5b31265e3af..1eb961a5a76e3 100644
--- a/examples/fp8/extract_scales.py
+++ b/examples/fp8/extract_scales.py
@@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
-# Adapted from vllm/model_executor/weight_utils.py
+# Adapted from vllm/model_executor/model_loader/weight_utils.py
# The main differences are that we add the NPZ format and simplify
# its functionality drastically for our purposes (e.g. we assume that
# the quantized model exists locally and there is no need to download it)
@@ -71,7 +71,7 @@ def _prepare_hf_weights(
return hf_weights_files, use_safetensors
-# Adapted from vllm/model_executor/weight_utils.py
+# Adapted from vllm/model_executor/model_loader/weight_utils.py
def _hf_tensorfile_iterator(filename: str, load_format: str,
use_safetensors: bool):
if load_format == "npz":
diff --git a/examples/tensorize_vllm_model.py b/examples/tensorize_vllm_model.py
index 8cf8be09d0b9c..e2456168de9d5 100644
--- a/examples/tensorize_vllm_model.py
+++ b/examples/tensorize_vllm_model.py
@@ -16,8 +16,8 @@
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
+from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry
-from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py
index 1127cc33183c9..2dabfb6b4337c 100644
--- a/tests/lora/conftest.py
+++ b/tests/lora/conftest.py
@@ -153,11 +153,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model
- def get_model_patched(model_config, device_config, **kwargs):
- return get_model_old(model_config,
- device_config,
- lora_config=LoRAConfig(max_loras=4,
- max_lora_rank=8))
+ def get_model_patched(*, model_config, device_config, **kwargs):
+ kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
+ return get_model_old(model_config=model_config,
+ device_config=device_config,
+ **kwargs)
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py
index 54594690f7922..732e91a52c0a9 100644
--- a/tests/lora/test_worker.py
+++ b/tests/lora/test_worker.py
@@ -3,8 +3,8 @@
import tempfile
from unittest.mock import patch
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig)
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+ ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.worker.worker import Worker
@@ -18,12 +18,14 @@ def test_worker_apply_lora(sql_lora_files):
"meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
),
+ load_config=LoadConfig(
+ download_dir=None,
+ load_format="dummy",
+ ),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"),
diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py
index 3154f2826d10c..b0086dd7a7d71 100644
--- a/tests/model_executor/weight_utils.py
+++ b/tests/model_executor/weight_utils.py
@@ -3,7 +3,7 @@
import huggingface_hub.constants
import pytest
-from vllm.model_executor.weight_utils import enable_hf_transfer
+from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer
def test_hf_transfer_auto_activation():
diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py
index cd64622e2226f..1310b4da218b5 100644
--- a/tests/quantization/test_autogptq_marlin_configs.py
+++ b/tests/quantization/test_autogptq_marlin_configs.py
@@ -36,8 +36,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
@@ -49,8 +47,6 @@ def test_auto_gptq(model_quant_type: str, ) -> None:
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py
index 26e2d29ffd04c..dbbe13b8da060 100644
--- a/tests/samplers/test_sampler.py
+++ b/tests/samplers/test_sampler.py
@@ -32,7 +32,12 @@ def _prepare_test(
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
- model_runner = ModelRunner(None, None, None, None, None)
+ model_runner = ModelRunner(model_config=None,
+ parallel_config=None,
+ scheduler_config=None,
+ device_config=None,
+ load_config=None,
+ lora_config=None)
return input_tensor, fake_logits, sampler, model_runner
@@ -591,7 +596,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
- model_runner = ModelRunner(None, None, None, None, None)
+ model_runner = ModelRunner(model_config=None,
+ parallel_config=None,
+ scheduler_config=None,
+ device_config=None,
+ load_config=None,
+ lora_config=None)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py
index 4637826f254d6..edba4c226b289 100644
--- a/tests/spec_decode/utils.py
+++ b/tests/spec_decode/utils.py
@@ -118,6 +118,7 @@ def create_worker(cls: type,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
+ load_config=engine_config.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
diff --git a/tests/tensorizer/__init__.py b/tests/tensorizer_loader/__init__.py
similarity index 100%
rename from tests/tensorizer/__init__.py
rename to tests/tensorizer_loader/__init__.py
diff --git a/tests/tensorizer/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
similarity index 98%
rename from tests/tensorizer/tensorize_vllm_model_for_testing.py
rename to tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
index d0be08329fd64..e4b15fd57add4 100644
--- a/tests/tensorizer/tensorize_vllm_model_for_testing.py
+++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
@@ -16,8 +16,8 @@
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
+from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
from vllm.model_executor.models import ModelRegistry
-from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
@@ -74,7 +74,7 @@ def parse_args():
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
- parser = EngineArgs.add_cli_args(parser)
+ parser = TensorizerArgs.add_cli_args(EngineArgs.add_cli_args(parser))
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
diff --git a/tests/tensorizer/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py
similarity index 67%
rename from tests/tensorizer/test_tensorizer.py
rename to tests/tensorizer_loader/test_tensorizer.py
index 2ab893e95da9c..a97cc0b3706b4 100644
--- a/tests/tensorizer/test_tensorizer.py
+++ b/tests/tensorizer_loader/test_tensorizer.py
@@ -1,16 +1,19 @@
import gc
+import json
+import os
import subprocess
from unittest.mock import MagicMock, patch
+import openai
import pytest
+import ray
import torch
from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams
-from vllm.config import TensorizerConfig
-from vllm.model_executor.tensorizer_loader import (
- EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
- load_with_tensorizer, open_stream)
+from vllm.model_executor.model_loader.tensorizer import (
+ EncryptionParams, TensorizerConfig, TensorSerializer,
+ is_vllm_serialized_tensorizer, load_with_tensorizer, open_stream)
prompts = [
"Hello, my name is",
@@ -22,6 +25,8 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m"
+tensorize_model_for_testing_script = os.path.join(
+ os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def is_curl_installed():
@@ -38,7 +43,7 @@ def tensorizer_config():
return config
-@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
+@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value
@@ -81,11 +86,13 @@ def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
- loaded_vllm_model = vllm_runner(model_ref,
- load_format="tensorizer",
- tensorizer_uri=model_path,
- num_readers=1,
- vllm_tensorized=True)
+ loaded_vllm_model = vllm_runner(
+ model_ref,
+ load_format="tensorizer",
+ model_loader_extra_config=TensorizerConfig(tensorizer_uri=model_path,
+ num_readers=1,
+ vllm_tensorized=True),
+ )
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
@@ -97,14 +104,14 @@ def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
- loaded_hf_model = vllm_runner(
- model_ref,
- tensorizer_uri=tensorized_path,
- load_format="tensorizer",
- num_readers=1,
- vllm_tensorized=False,
- s3_endpoint="object.ord1.coreweave.com",
- )
+ loaded_hf_model = vllm_runner(model_ref,
+ load_format="tensorizer",
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri=tensorized_path,
+ num_readers=1,
+ vllm_tensorized=False,
+ s3_endpoint="object.ord1.coreweave.com",
+ ))
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
@@ -131,11 +138,12 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref,
- tensorizer_uri=model_path,
load_format="tensorizer",
- encryption_keyfile=key_path,
- num_readers=1,
- vllm_tensorized=True)
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri=model_path,
+ encryption_keyfile=key_path,
+ num_readers=1,
+ vllm_tensorized=True))
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
@@ -156,10 +164,11 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
gc.collect()
torch.cuda.empty_cache()
loaded_hf_model = vllm_runner(model_ref,
- tensorizer_uri=model_path,
load_format="tensorizer",
- num_readers=1,
- vllm_tensorized=False)
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri=model_path,
+ num_readers=1,
+ vllm_tensorized=False))
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
@@ -190,10 +199,12 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
model_ref,
- tensorizer_uri=model_path,
load_format="tensorizer",
- num_readers=1,
- vllm_tensorized=True,
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri=model_path,
+ num_readers=1,
+ vllm_tensorized=True,
+ ),
enable_lora=True,
max_loras=1,
max_lora_rank=8,
@@ -208,16 +219,18 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
def test_load_without_tensorizer_load_format(vllm_runner):
with pytest.raises(ValueError):
- vllm_runner(model_ref, tensorizer_uri="test")
+ vllm_runner(model_ref,
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri="test", vllm_tensorized=False))
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorize_vllm_model(tmp_path):
# Test serialize command
serialize_args = [
- "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
- model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
- tmp_path, "--suffix", "tests"
+ "python3", tensorize_model_for_testing_script, "--model", model_ref,
+ "--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
+ "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
@@ -229,8 +242,8 @@ def test_tensorize_vllm_model(tmp_path):
# Test deserialize command
deserialize_args = [
- "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
- model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
+ "python3", tensorize_model_for_testing_script, "--model", model_ref,
+ "--dtype", "float16", "deserialize", "--path-to-tensors",
path_to_tensors
]
result = subprocess.run(deserialize_args, capture_output=True, text=True)
@@ -242,9 +255,9 @@ def test_tensorize_vllm_model(tmp_path):
def test_openai_apiserver_with_tensorizer(tmp_path):
## Serialize model
serialize_args = [
- "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
- model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
- tmp_path, "--suffix", "tests"
+ "python3", tensorize_model_for_testing_script, "--model", model_ref,
+ "--dtype", "float16", "serialize", "--serialized-directory", tmp_path,
+ "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
@@ -253,25 +266,47 @@ def test_openai_apiserver_with_tensorizer(tmp_path):
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
+ model_loader_extra_config = {
+ "tensorizer_uri": path_to_tensors,
+ "vllm_tensorized": True
+ }
## Start OpenAI API server
openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format",
- "tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
- "--port", "8000"
+ "tensorizer", "--model-loader-extra-config",
+ json.dumps(model_loader_extra_config), "--port", "8000"
]
server = ServerRunner.remote(openai_args)
+ assert ray.get(server.ready.remote())
print("Server ready.")
- assert server.ready.remote()
+
+ client = openai.OpenAI(
+ base_url="http://localhost:8000/v1",
+ api_key="token-abc123",
+ )
+ completion = client.completions.create(model=model_ref,
+ prompt="Hello, my name is",
+ max_tokens=5,
+ temperature=0.0)
+
+ assert completion.id is not None
+ assert completion.choices is not None and len(completion.choices) == 1
+ assert completion.choices[0].text is not None and len(
+ completion.choices[0].text) >= 5
+ assert completion.choices[0].finish_reason == "length"
+ assert completion.usage == openai.types.CompletionUsage(
+ completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref,
load_format="safetensors",
- tensorizer_uri="test")
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri="test", vllm_tensorized=False))
def test_tensorizer_with_tp(vllm_runner):
@@ -281,22 +316,12 @@ def test_tensorizer_with_tp(vllm_runner):
vllm_runner(
model_ref,
- tensorizer_uri=tensorized_path,
load_format="tensorizer",
- num_readers=1,
- vllm_tensorized=False,
- s3_endpoint="object.ord1.coreweave.com",
+ model_loader_extra_config=TensorizerConfig(
+ tensorizer_uri=tensorized_path,
+ num_readers=1,
+ vllm_tensorized=False,
+ s3_endpoint="object.ord1.coreweave.com",
+ ),
tensor_parallel_size=2,
)
-
-
-@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
-def test_tensorizer_warn_quant(tmp_path):
- model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
- serialize_args = [
- "python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
- model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
- "serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
- ]
- result = subprocess.run(serialize_args, capture_output=True, text=True)
- assert 'PerformanceWarning' in result.stderr
diff --git a/tests/test_config.py b/tests/test_config.py
index 13a9f76212679..19db10630bbae 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -11,8 +11,6 @@ def test_get_sliding_window():
"Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
@@ -30,8 +28,6 @@ def test_get_sliding_window():
"mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py
index fe321520114f7..5bb93ca74855b 100644
--- a/tests/test_logits_processor.py
+++ b/tests/test_logits_processor.py
@@ -37,7 +37,12 @@ def _prepare_test(
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
- model_runner = ModelRunner(None, None, None, None, None)
+ model_runner = ModelRunner(model_config=None,
+ parallel_config=None,
+ scheduler_config=None,
+ device_config=None,
+ load_config=None,
+ lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner
diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py
index dcaae4af4a6f8..59bed2ce0dad3 100644
--- a/tests/worker/test_model_runner.py
+++ b/tests/worker/test_model_runner.py
@@ -12,7 +12,12 @@ def test_prepare_prompt(batch_size):
100000,
100000,
enable_chunked_prefill=False)
- model_runner = ModelRunner(None, None, scheduler_config, None, None)
+ model_runner = ModelRunner(model_config=None,
+ parallel_config=None,
+ scheduler_config=scheduler_config,
+ device_config=None,
+ load_config=None,
+ lora_config=None)
model_runner.set_block_size(16)
prompt_lens = []
@@ -118,8 +123,6 @@ def test_prepare_decode_cuda_graph(batch_size):
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
@@ -129,8 +132,12 @@ def test_prepare_decode_cuda_graph(batch_size):
100000,
100000,
enable_chunked_prefill=False)
- model_runner = ModelRunner(model_config, None, scheduler_config, None,
- None)
+ model_runner = ModelRunner(model_config=model_config,
+ parallel_config=None,
+ scheduler_config=scheduler_config,
+ device_config=None,
+ load_config=None,
+ lora_config=None)
model_runner.set_block_size(16)
prompt_lens = []
@@ -205,14 +212,17 @@ def test_empty_seq_group():
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
- model_runner = ModelRunner(model_config, None, None, None, None)
+ model_runner = ModelRunner(model_config=model_config,
+ parallel_config=None,
+ scheduler_config=None,
+ device_config=None,
+ load_config=None,
+ lora_config=None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
@@ -251,8 +261,6 @@ def mock_get_process_group_ranks(group=None):
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
- download_dir=None,
- load_format="dummy",
seed=0,
dtype="float16",
revision=None,
@@ -262,11 +270,12 @@ def mock_get_process_group_ranks(group=None):
100000,
100000,
enable_chunked_prefill=True)
- model_runner = ModelRunner(model_config,
- None,
- scheduler_config,
- None,
- None,
+ model_runner = ModelRunner(model_config=model_config,
+ parallel_config=None,
+ scheduler_config=scheduler_config,
+ device_config=None,
+ load_config=None,
+ lora_config=None,
is_driver_worker=True)
model_runner.set_block_size(16)
diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py
index 8edb1cf05c08e..1804cf78d8003 100644
--- a/tests/worker/test_swap.py
+++ b/tests/worker/test_swap.py
@@ -23,6 +23,7 @@ def test_swap() -> None:
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
+ load_config=engine_config.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
diff --git a/vllm/config.py b/vllm/config.py
index bf31b03b7c6c4..5a29620e85ac6 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1,9 +1,7 @@
import enum
-import io
import json
import os
-import typing
-from dataclasses import dataclass, fields
+from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch
@@ -18,10 +16,14 @@
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
- from vllm.model_executor.tensorizer_loader import TensorizerArgs
+ from vllm.model_executor.model_loader.loader import BaseModelLoader
logger = init_logger(__name__)
+# If true, will load models from ModelScope instead of Hugging Face Hub.
+VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE",
+ "False").lower() == "true"
+
_GB = 1 << 30
@@ -35,18 +37,6 @@ class ModelConfig:
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
- download_dir: Directory to download and load the weights, default to the
- default cache directory of huggingface.
- load_format: The format of the model weights to load:
- "auto" will try to load the weights in the safetensors format and
- fall back to the pytorch bin format if safetensors format is
- not available.
- "pt" will load the weights in the pytorch bin format.
- "safetensors" will load the weights in the safetensors format.
- "npcache" will load the weights in pytorch format and store
- a numpy cache to speed up the loading.
- "dummy" will initialize the weights with random values, which is
- mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
@@ -83,8 +73,6 @@ def __init__(
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
- download_dir: Optional[str],
- load_format: str,
dtype: Union[str, torch.dtype],
seed: int,
revision: Optional[str] = None,
@@ -101,8 +89,6 @@ def __init__(
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
- self.download_dir = download_dir
- self.load_format = load_format
self.seed = seed
self.revision = revision
self.code_revision = code_revision
@@ -113,64 +99,16 @@ def __init__(
self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs
- if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
- # download model from ModelScope hub,
- # lazy import so that modelscope is not required for normal use.
- # pylint: disable=C.
- from modelscope.hub.snapshot_download import snapshot_download
-
- if not os.path.exists(model):
- model_path = snapshot_download(model_id=model,
- cache_dir=download_dir,
- revision=revision)
- else:
- model_path = model
- self.model = model_path
- self.download_dir = model_path
- self.tokenizer = model_path
-
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
- self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
- def _verify_load_format(self) -> None:
- load_format = self.load_format.lower()
- supported_load_format = [
- "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
- ]
- rocm_not_supported_load_format: List[str] = []
- if load_format not in supported_load_format:
- raise ValueError(
- f"Unknown load format: {self.load_format}. Must be one of "
- "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
- "'dummy'.")
- if is_hip() and load_format in rocm_not_supported_load_format:
- rocm_supported_load_format = [
- f for f in supported_load_format
- if (f not in rocm_not_supported_load_format)
- ]
- raise ValueError(
- f"load format '{load_format}' is not supported in ROCm. "
- f"Supported load format are "
- f"{rocm_supported_load_format}")
-
- # TODO: Remove this check once HF updates the pt weights of Mixtral.
- architectures = getattr(self.hf_config, "architectures", [])
- # architectures can be None instead of []
- if architectures and "MixtralForCausalLM" in architectures \
- and load_format == "pt":
- raise ValueError(
- "Currently, the 'pt' format is not supported for Mixtral. "
- "Please use the 'safetensors' format instead. ")
- self.load_format = load_format
-
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
@@ -471,6 +409,65 @@ def create_config(
return tokenizer_pool_config
+class LoadFormat(str, enum.Enum):
+ AUTO = "auto"
+ PT = "pt"
+ SAFETENSORS = "safetensors"
+ NPCACHE = "npcache"
+ DUMMY = "dummy"
+ TENSORIZER = "tensorizer"
+
+
+@dataclass
+class LoadConfig:
+ """
+ download_dir: Directory to download and load the weights, default to the
+ default cache directory of huggingface.
+ load_format: The format of the model weights to load:
+ "auto" will try to load the weights in the safetensors format and
+ fall back to the pytorch bin format if safetensors format is
+ not available.
+ "pt" will load the weights in the pytorch bin format.
+ "safetensors" will load the weights in the safetensors format.
+ "npcache" will load the weights in pytorch format and store
+ a numpy cache to speed up the loading.
+ "dummy" will initialize the weights with random values, which is
+ mainly for profiling.
+ "tensorizer" will use CoreWeave's tensorizer library for
+ fast weight loading.
+ """
+
+ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
+ download_dir: Optional[str] = None
+ model_loader_extra_config: Optional[Union[str, dict]] = field(
+ default_factory=dict)
+
+ def __post_init__(self):
+ model_loader_extra_config = self.model_loader_extra_config or {}
+ if isinstance(model_loader_extra_config, str):
+ self.model_loader_extra_config = json.loads(
+ model_loader_extra_config)
+ self._verify_load_format()
+
+ def _verify_load_format(self) -> None:
+ if not isinstance(self.load_format, str):
+ return
+
+ load_format = self.load_format.lower()
+ self.load_format = LoadFormat(load_format)
+
+ rocm_not_supported_load_format: List[str] = []
+ if is_hip() and load_format in rocm_not_supported_load_format:
+ rocm_supported_load_format = [
+ f for f in LoadFormat.__members__
+ if (f not in rocm_not_supported_load_format)
+ ]
+ raise ValueError(
+ f"load format '{load_format}' is not supported in ROCm. "
+ f"Supported load formats are "
+ f"{rocm_supported_load_format}")
+
+
class ParallelConfig:
"""Configuration for the distributed execution.
@@ -699,8 +696,6 @@ def maybe_create_spec_config(
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
- download_dir=target_model_config.download_dir,
- load_format=target_model_config.load_format,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
@@ -887,65 +882,6 @@ def get_image_input_enum_type(
f"{[x.name for x in cls.ImageInputType]}.") from e
-@dataclass
-class TensorizerConfig:
- tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
- str, bytes, os.PathLike, int]
- vllm_tensorized: bool
- verify_hash: Optional[bool] = False
- num_readers: Optional[int] = 1
- encryption_keyfile: Optional[str] = None
- s3_access_key_id: Optional[str] = None
- s3_secret_access_key: Optional[str] = None
- s3_endpoint: Optional[str] = None
- model_class: Optional[torch.nn.Module] = None
- hf_config: Optional[PretrainedConfig] = None
- dtype: Union[str, torch.dtype] = None
-
- def _construct_tensorizer_args(self) -> "TensorizerArgs":
- from vllm.model_executor.tensorizer_loader import TensorizerArgs
- tensorizer_args = {
- "tensorizer_uri": self.tensorizer_uri,
- "vllm_tensorized": self.vllm_tensorized,
- "verify_hash": self.verify_hash,
- "num_readers": self.num_readers,
- "encryption_keyfile": self.encryption_keyfile,
- "s3_access_key_id": self.s3_access_key_id,
- "s3_secret_access_key": self.s3_secret_access_key,
- "s3_endpoint": self.s3_endpoint,
- }
- return TensorizerArgs(**tensorizer_args)
-
- def verify_with_parallel_config(
- self,
- parallel_config: "ParallelConfig",
- ) -> None:
- if (parallel_config.tensor_parallel_size > 1
- and self.tensorizer_uri is not None):
- raise ValueError(
- "Loading to multiple GPUs is not currently supported with "
- "vLLM-serialized models. Please set tensor_parallel_size=1."
- " or use a non-vLLM-serialized model, such as a "
- "serialized Hugging Face `PretrainedModel`.")
-
- def verify_with_model_config(self, model_config) -> None:
- if (model_config.quantization is not None
- and self.tensorizer_uri is not None):
- from vllm.model_executor.tensorizer_loader import (
- tensorizer_warning)
- tensorizer_warning(
- "Loading a model using Tensorizer with quantization on vLLM"
- " is unstable and may lead to errors.")
-
- if (model_config.load_format != "tensorizer"
- and self.tensorizer_uri is not None):
- raise ValueError(
- "A tensorizer uri was passed for tensorizer loading, but the "
- f"load format was set to {model_config.load_format}. "
- "Please set the load format to 'tensorizer' to use "
- f"tensorizer args.")
-
-
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
@@ -1105,11 +1041,11 @@ class EngineConfig:
parallel_config: ParallelConfig
scheduler_config: SchedulerConfig
device_config: DeviceConfig
+ load_config: LoadConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
- tensorizer_config: Optional[TensorizerConfig]
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
@@ -1117,11 +1053,6 @@ def __post_init__(self):
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
- if self.tensorizer_config:
- self.tensorizer_config.verify_with_parallel_config(
- self.parallel_config)
- self.tensorizer_config.verify_with_model_config(self.model_config)
-
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 3de74b0ac28b9..c61c0cc67d7a2 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -1,15 +1,12 @@
import argparse
import dataclasses
-import io
-import os
from dataclasses import dataclass
-from typing import BinaryIO, Optional, Union
+from typing import Optional
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
- EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig, SpeculativeConfig, TensorizerConfig,
+ EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
+ ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig)
-from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple
@@ -60,17 +57,7 @@ class EngineArgs:
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
-
- # Tensorizer configuration parameters
- tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
- bytes, os.PathLike, int] = None
- vllm_tensorized: bool = False
- verify_hash: Optional[bool] = False
- num_readers: Optional[int] = 1
- encryption_keyfile: Optional[str] = None
- s3_access_key_id: Optional[str] = None
- s3_secret_access_key: Optional[str] = None
- s3_endpoint: Optional[str] = None
+ model_loader_extra_config: Optional[dict] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
@@ -429,7 +416,16 @@ def add_cli_args(
default=None,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding')
- parser = TensorizerArgs.add_cli_args(parser)
+
+ parser.add_argument('--model-loader-extra-config',
+ type=str,
+ default=EngineArgs.model_loader_extra_config,
+ help='Extra config for model loader. '
+ 'This will be passed to the model loader '
+ 'corresponding to the chosen load_format. '
+ 'This should be a JSON string that will be '
+ 'parsed into a dictionary.')
+
return parser
@classmethod
@@ -444,11 +440,11 @@ def create_engine_config(self, ) -> EngineConfig:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
- self.trust_remote_code, self.download_dir, self.load_format,
- self.dtype, self.seed, self.revision, self.code_revision,
- self.tokenizer_revision, self.max_model_len, self.quantization,
- self.quantization_param_path, self.enforce_eager,
- self.max_context_len_to_capture, self.max_logprobs)
+ self.trust_remote_code, self.dtype, self.seed, self.revision,
+ self.code_revision, self.tokenizer_revision, self.max_model_len,
+ self.quantization, self.quantization_param_path,
+ self.enforce_eager, self.max_context_len_to_capture,
+ self.max_logprobs)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
@@ -492,15 +488,10 @@ def create_engine_config(self, ) -> EngineConfig:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
- tensorizer_config = TensorizerConfig(
- tensorizer_uri=self.tensorizer_uri,
- vllm_tensorized=self.vllm_tensorized,
- verify_hash=self.verify_hash,
- num_readers=self.num_readers,
- encryption_keyfile=self.encryption_keyfile,
- s3_access_key_id=self.s3_access_key_id,
- s3_secret_access_key=self.s3_secret_access_key,
- s3_endpoint=self.s3_endpoint,
+ load_config = LoadConfig(
+ load_format=self.load_format,
+ download_dir=self.download_dir,
+ model_loader_extra_config=self.model_loader_extra_config,
)
if self.image_input_type:
@@ -530,8 +521,8 @@ def create_engine_config(self, ) -> EngineConfig:
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
- decoding_config=decoding_config,
- tensorizer_config=tensorizer_config)
+ load_config=load_config,
+ decoding_config=decoding_config)
@dataclass
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index f06c1d18ace4b..563694946d16e 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -4,9 +4,9 @@
from transformers import PreTrainedTokenizer
import vllm
-from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
- ModelConfig, ParallelConfig, SchedulerConfig,
- SpeculativeConfig, TensorizerConfig,
+from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
+ LoRAConfig, ModelConfig, ParallelConfig,
+ SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
@@ -72,11 +72,11 @@ def __init__(
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
+ load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
- tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@@ -92,8 +92,8 @@ def __init__(
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
- f"download_dir={model_config.download_dir!r}, "
- f"load_format={model_config.load_format}, "
+ f"download_dir={load_config.download_dir!r}, "
+ f"load_format={load_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce="
f"{parallel_config.disable_custom_all_reduce}, "
@@ -114,8 +114,8 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
+ self.load_config = load_config
self.decoding_config = decoding_config or DecodingConfig()
- self.tensorizer_config = tensorizer_config
self.log_stats = log_stats
self._init_tokenizer()
@@ -131,7 +131,7 @@ def __init__(
lora_config=lora_config,
vision_language_config=vision_language_config,
speculative_config=speculative_config,
- tensorizer_config=tensorizer_config,
+ load_config=load_config,
)
self._initialize_kv_caches()
@@ -271,9 +271,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
- if self.tensorizer_config:
- self.tensorizer_config.verify_with_parallel_config(
- self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index f562e4e0ae3de..426e2c41d8427 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -40,6 +40,7 @@ def _init_worker(self):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
+ load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index bbb6ec80f7b7e..8cc04c5299ca1 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, SpeculativeConfig,
- TensorizerConfig, VisionLanguageConfig)
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+ ModelConfig, ParallelConfig, SchedulerConfig,
+ SpeculativeConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@@ -23,20 +23,20 @@ def __init__(
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
+ load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
- tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
+ self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
- self.tensorizer_config = tensorizer_config
self._init_executor()
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index bae509f48025b..3a9537effe6d9 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -35,12 +35,12 @@ def _init_worker(self):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
+ load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
- tensorizer_config=self.tensorizer_config,
is_driver_worker=True,
)
self.driver_worker.init_device()
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 7aca5e36107aa..4065c0868d79a 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -147,6 +147,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
+ load_config = copy.deepcopy(self.load_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
@@ -165,12 +166,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
+ load_config=load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
- tensorizer_config=self.tensorizer_config,
))
# Initialize the driver worker with the Worker class.
@@ -187,7 +188,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
- tensorizer_config=self.tensorizer_config,
+ load_config=self.load_config,
is_driver_worker=True,
)
diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py
deleted file mode 100644
index c70ca48bca70a..0000000000000
--- a/vllm/model_executor/model_loader.py
+++ /dev/null
@@ -1,128 +0,0 @@
-"""Utilities for selecting and loading models."""
-import contextlib
-from typing import Tuple, Type
-
-import torch
-from torch import nn
-
-from vllm.config import DeviceConfig, ModelConfig
-from vllm.model_executor.models import ModelRegistry
-from vllm.model_executor.models.llava import LlavaForConditionalGeneration
-from vllm.model_executor.tensorizer_loader import (
- ParameterizedLoadFormat, is_vllm_serialized_tensorizer,
- load_with_tensorizer)
-from vllm.model_executor.weight_utils import (get_quant_config,
- initialize_dummy_weights)
-
-_VISION_MODEL_CLASSES = [
- LlavaForConditionalGeneration,
-]
-
-
-@contextlib.contextmanager
-def _set_default_torch_dtype(dtype: torch.dtype):
- """Sets the default torch dtype to the given dtype."""
- old_dtype = torch.get_default_dtype()
- torch.set_default_dtype(dtype)
- yield
- torch.set_default_dtype(old_dtype)
-
-
-def _get_model_architecture(
- model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
- architectures = getattr(model_config.hf_config, "architectures", [])
- # Special handling for quantized Mixtral.
- # FIXME(woosuk): This is a temporary hack.
- if (model_config.quantization is not None
- and "MixtralForCausalLM" in architectures):
- architectures = ["QuantMixtralForCausalLM"]
-
- for arch in architectures:
- model_cls = ModelRegistry.load_model_cls(arch)
- if model_cls is not None:
- return (model_cls, arch)
- raise ValueError(
- f"Model architectures {architectures} are not supported for now. "
- f"Supported architectures: {ModelRegistry.get_supported_archs()}")
-
-
-def get_architecture_class_name(model_config: ModelConfig) -> str:
- return _get_model_architecture(model_config)[1]
-
-
-def get_model(model_config: ModelConfig, device_config: DeviceConfig,
- **kwargs) -> nn.Module:
- lora_config = kwargs.get("lora_config", None)
- vision_language_config = kwargs.get("vision_language_config", None)
- tensorizer_config = kwargs.get("tensorizer_config", None)
- model_class = _get_model_architecture(model_config)[0]
-
- # Get the (maybe quantized) linear method.
- linear_method = None
- if model_config.quantization is not None:
- quant_config = get_quant_config(model_config)
- capability = torch.cuda.get_device_capability()
- capability = capability[0] * 10 + capability[1]
- if capability < quant_config.get_min_capability():
- raise ValueError(
- f"The quantization method {model_config.quantization} is not "
- "supported for the current GPU. "
- f"Minimum capability: {quant_config.get_min_capability()}. "
- f"Current capability: {capability}.")
- supported_dtypes = quant_config.get_supported_act_dtypes()
- if model_config.dtype not in supported_dtypes:
- raise ValueError(
- f"{model_config.dtype} is not supported for quantization "
- f"method {model_config.quantization}. Supported dtypes: "
- f"{supported_dtypes}")
-
- linear_method = quant_config.get_linear_method()
-
- with _set_default_torch_dtype(model_config.dtype):
- # Create a model instance.
- # The weights will be initialized as empty tensors.
- extra_kwargs = {}
- if hasattr(model_class, "supported_lora_modules"):
- extra_kwargs["lora_config"] = lora_config
- elif lora_config:
- raise ValueError(
- f"Model {model_class.__name__} does not support LoRA, "
- "but LoRA is enabled. Support for this model may "
- "be added in the future. If this is important to you, "
- "please open an issue on github.")
- elif model_class in _VISION_MODEL_CLASSES:
- extra_kwargs["vision_language_config"] = vision_language_config
-
- with torch.device(device_config.device):
- if (model_config.load_format == "tensorizer"
- and is_vllm_serialized_tensorizer(tensorizer_config)):
- extra_kwargs["linear_method"] = linear_method
- tensorizer_config.model_class = model_class
- tensorizer_config.hf_config = model_config.hf_config
- tensorizer_config.dtype = model_config.dtype
- model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
- return model.eval()
- model = model_class(config=model_config.hf_config,
- linear_method=linear_method,
- **extra_kwargs)
- if model_config.load_format == "dummy":
- # NOTE(woosuk): For accurate performance evaluation, we assign
- # random values to the weights.
- initialize_dummy_weights(model)
- else:
- # Load the weights from the cached or downloaded files.
- if model_config.load_format == "tensorizer":
- # Provide a dynamic load format for `model.load_weights`
- # to retain tensorizer args from CLI.
- model_config.load_format = ParameterizedLoadFormat(
- model_config.load_format)
- model_config.load_format.params = (
- tensorizer_config._construct_tensorizer_args())
-
- model.load_weights(
- model_config.model,
- model_config.download_dir,
- model_config.load_format,
- model_config.revision,
- )
- return model.eval()
diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py
new file mode 100644
index 0000000000000..6f90e49994fb2
--- /dev/null
+++ b/vllm/model_executor/model_loader/__init__.py
@@ -0,0 +1,30 @@
+from typing import Optional
+
+from torch import nn
+
+from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
+ ParallelConfig, SchedulerConfig, VisionLanguageConfig)
+from vllm.model_executor.model_loader.loader import (BaseModelLoader,
+ get_model_loader)
+from vllm.model_executor.model_loader.utils import (
+ get_architecture_class_name, get_model_architecture)
+
+
+def get_model(
+ *, model_config: ModelConfig, load_config: LoadConfig,
+ device_config: DeviceConfig, parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
+ loader = get_model_loader(load_config)
+ return loader.load_model(model_config=model_config,
+ device_config=device_config,
+ lora_config=lora_config,
+ vision_language_config=vision_language_config,
+ parallel_config=parallel_config,
+ scheduler_config=scheduler_config)
+
+
+__all__ = [
+ "get_model", "get_model_loader", "BaseModelLoader",
+ "get_architecture_class_name", "get_model_architecture"
+]
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
new file mode 100644
index 0000000000000..3b1d125ef8a67
--- /dev/null
+++ b/vllm/model_executor/model_loader/loader.py
@@ -0,0 +1,354 @@
+# ruff: noqa: SIM117
+import copy
+import glob
+import os
+from abc import ABC, abstractmethod
+from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
+ Type)
+
+import torch
+from torch import nn
+
+from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
+ LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
+ SchedulerConfig, VisionLanguageConfig)
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader.tensorizer import (
+ TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
+ tensorizer_weights_iterator)
+from vllm.model_executor.model_loader.utils import (get_model_architecture,
+ set_default_torch_dtype)
+from vllm.model_executor.model_loader.weight_utils import (
+ download_weights_from_hf, filter_files_not_needed_for_inference,
+ get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
+ pt_weights_iterator, safetensors_weights_iterator)
+from vllm.model_executor.models.llava import LlavaForConditionalGeneration
+
+if TYPE_CHECKING:
+ from vllm.model_executor.layers.linear import LinearMethodBase
+
+_VISION_MODEL_CLASSES = [
+ LlavaForConditionalGeneration,
+]
+
+logger = init_logger(__name__)
+
+
+def _get_linear_method(
+ model_config: ModelConfig,
+ load_config: LoadConfig) -> Optional["LinearMethodBase"]:
+ """Get the (maybe quantized) linear method."""
+ linear_method = None
+ if model_config.quantization is not None:
+ quant_config = get_quant_config(model_config, load_config)
+ capability = torch.cuda.get_device_capability()
+ capability = capability[0] * 10 + capability[1]
+ if capability < quant_config.get_min_capability():
+ raise ValueError(
+ f"The quantization method {model_config.quantization} is not "
+ "supported for the current GPU. "
+ f"Minimum capability: {quant_config.get_min_capability()}. "
+ f"Current capability: {capability}.")
+ supported_dtypes = quant_config.get_supported_act_dtypes()
+ if model_config.dtype not in supported_dtypes:
+ raise ValueError(
+ f"{model_config.dtype} is not supported for quantization "
+ f"method {model_config.quantization}. Supported dtypes: "
+ f"{supported_dtypes}")
+
+ linear_method = quant_config.get_linear_method()
+ return linear_method
+
+
+def _get_model_initialization_kwargs(
+ model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig]
+) -> Dict[str, Any]:
+ """Get extra kwargs for model initialization."""
+ extra_kwargs = {}
+ if hasattr(model_class, "supported_lora_modules"):
+ extra_kwargs["lora_config"] = lora_config
+ elif lora_config:
+ raise ValueError(
+ f"Model {model_class.__name__} does not support LoRA, "
+ "but LoRA is enabled. Support for this model may "
+ "be added in the future. If this is important to you, "
+ "please open an issue on github.")
+ elif model_class in _VISION_MODEL_CLASSES:
+ extra_kwargs["vision_language_config"] = vision_language_config
+ return extra_kwargs
+
+
+def _initialize_model(
+ model_config: ModelConfig, load_config: LoadConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
+ """Initialize a model with the given configurations."""
+ model_class = get_model_architecture(model_config)[0]
+ linear_method = _get_linear_method(model_config, load_config)
+
+ return model_class(config=model_config.hf_config,
+ linear_method=linear_method,
+ **_get_model_initialization_kwargs(
+ model_class, lora_config, vision_language_config))
+
+
+class BaseModelLoader(ABC):
+ """Base class for model loaders."""
+
+ def __init__(self, load_config: LoadConfig):
+ self.load_config = load_config
+
+ @abstractmethod
+ def load_model(self, *, model_config: ModelConfig,
+ device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig],
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig) -> nn.Module:
+ """Load a model with the given configurations."""
+ ...
+
+
+class DefaultModelLoader(BaseModelLoader):
+ """Model loader that can load different file types from disk."""
+
+ def __init__(self, load_config: LoadConfig):
+ super().__init__(load_config)
+ if load_config.model_loader_extra_config:
+ raise ValueError(f"Model loader extra config is not supported for "
+ f"load format {load_config.load_format}")
+
+ def _maybe_download_from_modelscope(
+ self, model: str, revision: Optional[str]) -> Optional[str]:
+ """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
+
+ Returns the path to the downloaded model, or None if the model is not
+ downloaded from ModelScope."""
+ if VLLM_USE_MODELSCOPE:
+ # download model from ModelScope hub,
+ # lazy import so that modelscope is not required for normal use.
+ # pylint: disable=C.
+ from modelscope.hub.snapshot_download import snapshot_download
+
+ if not os.path.exists(model):
+ model_path = snapshot_download(
+ model_id=model,
+ cache_dir=self.load_config.download_dir,
+ revision=revision)
+ else:
+ model_path = model
+ return model_path
+ return None
+
+ def _prepare_weights(self, model_name_or_path: str,
+ revision: Optional[str],
+ fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
+ """Prepare weights for the model.
+
+ If the model is not local, it will be downloaded."""
+ model_name_or_path = self._maybe_download_from_modelscope(
+ model_name_or_path, revision) or model_name_or_path
+
+ is_local = os.path.isdir(model_name_or_path)
+ load_format = self.load_config.load_format
+ use_safetensors = False
+ # Some quantized models use .pt files for storing the weights.
+ if load_format == LoadFormat.AUTO:
+ allow_patterns = ["*.safetensors", "*.bin"]
+ elif load_format == LoadFormat.SAFETENSORS:
+ use_safetensors = True
+ allow_patterns = ["*.safetensors"]
+ elif load_format == LoadFormat.PT:
+ allow_patterns = ["*.pt"]
+ elif load_format == LoadFormat.NPCACHE:
+ allow_patterns = ["*.bin"]
+ else:
+ raise ValueError(f"Unknown load_format: {load_format}")
+
+ if fall_back_to_pt:
+ allow_patterns += ["*.pt"]
+
+ if not is_local:
+ hf_folder = download_weights_from_hf(model_name_or_path,
+ self.load_config.download_dir,
+ allow_patterns)
+ else:
+ hf_folder = model_name_or_path
+
+ hf_weights_files: List[str] = []
+ for pattern in allow_patterns:
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
+ if len(hf_weights_files) > 0:
+ if pattern == "*.safetensors":
+ use_safetensors = True
+ break
+
+ if not use_safetensors:
+ hf_weights_files = filter_files_not_needed_for_inference(
+ hf_weights_files)
+
+ if len(hf_weights_files) == 0:
+ raise RuntimeError(
+ f"Cannot find any model weights with `{model_name_or_path}`")
+
+ return hf_folder, hf_weights_files, use_safetensors
+
+ def _get_weights_iterator(
+ self, model_name_or_path: str, revision: Optional[str],
+ fall_back_to_pt: bool
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
+ """Get an iterator for the model weights based on the load format."""
+ hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
+ model_name_or_path, revision, fall_back_to_pt)
+ if self.load_config.load_format == LoadFormat.NPCACHE:
+ # Currently np_cache only support *.bin checkpoints
+ assert use_safetensors is False
+ return np_cache_weights_iterator(model_name_or_path,
+ self.load_config.download_dir,
+ hf_folder, hf_weights_files)
+ if use_safetensors:
+ return safetensors_weights_iterator(hf_weights_files)
+ return pt_weights_iterator(hf_weights_files)
+
+ def load_model(self, *, model_config: ModelConfig,
+ device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig],
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig) -> nn.Module:
+ with set_default_torch_dtype(model_config.dtype):
+ with torch.device(device_config.device):
+ model = _initialize_model(model_config, self.load_config,
+ lora_config, vision_language_config)
+ model.load_weights(
+ self._get_weights_iterator(model_config.model,
+ model_config.revision,
+ fall_back_to_pt=getattr(
+ model,
+ "fall_back_to_pt_during_load",
+ True)), )
+ return model.eval()
+
+
+class DummyModelLoader(BaseModelLoader):
+ """Model loader that will set model weights to random values."""
+
+ def __init__(self, load_config: LoadConfig):
+ super().__init__(load_config)
+ if load_config.model_loader_extra_config:
+ raise ValueError(f"Model loader extra config is not supported for "
+ f"load format {load_config.load_format}")
+
+ def load_model(self, *, model_config: ModelConfig,
+ device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig],
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig) -> nn.Module:
+ with set_default_torch_dtype(model_config.dtype):
+ with torch.device(device_config.device):
+ model = _initialize_model(model_config, self.load_config,
+ lora_config, vision_language_config)
+ # NOTE(woosuk): For accurate performance evaluation, we assign
+ # random values to the weights.
+ initialize_dummy_weights(model)
+ return model.eval()
+
+
+class TensorizerLoader(BaseModelLoader):
+ """Model loader using CoreWeave's tensorizer library."""
+
+ def __init__(self, load_config: LoadConfig):
+ super().__init__(load_config)
+ if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
+ self.tensorizer_config = load_config.model_loader_extra_config
+ else:
+ self.tensorizer_config = TensorizerConfig(
+ **load_config.model_loader_extra_config)
+
+ def _verify_config(self, model_config: ModelConfig,
+ parallel_config: ParallelConfig):
+ self.tensorizer_config.verify_with_model_config(model_config)
+ self.tensorizer_config.verify_with_parallel_config(parallel_config)
+
+ def _get_weights_iterator(
+ self) -> Generator[Tuple[str, torch.Tensor], None, None]:
+ tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
+ return tensorizer_weights_iterator(tensorizer_args)
+
+ def _load_model_unserialized(
+ self, model_config: ModelConfig, device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig]
+ ) -> nn.Module:
+ """Load an unserialized model with tensorizer.
+
+ Unserialized here means "not serialized with tensorizer". This
+ should still be faster than default HuggingFace loading, but will
+ be slower than loading a tensorizer-serialized model.
+ """
+ with set_default_torch_dtype(model_config.dtype):
+ with torch.device(device_config.device):
+ model = _initialize_model(model_config, self.load_config,
+ lora_config, vision_language_config)
+
+ model.load_weights(self._get_weights_iterator())
+ return model.eval()
+
+ def _load_model_serialized(
+ self, model_config: ModelConfig, device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig]
+ ) -> nn.Module:
+ """Load a serialized model with tensorizer.
+
+ See the examples/tensorize_vllm_model.py example "
+ script for serializing vLLM models."""
+ with set_default_torch_dtype(model_config.dtype):
+ with torch.device(device_config.device):
+ model_class = get_model_architecture(model_config)[0]
+ linear_method = _get_linear_method(model_config,
+ self.load_config)
+ extra_kwargs = _get_model_initialization_kwargs(
+ model_class, lora_config, vision_language_config)
+ extra_kwargs["linear_method"] = linear_method
+
+ tensorizer_config = copy.copy(self.tensorizer_config)
+ tensorizer_config.model_class = model_class
+ tensorizer_config.hf_config = model_config.hf_config
+ tensorizer_config.dtype = model_config.dtype
+
+ model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
+ return model.eval()
+
+ def load_model(self, *, model_config: ModelConfig,
+ device_config: DeviceConfig,
+ lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig],
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig) -> nn.Module:
+ self._verify_config(model_config, parallel_config)
+
+ if is_vllm_serialized_tensorizer(self.tensorizer_config):
+ return self._load_model_serialized(model_config, device_config,
+ lora_config,
+ vision_language_config)
+ return self._load_model_unserialized(model_config, device_config,
+ lora_config,
+ vision_language_config)
+
+
+def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
+ """Get a model loader based on the load format."""
+
+ if isinstance(load_config.load_format, type):
+ return load_config.load_format(load_config)
+
+ if load_config.load_format == LoadFormat.DUMMY:
+ return DummyModelLoader(load_config)
+
+ if load_config.load_format == LoadFormat.TENSORIZER:
+ return TensorizerLoader(load_config)
+
+ return DefaultModelLoader(load_config)
diff --git a/vllm/model_executor/neuron_model_loader.py b/vllm/model_executor/model_loader/neuron.py
similarity index 100%
rename from vllm/model_executor/neuron_model_loader.py
rename to vllm/model_executor/model_loader/neuron.py
diff --git a/vllm/model_executor/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer.py
similarity index 78%
rename from vllm/model_executor/tensorizer_loader.py
rename to vllm/model_executor/model_loader/tensorizer.py
index 8550cc97aefe8..ad554844384eb 100644
--- a/vllm/model_executor/tensorizer_loader.py
+++ b/vllm/model_executor/model_loader/tensorizer.py
@@ -4,20 +4,20 @@
import os
import time
import typing
-import warnings
from dataclasses import dataclass
-from typing import Optional, Union
+from typing import Generator, Optional, Tuple, Type, Union
import torch
from torch import nn
+from transformers import PretrainedConfig
-from vllm.config import TensorizerConfig
+from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
-tensorizer_load_fail = False
+tensorizer_load_fail = None
try:
from tensorizer import (DecryptionParams, EncryptionParams,
@@ -25,51 +25,78 @@
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
-except ImportError:
- tensorizer_load_fail = True
+except ImportError as e:
+ tensorizer_load_fail = e
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
- 'no_init_or_tensor'
+ 'no_init_or_tensor', 'TensorizerConfig'
]
logger = init_logger(__name__)
+@dataclass
+class TensorizerConfig:
+ tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
+ str, bytes, os.PathLike, int]
+ vllm_tensorized: bool
+ verify_hash: Optional[bool] = False
+ num_readers: Optional[int] = 1
+ encryption_keyfile: Optional[str] = None
+ s3_access_key_id: Optional[str] = None
+ s3_secret_access_key: Optional[str] = None
+ s3_endpoint: Optional[str] = None
+ model_class: Optional[Type[torch.nn.Module]] = None
+ hf_config: Optional[PretrainedConfig] = None
+ dtype: Optional[Union[str, torch.dtype]] = None
+
+ def _construct_tensorizer_args(self) -> "TensorizerArgs":
+ tensorizer_args = {
+ "tensorizer_uri": self.tensorizer_uri,
+ "vllm_tensorized": self.vllm_tensorized,
+ "verify_hash": self.verify_hash,
+ "num_readers": self.num_readers,
+ "encryption_keyfile": self.encryption_keyfile,
+ "s3_access_key_id": self.s3_access_key_id,
+ "s3_secret_access_key": self.s3_secret_access_key,
+ "s3_endpoint": self.s3_endpoint,
+ }
+ return TensorizerArgs(**tensorizer_args)
+
+ def verify_with_parallel_config(
+ self,
+ parallel_config: "ParallelConfig",
+ ) -> None:
+ if (parallel_config.tensor_parallel_size > 1
+ and self.tensorizer_uri is not None):
+ raise ValueError(
+ "Loading to multiple GPUs is not currently supported with "
+ "vLLM-serialized models. Please set tensor_parallel_size=1."
+ " or use a non-vLLM-serialized model, such as a "
+ "serialized Hugging Face `PretrainedModel`.")
+
+ def verify_with_model_config(self, model_config: "ModelConfig") -> None:
+ if (model_config.quantization is not None
+ and self.tensorizer_uri is not None):
+ logger.warning(
+ "Loading a model using Tensorizer with quantization on vLLM"
+ " is unstable and may lead to errors.")
+
+
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module:
tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs)
return tensorizer.deserialize()
-def tensorizer_warning(message: str):
- return warnings.warn(message, category=PerformanceWarning, stacklevel=2)
-
-
def is_vllm_serialized_tensorizer(tensorizer_config: TensorizerConfig) -> bool:
if tensorizer_config is None:
return False
return tensorizer_config.vllm_tensorized
-class ParameterizedLoadFormat(str):
- __slots__ = "params"
-
-
-class PerformanceWarning(UserWarning):
-
- def __str__(self):
- return (f"{super().__str__()}"
- " (set the VLLM_SILENCE_PERFORMANCE_WARNINGS"
- " environment variable to hide this)")
-
-
-if (os.getenv("VLLM_SILENCE_PERFORMANCE_WARNINGS", "").lower()
- not in ("", "0", "n", "no", "off", "disable")):
- warnings.simplefilter("ignore", category=PerformanceWarning)
-
-
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
@@ -219,11 +246,17 @@ class TensorizerAgent:
behavior of the TensorDeserializer when loading tensors from a serialized
model. For deserializations of HuggingFace models, TensorDeserializer is
instead used as an iterator directly in the func hf_model_weights_iterator
- in vllm/model_executor/weight_utils.py
+ in vllm/model_executor/model_loader/weight_utils.py
"""
def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs):
+ if tensorizer_load_fail is not None:
+ raise ImportError(
+ "Tensorizer is not installed. Please install tensorizer "
+ "to use this feature with `pip install vllm[tensorizer]`."
+ ) from tensorizer_load_fail
+
self.tensorizer_config = tensorizer_config
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
@@ -234,11 +267,6 @@ def __init__(self, tensorizer_config: TensorizerConfig,
self.linear_method = linear_method
self.model = self._init_model()
- if tensorizer_load_fail:
- raise ImportError(
- "Tensorizer is not installed. Please install tensorizer "
- "to use this feature with `pip install vllm[tensorizer]`.")
-
def _init_model(self):
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
@@ -313,3 +341,23 @@ def deserialize(self):
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
return self.model.eval()
+
+
+def tensorizer_weights_iterator(
+ tensorizer_args: "TensorizerArgs"
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+ logger.warning(
+ "Deserializing HuggingFace models is not optimized for "
+ "loading on vLLM, as tensorizer is forced to load to CPU. "
+ "Consider deserializing a vLLM model instead for faster "
+ "load times. See the examples/tensorize_vllm_model.py example "
+ "script for serializing vLLM models.")
+
+ deserializer_args = tensorizer_args.deserializer_params
+ stream_params = tensorizer_args.stream_params
+ stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
+ with TensorDeserializer(stream, **deserializer_args,
+ device="cpu") as state:
+ for name, param in state.items():
+ yield name, param
+ del state
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
new file mode 100644
index 0000000000000..a0a3b2784614d
--- /dev/null
+++ b/vllm/model_executor/model_loader/utils.py
@@ -0,0 +1,40 @@
+"""Utilities for selecting and loading models."""
+import contextlib
+from typing import Tuple, Type
+
+import torch
+from torch import nn
+
+from vllm.config import ModelConfig
+from vllm.model_executor.models import ModelRegistry
+
+
+@contextlib.contextmanager
+def set_default_torch_dtype(dtype: torch.dtype):
+ """Sets the default torch dtype to the given dtype."""
+ old_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(dtype)
+ yield
+ torch.set_default_dtype(old_dtype)
+
+
+def get_model_architecture(
+ model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
+ architectures = getattr(model_config.hf_config, "architectures", [])
+ # Special handling for quantized Mixtral.
+ # FIXME(woosuk): This is a temporary hack.
+ if (model_config.quantization is not None
+ and "MixtralForCausalLM" in architectures):
+ architectures = ["QuantMixtralForCausalLM"]
+
+ for arch in architectures:
+ model_cls = ModelRegistry.load_model_cls(arch)
+ if model_cls is not None:
+ return (model_cls, arch)
+ raise ValueError(
+ f"Model architectures {architectures} are not supported for now. "
+ f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def get_architecture_class_name(model_config: ModelConfig) -> str:
+ return get_model_architecture(model_config)[1]
diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
similarity index 53%
rename from vllm/model_executor/weight_utils.py
rename to vllm/model_executor/model_loader/weight_utils.py
index 08425604f0511..1798db0136868 100644
--- a/vllm/model_executor/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -4,8 +4,9 @@
import hashlib
import json
import os
+import tempfile
from collections import defaultdict
-from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
+from typing import Any, Generator, Iterable, List, Optional, Tuple
import filelock
import huggingface_hub.constants
@@ -15,7 +16,7 @@
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
-from vllm.config import ModelConfig
+from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
@@ -27,8 +28,7 @@
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
-temp_dir = os.environ.get('TMPDIR') or os.environ.get(
- 'TEMP') or os.environ.get('TMP') or "/tmp/"
+temp_dir = tempfile.gettempdir()
def enable_hf_transfer():
@@ -46,7 +46,7 @@ def enable_hf_transfer():
enable_hf_transfer()
-class Disabledtqdm(tqdm):
+class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
@@ -114,7 +114,8 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
-def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
+def get_quant_config(model_config: ModelConfig,
+ load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
@@ -125,12 +126,12 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
- with get_lock(model_name_or_path, model_config.download_dir):
+ with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
- cache_dir=model_config.download_dir,
- tqdm_class=Disabledtqdm)
+ cache_dir=load_config.download_dir,
+ tqdm_class=DisabledTqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
@@ -153,169 +154,127 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
return quant_cls.from_config(config)
-def prepare_hf_model_weights(
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- fall_back_to_pt: bool = True,
- revision: Optional[str] = None,
-) -> Tuple[str, List[str], bool]:
- # Download model weights from huggingface.
- is_local = os.path.isdir(model_name_or_path) \
- and load_format != "tensorizer"
- use_safetensors = False
- # Some quantized models use .pt files for storing the weights.
- if load_format == "auto":
- allow_patterns = ["*.safetensors", "*.bin"]
- elif load_format == "safetensors":
- use_safetensors = True
- allow_patterns = ["*.safetensors"]
- elif load_format == "pt":
- allow_patterns = ["*.pt"]
- elif load_format == "npcache":
- allow_patterns = ["*.bin"]
- elif load_format == "tensorizer":
- allow_patterns = ["*.tensors"]
- else:
- raise ValueError(f"Unknown load_format: {load_format}")
-
- if fall_back_to_pt:
- allow_patterns += ["*.pt"]
-
- if not is_local and load_format != "tensorizer":
- # Before we download we look at that is available:
- fs = HfFileSystem()
- file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
-
- # depending on what is available we download different things
- for pattern in allow_patterns:
- matching = fnmatch.filter(file_list, pattern)
- if len(matching) > 0:
- allow_patterns = [pattern]
- break
-
- logger.info(f"Using model weights format {allow_patterns}")
- # Use file lock to prevent multiple processes from
- # downloading the same model weights at the same time.
- with get_lock(model_name_or_path, cache_dir):
- hf_folder = snapshot_download(model_name_or_path,
- allow_patterns=allow_patterns,
- cache_dir=cache_dir,
- tqdm_class=Disabledtqdm,
- revision=revision)
- else:
- hf_folder = model_name_or_path
- hf_weights_files: List[str] = []
+def download_weights_from_hf(model_name_or_path: str,
+ cache_dir: Optional[str],
+ allow_patterns: List[str],
+ revision: Optional[str] = None) -> str:
+ """Download model weights from Hugging Face Hub.
+
+ Args:
+ model_name_or_path (str): The model name or path.
+ cache_dir (Optional[str]): The cache directory to store the model
+ weights. If None, will use HF defaults.
+ allow_patterns (List[str]): The allowed patterns for the
+ weight files. Files matched by any of the patterns will be
+ downloaded.
+ revision (Optional[str]): The revision of the model.
+
+ Returns:
+ str: The path to the downloaded model weights.
+ """
+ # Before we download we look at that is available:
+ fs = HfFileSystem()
+ file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
+
+ # depending on what is available we download different things
for pattern in allow_patterns:
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
- if len(hf_weights_files) > 0:
- if pattern == "*.safetensors":
- use_safetensors = True
+ matching = fnmatch.filter(file_list, pattern)
+ if len(matching) > 0:
+ allow_patterns = [pattern]
break
- if not use_safetensors:
- # Exclude files that are not needed for inference.
- # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
- blacklist = [
- "training_args.bin",
- "optimizer.bin",
- "optimizer.pt",
- "scheduler.pt",
- "scaler.pt",
- ]
- hf_weights_files = [
- f for f in hf_weights_files
- if not any(f.endswith(x) for x in blacklist)
- ]
-
- if load_format == "tensorizer":
- return hf_folder, hf_weights_files, use_safetensors
-
- if len(hf_weights_files) == 0:
- raise RuntimeError(
- f"Cannot find any model weights with `{model_name_or_path}`")
-
- return hf_folder, hf_weights_files, use_safetensors
-
-
-def hf_model_weights_iterator(
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: Union[Tuple, str] = "auto",
- revision: Optional[str] = None,
- fall_back_to_pt: Optional[bool] = True,
-) -> Iterator[Tuple[str, torch.Tensor]]:
- hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
- model_name_or_path,
- cache_dir=cache_dir,
- load_format=load_format,
- fall_back_to_pt=fall_back_to_pt,
- revision=revision)
-
- if load_format == "npcache":
- # Currently np_cache only support *.bin checkpoints
- assert use_safetensors is False
-
- # Convert the model weights from torch tensors to numpy arrays for
- # faster loading.
- np_folder = os.path.join(hf_folder, "np")
- os.makedirs(np_folder, exist_ok=True)
- weight_names_file = os.path.join(np_folder, "weight_names.json")
- # Use file lock to prevent multiple processes from
- # dumping the same model weights to numpy at the same time.
- with get_lock(model_name_or_path, cache_dir):
- if not os.path.exists(weight_names_file):
- weight_names = []
- for bin_file in hf_weights_files:
- state = torch.load(bin_file, map_location="cpu")
- for name, param in state.items():
- param_path = os.path.join(np_folder, name)
- with open(param_path, "wb") as f:
- np.save(f, param.cpu().detach().numpy())
- weight_names.append(name)
- with open(weight_names_file, "w") as f:
- json.dump(weight_names, f)
-
- with open(weight_names_file, "r") as f:
- weight_names = json.load(f)
-
- for name in weight_names:
- param_path = os.path.join(np_folder, name)
- with open(param_path, "rb") as f:
- param = np.load(f)
- yield name, torch.from_numpy(param)
- elif load_format == "tensorizer":
- from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
- open_stream,
- tensorizer_warning)
- tensorizer_args = load_format.params
- tensorizer_warning(
- "Deserializing HuggingFace models is not optimized for "
- "loading on vLLM, as tensorizer is forced to load to CPU. "
- "Consider deserializing a vLLM model instead for faster "
- "load times. See the examples/tensorize_vllm_model.py example "
- "script for serializing vLLM models.")
-
- deserializer_args = tensorizer_args.deserializer_params
- stream_params = tensorizer_args.stream_params
- stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
- with TensorDeserializer(stream, **deserializer_args,
- device="cpu") as state:
- for name, param in state.items():
+
+ logger.info(f"Using model weights format {allow_patterns}")
+ # Use file lock to prevent multiple processes from
+ # downloading the same model weights at the same time.
+ with get_lock(model_name_or_path, cache_dir):
+ hf_folder = snapshot_download(model_name_or_path,
+ allow_patterns=allow_patterns,
+ cache_dir=cache_dir,
+ tqdm_class=DisabledTqdm,
+ revision=revision)
+ return hf_folder
+
+
+def filter_files_not_needed_for_inference(
+ hf_weights_files: List[str]) -> List[str]:
+ """
+ Exclude files that are not needed for inference.
+
+ See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
+ """
+ blacklist = [
+ "training_args.bin",
+ "optimizer.bin",
+ "optimizer.pt",
+ "scheduler.pt",
+ "scaler.pt",
+ ]
+ hf_weights_files = [
+ f for f in hf_weights_files
+ if not any(f.endswith(x) for x in blacklist)
+ ]
+ return hf_weights_files
+
+
+def np_cache_weights_iterator(
+ model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
+ hf_weights_files: List[str]
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+ """Iterate over the weights in the model np files.
+
+ Will dump the model weights to numpy files if they are not already dumped.
+ """
+ # Convert the model weights from torch tensors to numpy arrays for
+ # faster loading.
+ np_folder = os.path.join(hf_folder, "np")
+ os.makedirs(np_folder, exist_ok=True)
+ weight_names_file = os.path.join(np_folder, "weight_names.json")
+ # Use file lock to prevent multiple processes from
+ # dumping the same model weights to numpy at the same time.
+ with get_lock(model_name_or_path, cache_dir):
+ if not os.path.exists(weight_names_file):
+ weight_names = []
+ for bin_file in hf_weights_files:
+ state = torch.load(bin_file, map_location="cpu")
+ for name, param in state.items():
+ param_path = os.path.join(np_folder, name)
+ with open(param_path, "wb") as f:
+ np.save(f, param.cpu().detach().numpy())
+ weight_names.append(name)
+ with open(weight_names_file, "w") as f:
+ json.dump(weight_names, f)
+
+ with open(weight_names_file, "r") as f:
+ weight_names = json.load(f)
+
+ for name in weight_names:
+ param_path = os.path.join(np_folder, name)
+ with open(param_path, "rb") as f:
+ param = np.load(f)
+ yield name, torch.from_numpy(param)
+
+
+def safetensors_weights_iterator(
+ hf_weights_files: List[str]
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+ """Iterate over the weights in the model safetensor files."""
+ for st_file in hf_weights_files:
+ with safe_open(st_file, framework="pt") as f:
+ for name in f.keys(): # noqa: SIM118
+ param = f.get_tensor(name)
yield name, param
+
+
+def pt_weights_iterator(
+ hf_weights_files: List[str]
+) -> Generator[Tuple[str, torch.Tensor], None, None]:
+ """Iterate over the weights in the model bin/pt files."""
+ for bin_file in hf_weights_files:
+ state = torch.load(bin_file, map_location="cpu")
+ for name, param in state.items():
+ yield name, param
del state
- elif use_safetensors:
- for st_file in hf_weights_files:
- with safe_open(st_file, framework="pt") as f:
- for name in f.keys(): # noqa: SIM118
- param = f.get_tensor(name)
- yield name, param
- else:
- for bin_file in hf_weights_files:
- state = torch.load(bin_file, map_location="cpu")
- for name, param in state.items():
- yield name, param
- del state
- torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
def kv_cache_scales_loader(
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index 30588aecdebe9..69162b0a92d65 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -19,7 +19,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -40,9 +40,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -340,19 +339,14 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index 40966ab33631a..14f325e624f41 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -17,7 +17,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -35,9 +35,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -298,14 +297,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 7b46ba306619a..3cdb7a7bca1c1 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -2,7 +2,7 @@
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -22,9 +22,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
@@ -370,14 +369,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index aa9b28b676e0b..d80969773e163 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -20,7 +20,7 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
import torch.utils.checkpoint
@@ -41,10 +41,9 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -335,13 +334,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -352,8 +345,7 @@ def load_weights(
]
params_dict = dict(self.named_parameters())
loaded_params = set()
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 49eb7f1b2c185..179094b8fd7aa 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -1,5 +1,5 @@
# coding=utf-8
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
@@ -18,10 +18,9 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@@ -391,20 +390,13 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py
index abf4a462871b0..d476630ee6f11 100644
--- a/vllm/model_executor/models/decilm.py
+++ b/vllm/model_executor/models/decilm.py
@@ -23,16 +23,15 @@
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
-from typing import Optional
+from typing import Iterable, Optional, Tuple
import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
class DeciLMForCausalLM(LlamaForCausalLM):
@@ -65,11 +64,7 @@ def __init__(
linear_method=linear_method,
lora_config=lora_config)
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -79,8 +74,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py
index c7dd11d07e6da..46101a152ec0d 100644
--- a/vllm/model_executor/models/deepseek.py
+++ b/vllm/model_executor/models/deepseek.py
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -44,9 +44,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -316,6 +315,8 @@ def forward(
class DeepseekModel(nn.Module):
+ fall_back_to_pt_during_load = False
+
def __init__(
self,
config: PretrainedConfig,
@@ -395,11 +396,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -410,12 +407,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path,
- cache_dir,
- load_format,
- revision,
- fall_back_to_pt=False):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index 4f1ebcd5fb43c..25ce239d14662 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -19,7 +19,7 @@
"""PyTorch Falcon model."""
import math
-from typing import List, Optional, Union
+from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -40,9 +40,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
@@ -399,11 +398,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
@@ -413,8 +408,7 @@ def load_weights(self,
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if name == "lm_head.weight":
# Falcon uses tied embeddings.
continue
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index fc1fc35570368..6d01537c5c344 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -15,7 +15,7 @@
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -36,9 +36,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
@@ -346,11 +345,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -361,8 +356,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters())
loaded_params = set()
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 43f0d47fcb122..850050c7232d0 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -34,9 +34,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -239,14 +238,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index cec2d771adfa8..8278ba02514d5 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -35,9 +35,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -260,14 +259,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index 5660097652748..7a830d7f9c965 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -34,9 +34,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -248,11 +247,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -262,8 +257,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index 2f9e2171cf114..b946aed92ed35 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -34,9 +34,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -262,14 +261,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
continue
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index 6e9cbd3f9f43f..db1da8bdc4fb9 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -18,9 +18,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -274,19 +273,14 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index a041b0c9a0452..e7ee749e824e4 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -20,7 +20,7 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -36,9 +36,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import JAISConfig
@@ -303,16 +302,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index c86e292e7df1a..016e3b039d1e8 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -42,10 +42,9 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator,
- kv_cache_scales_loader)
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
@@ -376,11 +375,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -390,8 +385,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index c2571d0893c8d..314a2792bf167 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -1,4 +1,4 @@
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -13,10 +13,9 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
_KEYS_TO_MODIFY_MAPPING = {
@@ -198,11 +197,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -213,8 +208,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 49eda9c9a8112..f0d72fafcaf70 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -45,10 +45,9 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -472,11 +471,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -493,8 +488,7 @@ def load_weights(self,
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index ff552a9d86536..4d1755f2bbe63 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -43,10 +43,9 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -319,6 +318,8 @@ def forward(
class MixtralForCausalLM(nn.Module):
+ fall_back_to_pt_during_load = False
+
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -393,11 +394,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -414,12 +411,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path,
- cache_dir,
- load_format,
- revision,
- fall_back_to_pt=False):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py
index 1f0c0e912beea..acd13cc27f159 100644
--- a/vllm/model_executor/models/mixtral_quant.py
+++ b/vllm/model_executor/models/mixtral_quant.py
@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import numpy as np
import torch
@@ -43,9 +43,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -327,6 +326,7 @@ def forward(
class MixtralForCausalLM(nn.Module):
+ fall_back_to_pt_during_load = False
def __init__(
self,
@@ -366,11 +366,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -379,12 +375,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path,
- cache_dir,
- load_format,
- revision,
- fall_back_to_pt=False):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index af4cdce29d085..340f63286739b 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -1,7 +1,7 @@
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
@@ -18,9 +18,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
@@ -284,14 +283,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 3513c72879102..b92003bc0e067 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -36,7 +36,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only OLMo model compatible with HuggingFace weights."""
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
# this model must need this dependency
@@ -56,9 +56,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -348,16 +347,9 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(
- self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None,
- ):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
# attention
if ".att" in name:
name = name.replace(".att", ".attn.att")
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 3a640850662c0..89263166bca81 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -35,9 +35,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -315,11 +314,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -327,8 +322,7 @@ def load_weights(self,
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if name.startswith("decoder."):
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index c606ac027e9d9..bbb9fa5347cc8 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -4,7 +4,7 @@
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -22,9 +22,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -280,11 +279,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -294,8 +289,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index e91624da90955..f974b78a0fbda 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -35,7 +35,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -53,9 +53,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -265,11 +264,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -278,8 +273,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index 6213a2ded65ab..a77da7cb15984 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -4,7 +4,7 @@
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -23,9 +23,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -253,19 +252,14 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 796e30e633e85..71b906e20ac19 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -42,9 +42,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -331,11 +330,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -345,8 +340,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index f920b4f5a40c7..59908bc9ef26a 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
@@ -46,9 +46,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -366,6 +365,8 @@ def forward(
class Qwen2MoeForCausalLM(nn.Module):
+ fall_back_to_pt_during_load = False
+
def __init__(
self,
config: PretrainedConfig,
@@ -404,11 +405,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -419,12 +416,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path,
- cache_dir,
- load_format,
- revision,
- fall_back_to_pt=False):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 651598b770f13..3e6c2db6f3c65 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -19,7 +19,7 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
-from typing import List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -37,9 +37,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -262,11 +261,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -276,8 +271,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 76e8e48673413..b90f3da141c2e 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
-from typing import List, Optional
+from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -36,9 +36,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -274,11 +273,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -287,8 +282,7 @@ def load_weights(self,
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py
index 7e9ce9e5c8e15..4e905390c2340 100644
--- a/vllm/model_executor/models/xverse.py
+++ b/vllm/model_executor/models/xverse.py
@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -40,9 +40,8 @@
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.model_executor.weight_utils import (default_weight_loader,
- hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@@ -331,11 +330,7 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
- def load_weights(self,
- model_name_or_path: str,
- cache_dir: Optional[str] = None,
- load_format: str = "auto",
- revision: Optional[str] = None):
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
@@ -344,8 +339,7 @@ def load_weights(self,
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
- for name, loaded_weight in hf_model_weights_iterator(
- model_name_or_path, cache_dir, load_format, revision):
+ for name, loaded_weight in weights:
if ("rotary_emb.inv_freq" in name
or "rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index 5d3d5801c960d..c98a673bfed4b 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -1,8 +1,10 @@
+import os
from typing import Optional, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
+from vllm.config import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
@@ -57,9 +59,26 @@ def get_tokenizer(
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
+ download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
- """Gets a tokenizer for the given model name via Huggingface."""
+ """Gets a tokenizer for the given model name via Huggingface/modelscope."""
+ if VLLM_USE_MODELSCOPE:
+ # download model from ModelScope hub,
+ # lazy import so that modelscope is not required for normal use.
+ # pylint: disable=C.
+ from modelscope.hub.snapshot_download import snapshot_download
+
+ # Only set the tokenizer here, model will be downloaded on the workers.
+ if not os.path.exists(tokenizer_name):
+ tokenizer_path = snapshot_download(
+ model_id=tokenizer_name,
+ cache_dir=download_dir,
+ revision=tokenizer_revision,
+ # Ignore weights - we only need the tokenizer.
+ ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
+ tokenizer_name = tokenizer_path
+
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError(
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index 49e1ad5709f5d..d378e3a90e1e7 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -3,8 +3,8 @@
import torch
from vllm.attention import AttentionMetadata, get_attn_backend
-from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig)
+from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
+ ParallelConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
@@ -26,6 +26,7 @@ def __init__(
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
+ load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
@@ -36,6 +37,7 @@ def __init__(
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
+ self.load_config = load_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
@@ -55,8 +57,10 @@ def __init__(
self.model_config.dtype if model_config is not None else None)
def load_model(self) -> None:
- self.model = get_model(self.model_config,
- self.device_config,
+ self.model = get_model(model_config=self.model_config,
+ load_config=self.load_config,
+ device_config=self.device_config,
+ vision_language_config=None,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 41341b063bed7..6610b9c4be876 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -5,8 +5,8 @@
import torch.distributed
from vllm.attention import get_attn_backend
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig)
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+ ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
@@ -117,6 +117,7 @@ def __init__(
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
+ load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
@@ -129,6 +130,7 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
+ self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
@@ -141,6 +143,7 @@ def __init__(
parallel_config,
scheduler_config,
device_config,
+ load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 7dbe14ead0976..42c06a1b19361 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -9,9 +9,8 @@
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
-from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
- SchedulerConfig, TensorizerConfig,
- VisionLanguageConfig)
+from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
+ ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
from vllm.distributed.device_communicators import (custom_all_reduce,
pynccl_utils)
@@ -108,17 +107,17 @@ def __init__(
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
+ load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
- tensorizer_config: Optional[TensorizerConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
- self.tensorizer_config = tensorizer_config
+ self.load_config = load_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
@@ -156,13 +155,13 @@ def __init__(
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
- self.model_config,
- self.device_config,
+ model_config=self.model_config,
+ device_config=self.device_config,
+ load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
- tensorizer_config=self.tensorizer_config,
)
self.model_memory_usage = m.consumed_memory
diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py
index fff721a80c204..f70a7193effeb 100644
--- a/vllm/worker/neuron_model_runner.py
+++ b/vllm/worker/neuron_model_runner.py
@@ -6,7 +6,7 @@
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
-from vllm.model_executor.neuron_model_loader import get_neuron_model
+from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 82491c6df6616..6a79285f60579 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -6,8 +6,8 @@
import torch
import torch.distributed
-from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig, TensorizerConfig,
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+ ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
@@ -38,12 +38,12 @@ def __init__(
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
+ load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
- tensorizer_config: Optional[TensorizerConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
@@ -55,7 +55,7 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
- self.tensorizer_config = tensorizer_config
+ self.load_config = load_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
@@ -70,11 +70,11 @@ def __init__(
parallel_config,
scheduler_config,
device_config,
+ load_config=load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
- tensorizer_config=tensorizer_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
From e95cd879598b834f85e70ebcd23db316ae430540 Mon Sep 17 00:00:00 2001
From: Cade Daniel
Date: Tue, 16 Apr 2024 13:09:21 -0700
Subject: [PATCH 053/413] [Speculative decoding 6/9] Integrate speculative
decoding with LLMEngine (#3894)
---
tests/core/block/e2e/test_correctness.py | 70 ++++
tests/core/utils.py | 17 +-
.../output_processor/test_multi_step.py | 270 +++++++++++++
tests/spec_decode/e2e/test_correctness.py | 127 +++++-
tests/spec_decode/test_multi_step_worker.py | 4 +-
tests/spec_decode/test_spec_decode_worker.py | 32 +-
tests/spec_decode/utils.py | 2 +-
vllm/core/block/block_table.py | 1 -
vllm/core/scheduler.py | 8 +-
vllm/engine/async_llm_engine.py | 4 +-
vllm/engine/llm_engine.py | 371 +++---------------
vllm/engine/output_processor/__init__.py | 0
vllm/engine/output_processor/interfaces.py | 69 ++++
vllm/engine/output_processor/multi_step.py | 126 ++++++
vllm/engine/output_processor/single_step.py | 276 +++++++++++++
vllm/engine/output_processor/stop_checker.py | 101 +++++
vllm/engine/output_processor/util.py | 16 +
vllm/executor/cpu_executor.py | 3 +-
vllm/executor/executor_base.py | 5 +-
vllm/executor/gpu_executor.py | 79 +++-
vllm/executor/neuron_executor.py | 5 +-
vllm/executor/ray_gpu_executor.py | 3 +-
vllm/sequence.py | 13 +
vllm/spec_decode/batch_expansion.py | 23 +-
vllm/spec_decode/multi_step_worker.py | 16 +-
vllm/spec_decode/spec_decode_worker.py | 44 ++-
vllm/spec_decode/util.py | 26 ++
vllm/worker/cpu_worker.py | 8 +-
vllm/worker/neuron_worker.py | 11 +-
vllm/worker/worker.py | 11 +-
vllm/worker/worker_base.py | 13 +-
31 files changed, 1347 insertions(+), 407 deletions(-)
create mode 100644 tests/engine/output_processor/test_multi_step.py
create mode 100644 vllm/engine/output_processor/__init__.py
create mode 100644 vllm/engine/output_processor/interfaces.py
create mode 100644 vllm/engine/output_processor/multi_step.py
create mode 100644 vllm/engine/output_processor/single_step.py
create mode 100644 vllm/engine/output_processor/stop_checker.py
create mode 100644 vllm/engine/output_processor/util.py
diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py
index 94b65401e1dd4..0ee78a9b0a8ea 100644
--- a/tests/core/block/e2e/test_correctness.py
+++ b/tests/core/block/e2e/test_correctness.py
@@ -230,6 +230,76 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [
+ {
+ # Use a small model for a fast test.
+ "model": "facebook/opt-125m",
+
+ # skip cuda graph creation for fast test.
+ "enforce_eager": True,
+ "enable_chunked_prefill": True,
+ "max_num_batched_tokens": 2,
+ "max_num_seqs": 2,
+ },
+ ])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [
+ {
+ "use_v2_block_manager": False,
+ },
+])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "use_v2_block_manager": True,
+ "num_lookahead_slots": 0,
+ },
+ {
+ "use_v2_block_manager": True,
+ "num_lookahead_slots": 5,
+ },
+])
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("seed", [1])
+def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
+ test_llm_generator, batch_size):
+ """Verify that chunked prefill works with BlockManagerV2, with and without
+ lookahead scheduling.
+ """
+ output_len = 32
+ temperature = 0.0
+
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+
+ prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+ sampling_params = SamplingParams(
+ max_tokens=output_len,
+ ignore_eos=True,
+ temperature=temperature,
+ )
+
+ print('Getting token ids with BlockManagerV1')
+ baseline_token_ids = get_token_ids_from_llm_generator(
+ baseline_llm_generator, prompts, sampling_params)
+
+ print('Getting token ids with BlockManagerV2')
+ test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
+ prompts, sampling_params)
+
+ for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
+ test_token_ids):
+ assert expected_token_ids == actual_token_ids
+
+ assert baseline_token_ids == test_token_ids
+
+
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
diff --git a/tests/core/utils.py b/tests/core/utils.py
index fbbdb07cb8e6e..22c1d3826dff4 100644
--- a/tests/core/utils.py
+++ b/tests/core/utils.py
@@ -1,5 +1,5 @@
import time
-from typing import Optional, Tuple
+from typing import Iterable, Optional, Tuple
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
@@ -31,14 +31,17 @@ def create_dummy_prompt(
def create_seq_group(
- seq_prompt_len=1024,
- seq_output_lens=(128, ),
- request_id='0',
- seq_id_start=0,
-) -> SequenceGroup:
+ seq_prompt_len: int = 1024,
+ seq_output_lens: Iterable[int] = (128, ),
+ request_id: str = '0',
+ seq_id_start: int = 0,
+ sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
assert len(seq_output_lens) > 0
+ if sampling_params is None:
+ sampling_params = SamplingParams()
+
prompt_token_ids = [0] * seq_prompt_len
seqs = []
@@ -60,7 +63,7 @@ def create_seq_group(
seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
- sampling_params=SamplingParams(),
+ sampling_params=sampling_params,
arrival_time=time.time(),
)
diff --git a/tests/engine/output_processor/test_multi_step.py b/tests/engine/output_processor/test_multi_step.py
new file mode 100644
index 0000000000000..6da3da091db78
--- /dev/null
+++ b/tests/engine/output_processor/test_multi_step.py
@@ -0,0 +1,270 @@
+import random
+from unittest.mock import MagicMock
+
+import pytest
+from transformers import PreTrainedTokenizer
+
+from tests.core.utils import create_seq_group
+from vllm.core.scheduler import Scheduler
+from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
+from vllm.engine.output_processor.stop_checker import StopChecker
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
+ SequenceStatus)
+from vllm.transformers_utils.detokenizer import Detokenizer
+from vllm.utils import Counter
+
+
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [1, 12])
+@pytest.mark.skip_global_cleanup
+def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
+ """Verify multi-step decoding appends token ids correctly.
+
+ We append token ids and verify all the token ids were appended correctly.
+ Note that ignore_eos=True.
+ """
+ detokenizer = MagicMock(spec=Detokenizer)
+ scheduler = MagicMock(spec=Scheduler)
+ stop_checker = MagicMock(spec=StopChecker)
+ seq_counter = Counter()
+
+ output_processor = MultiStepOutputProcessor(
+ detokenizer=detokenizer,
+ scheduler=scheduler,
+ seq_counter=seq_counter,
+ get_tokenizer_for_seq=lambda _: mock_tokenizer(),
+ stop_checker=stop_checker,
+ )
+
+ seq_group = create_seq_group(
+ seq_prompt_len=1024,
+ seq_output_lens=[seq_output_len],
+ sampling_params=SamplingParams(max_tokens=seq_output_len +
+ num_new_tokens,
+ ignore_eos=True),
+ )
+
+ seq = seq_group.get_seqs()[0]
+ seq.status = SequenceStatus.RUNNING
+
+ new_token_ids = list(range(num_new_tokens))
+
+ outputs = [
+ SequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq.seq_id,
+ output_token=output_token,
+ logprobs={output_token: Logprob(0.0)},
+ )
+ ],
+ prompt_logprobs=None,
+ ) for output_token in new_token_ids
+ ]
+
+ assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
+ output_processor.process_outputs(seq_group, outputs)
+ assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
+
+
+@pytest.mark.parametrize("seq_prompt_len", [1024])
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
+@pytest.mark.parametrize("max_tokens", [128 + 3])
+@pytest.mark.skip_global_cleanup
+def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
+ seq_output_len: int, max_tokens: int):
+ """Verify tokens after max_tokens are dropped and not appended to the
+ sequence.
+ """
+ detokenizer = MagicMock(spec=Detokenizer)
+ scheduler = MagicMock(spec=Scheduler)
+ stop_checker = MagicMock(spec=StopChecker)
+ seq_counter = Counter()
+
+ output_processor = MultiStepOutputProcessor(
+ detokenizer=detokenizer,
+ scheduler=scheduler,
+ seq_counter=seq_counter,
+ get_tokenizer_for_seq=lambda _: mock_tokenizer(),
+ stop_checker=stop_checker,
+ )
+
+ seq_group = create_seq_group(
+ seq_prompt_len=seq_prompt_len,
+ seq_output_lens=[seq_output_len],
+ sampling_params=SamplingParams(max_tokens=max_tokens, ),
+ )
+
+ seq = seq_group.get_seqs()[0]
+ seq.status = SequenceStatus.RUNNING
+
+ new_token_ids = list(range(num_new_tokens))
+
+ outputs = [
+ SequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq.seq_id,
+ output_token=output_token,
+ logprobs={output_token: Logprob(0.0)},
+ )
+ ],
+ prompt_logprobs=None,
+ ) for output_token in new_token_ids
+ ]
+
+ assert seq.get_len() == seq_prompt_len + seq_output_len
+ output_processor.process_outputs(seq_group, outputs)
+
+ # Expect the processed sequence to not go over max tokens in len.
+ assert seq.get_len() == seq_prompt_len + max_tokens
+
+ # Expect the correct tokens were appended.
+ expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
+ assert seq.get_token_ids(
+ )[-len(expected_appended_tokens):] == expected_appended_tokens
+
+
+@pytest.mark.parametrize("seq_prompt_len", [1024])
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [12])
+@pytest.mark.parametrize("seed", list(range(6)))
+@pytest.mark.skip_global_cleanup
+def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
+ seq_output_len: int, seed: int):
+ """Verify the eos token id is included in the sequence, but subsequent
+ tokens are dropped (not appended to sequence).
+ """
+ random.seed(seed)
+ detokenizer = MagicMock(spec=Detokenizer)
+ scheduler = MagicMock(spec=Scheduler)
+ stop_checker = MagicMock(spec=StopChecker)
+ seq_counter = Counter()
+
+ eos_token_id = 100
+
+ output_processor = MultiStepOutputProcessor(
+ detokenizer=detokenizer,
+ scheduler=scheduler,
+ seq_counter=seq_counter,
+ get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
+ stop_checker=stop_checker,
+ )
+
+ seq_group = create_seq_group(
+ seq_prompt_len=seq_prompt_len,
+ seq_output_lens=[seq_output_len],
+ sampling_params=SamplingParams(
+ # Ensure enough space.
+ max_tokens=seq_output_len + num_new_tokens, ),
+ )
+
+ seq = seq_group.get_seqs()[0]
+ seq.status = SequenceStatus.RUNNING
+
+ new_token_ids = list(range(num_new_tokens))
+ assert eos_token_id not in new_token_ids
+ eos_index = random.randint(0, len(new_token_ids) - 1)
+ new_token_ids[eos_index] = eos_token_id
+
+ outputs = [
+ SequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq.seq_id,
+ output_token=output_token,
+ logprobs={output_token: Logprob(0.0)},
+ )
+ ],
+ prompt_logprobs=None,
+ ) for output_token in new_token_ids
+ ]
+
+ assert seq.get_len() == seq_prompt_len + seq_output_len
+ output_processor.process_outputs(seq_group, outputs)
+
+ # Expect the processed sequence to not go beyond provided eos.
+ assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
+
+ # Expect the correct tokens were appended.
+ expected_appended_tokens = new_token_ids[:eos_index + 1]
+ assert seq.get_token_ids(
+ )[-len(expected_appended_tokens):] == expected_appended_tokens
+
+
+@pytest.mark.parametrize("seq_prompt_len", [1024])
+@pytest.mark.parametrize("seq_output_len", [128])
+@pytest.mark.parametrize("num_new_tokens", [12])
+@pytest.mark.parametrize("seed", list(range(6)))
+@pytest.mark.skip_global_cleanup
+def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
+ seq_output_len: int, seed: int):
+ """When sampling parameters dictate that we should ignore the eos token id,
+ ensure all token ids are appended even if the eos token id is emitted.
+ """
+ random.seed(seed)
+ detokenizer = MagicMock(spec=Detokenizer)
+ scheduler = MagicMock(spec=Scheduler)
+ stop_checker = MagicMock(spec=StopChecker)
+ seq_counter = Counter()
+
+ eos_token_id = 100
+
+ output_processor = MultiStepOutputProcessor(
+ detokenizer=detokenizer,
+ scheduler=scheduler,
+ seq_counter=seq_counter,
+ get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
+ stop_checker=stop_checker,
+ )
+
+ seq_group = create_seq_group(
+ seq_prompt_len=seq_prompt_len,
+ seq_output_lens=[seq_output_len],
+ sampling_params=SamplingParams(
+ # Ensure enough space.
+ max_tokens=seq_output_len + num_new_tokens,
+ ignore_eos=True,
+ ),
+ )
+
+ seq = seq_group.get_seqs()[0]
+ seq.status = SequenceStatus.RUNNING
+
+ new_token_ids = list(range(num_new_tokens))
+ assert eos_token_id not in new_token_ids
+ eos_index = random.randint(0, len(new_token_ids) - 1)
+ new_token_ids[eos_index] = eos_token_id
+
+ outputs = [
+ SequenceGroupOutput(
+ samples=[
+ SequenceOutput(
+ parent_seq_id=seq.seq_id,
+ output_token=output_token,
+ logprobs={output_token: Logprob(0.0)},
+ )
+ ],
+ prompt_logprobs=None,
+ ) for output_token in new_token_ids
+ ]
+
+ assert seq.get_len() == seq_prompt_len + seq_output_len
+ output_processor.process_outputs(seq_group, outputs)
+
+ # Expect the processed sequence to go beyond eos.
+ assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
+
+ # Expect the correct tokens were appended.
+ expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
+ seq_output_len]
+ assert seq.get_token_ids(
+ )[-len(expected_appended_tokens):] == expected_appended_tokens
+
+
+def mock_tokenizer(eos_token_id=1000):
+ tokenizer = MagicMock(spec=PreTrainedTokenizer)
+ tokenizer.eos_token_id = eos_token_id
+ return tokenizer
diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py
index b5a6fcb7900a3..a8ebd66841eb2 100644
--- a/tests/spec_decode/e2e/test_correctness.py
+++ b/tests/spec_decode/e2e/test_correctness.py
@@ -1,4 +1,8 @@
+from itertools import cycle
+from typing import List, Tuple
+
import pytest
+from transformers import AutoTokenizer
from vllm import SamplingParams
@@ -7,18 +11,47 @@
"common_llm_kwargs",
[{
# Use a small model for a fast test.
- "model": "facebook/opt-125m",
- "speculative_model": "facebook/opt-125m",
- "num_speculative_tokens": 5,
+ # Note this is repeated in the test body; to initialize a tokenizer.
+ "model": "JackFram/llama-68m",
+
+ # Skip real loading for fast test.
+ "load_format": "dummy",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
-@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 1,
+ },
+ {
+ # No spec decode.
+ },
+ ])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("batch_size", [1])
+# NOTE: We should run more permutations of this test (more BS, more seeds). But
+# because our spec decode generates gibberish token ids, the likelihood of
+# emitting an invalid token combination is nontrivial. This causes divergence in
+# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
+# start" bytes are emitted.
@pytest.mark.parametrize("seed", [1])
-def test_spec_decode_config(test_llm_generator):
- output_len = 1024
+def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
+ """Run generation with speculative decoding on a batch. Verify the engine
+ generates the correct number of tokens (via ignore_eos=True), and that the
+ detokenization matches HF transformers.
+ """
+ output_len = 32
temperature = 0.0
prompts = [
@@ -28,23 +61,91 @@ def test_spec_decode_config(test_llm_generator):
"The future of AI is",
]
+ prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+ sampling_params = SamplingParams(
+ max_tokens=output_len,
+ ignore_eos=True,
+ temperature=temperature,
+ skip_special_tokens=True,
+ spaces_between_special_tokens=False,
+ )
+
+ batch_tokens, batch_token_ids = get_output_from_llm_generator(
+ test_llm_generator, prompts, sampling_params)
+
+ # Expect a generation for each prompt in the batch.
+ assert len(batch_token_ids) == len(prompts)
+
+ # Expect each generation to have expected number of tokens (note
+ # ignore_eos=True).
+ assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
+
+ # Expect detokenized string to match.
+ tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
+ for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids):
+ expected_tokens = tok.decode(actual_token_ids)
+ print(f"{actual_token_ids=}")
+ assert actual_tokens.strip() == expected_tokens.strip()
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ # Use a small model for a fast test.
+ "model": "JackFram/llama-68m",
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+
+ # Skip real loading for fast test.
+ "load_format": "dummy",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ {
+ # Expect failure as spec decode not supported by
+ # Ray backend.
+ "worker_use_ray": True,
+ },
+ ])
+@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_xfail(test_llm_generator):
+ """Verify that speculative decoding with Ray fails.
+ """
+ output_len = 128
+ temperature = 0.0
+
+ prompts = [
+ "Hello, my name is",
+ ]
+
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
- with pytest.raises(
- AssertionError,
- match="Speculative decoding not yet supported for GPU backend"):
- get_token_ids_from_llm_generator(test_llm_generator, prompts,
- sampling_params)
+ with pytest.raises(AssertionError,
+ match="Speculative decoding not yet supported for "):
+ get_output_from_llm_generator(test_llm_generator, prompts,
+ sampling_params)
-def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
+def get_output_from_llm_generator(
+ llm_generator, prompts,
+ sampling_params) -> Tuple[List[str], List[List[int]]]:
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
+ tokens = [output.outputs[0].text for output in outputs]
del llm
- return token_ids
+ return tokens, token_ids
diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py
index f4d44108b47c2..d6edbab579afd 100644
--- a/tests/spec_decode/test_multi_step_worker.py
+++ b/tests/spec_decode/test_multi_step_worker.py
@@ -125,7 +125,7 @@ def test_same_output_for_single_step():
zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
- **single_step_execute_model_data.to_dict(), )
+ **single_step_execute_model_data.to_dict(), )[0]
actual_token_ids = [
output.samples[0].output_token for output in actual_output
@@ -219,7 +219,7 @@ def test_same_output_for_multi_step():
continuations=continuations,
final_seq_lens=final_seq_lens))
- single_step_output.append(
+ single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), ))
# Append output tokens to new sequence data.
diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py
index 47aff8f575413..0a3110775e2d6 100644
--- a/tests/spec_decode/test_spec_decode_worker.py
+++ b/tests/spec_decode/test_spec_decode_worker.py
@@ -6,6 +6,7 @@
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed
+from vllm.sequence import SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics)
@@ -37,7 +38,8 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
execute_model_data, _, _ = create_batch(batch_size, k)
with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
+ worker.execute_model(**execute_model_data.to_dict(),
+ num_lookahead_slots=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1
@@ -102,7 +104,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
+ worker.execute_model(**execute_model_data.to_dict(),
+ num_lookahead_slots=k)
seen_contexts = []
@@ -189,13 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
- target_worker.execute_model.return_value = target_output[0]
+ target_worker.execute_model.return_value = [target_output[0]]
exception_secret = 'artifical stop'
rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret):
- worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)
+ worker.execute_model(**execute_model_data.to_dict(),
+ num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1
args, _ = rejection_sampler.call_args_list[0]
@@ -268,7 +272,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
- target_worker.execute_model.return_value = target_output[0]
+ target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
@@ -283,7 +287,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(),
- num_spec_tokens=k)
+ num_lookahead_slots=k)
expected_output = create_sampler_output_list(
rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])
@@ -380,7 +384,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
target_output = create_sampler_output_list(target_token_ids,
target_token_probs)
- target_worker.execute_model.return_value = target_output[0]
+ target_worker.execute_model.return_value = [target_output[0]]
rejection_sampler_output = torch.randint(low=0,
high=vocab_size,
@@ -400,7 +404,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(),
- num_spec_tokens=k)
+ num_lookahead_slots=k)
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = (
@@ -423,6 +427,8 @@ def test_k_equals_zero(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
+ target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
+
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
@@ -435,7 +441,7 @@ def test_k_equals_zero(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
- num_spec_tokens=k)
+ num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
@@ -443,7 +449,7 @@ def test_k_equals_zero(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
- **execute_model_data.to_dict(), return_python_output=False)
+ **execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@@ -462,6 +468,8 @@ def test_empty_input_batch(k: int, batch_size: int):
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
+ target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)]
+
draft_worker.device = 'cuda'
target_worker.device = 'cuda'
@@ -474,7 +482,7 @@ def test_empty_input_batch(k: int, batch_size: int):
batch_size, k, prev_output_token_len=0)
out = worker.execute_model(**execute_model_data.to_dict(),
- num_spec_tokens=k)
+ num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None"
@@ -482,7 +490,7 @@ def test_empty_input_batch(k: int, batch_size: int):
0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with(
- **execute_model_data.to_dict(), return_python_output=False)
+ **execute_model_data.to_dict())
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py
index edba4c226b289..d04b6029493f4 100644
--- a/tests/spec_decode/utils.py
+++ b/tests/spec_decode/utils.py
@@ -212,7 +212,7 @@ def create_sampler_output_list(
SequenceOutput(
output_token=token_id,
parent_seq_id=seq_ids[seq_index],
- logprobs={token_id: 0},
+ logprobs={token_id: Logprob(0)},
)
],
prompt_logprobs=None,
diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py
index ba061bbc4fbcb..560267e55ea3a 100644
--- a/vllm/core/block/block_table.py
+++ b/vllm/core/block/block_table.py
@@ -104,7 +104,6 @@ def append_token_ids(self,
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
- assert token_ids, "can't append empty token ids"
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 18ddcd1d6d466..4198550621030 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -762,9 +762,7 @@ def _schedule_default(self) -> SchedulerOutputs:
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups,
- num_lookahead_slots=(prefills.num_lookahead_slots +
- running_scheduled.num_lookahead_slots +
- swapped_in.num_lookahead_slots),
+ num_lookahead_slots=running_scheduled.num_lookahead_slots,
)
def _schedule_chunked_prefill(self):
@@ -850,9 +848,7 @@ def _schedule_chunked_prefill(self):
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups,
- num_lookahead_slots=(prefills.num_lookahead_slots +
- running_scheduled.num_lookahead_slots +
- swapped_in.num_lookahead_slots),
+ num_lookahead_slots=running_scheduled.num_lookahead_slots,
)
def _schedule(self) -> SchedulerOutputs:
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 1dbf58904541c..27192449bf15a 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -217,7 +217,9 @@ async def step_async(self) -> List[RequestOutput]:
else:
output = []
- return self._process_model_outputs(output, scheduler_outputs)
+ return self._process_model_outputs(
+ output, scheduler_outputs.scheduled_seq_groups,
+ scheduler_outputs.ignored_seq_groups)
async def encode_request_async(
self,
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 563694946d16e..c3de57e249ff8 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -1,5 +1,5 @@
import time
-from typing import Iterable, List, Optional, Tuple, Type, Union
+from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer
@@ -11,6 +11,10 @@
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
+from vllm.engine.output_processor.interfaces import (
+ SequenceGroupOutputProcessor)
+from vllm.engine.output_processor.stop_checker import StopChecker
+from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
@@ -18,8 +22,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
- SequenceGroup, SequenceGroupOutput, SequenceOutput,
- SequenceStatus)
+ SequenceGroup)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -187,6 +190,21 @@ def __init__(
labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config)
+ # Create sequence output processor, e.g. for beam search or
+ # speculative decoding.
+ self.output_processor = (
+ SequenceGroupOutputProcessor.create_output_processor(
+ self.scheduler_config,
+ self.detokenizer,
+ self.scheduler,
+ self.seq_counter,
+ self.get_tokenizer_for_seq,
+ stop_checker=StopChecker(
+ self.scheduler_config.max_model_len,
+ self.get_tokenizer_for_seq,
+ ),
+ ))
+
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@@ -412,240 +430,32 @@ def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
- def _check_beam_search_early_stopping(
- self,
- early_stopping: Union[bool, str],
- sampling_params: SamplingParams,
- best_running_seq: Sequence,
- current_worst_seq: Sequence,
- ) -> bool:
- assert sampling_params.use_beam_search
- length_penalty = sampling_params.length_penalty
- if early_stopping is True:
- return True
-
- current_worst_score = current_worst_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=current_worst_seq.eos_token_id)
- if early_stopping is False:
- highest_attainable_score = best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id)
- else:
- assert early_stopping == "never"
- if length_penalty > 0.0:
- # If length_penalty > 0.0, beam search will prefer longer
- # sequences. The highest attainable score calculation is
- # based on the longest possible sequence length in this case.
- max_possible_length = max(
- best_running_seq.get_prompt_len() +
- sampling_params.max_tokens,
- self.scheduler_config.max_model_len)
- highest_attainable_score = (
- best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id,
- seq_len=max_possible_length))
- else:
- # Otherwise, beam search will prefer shorter sequences. The
- # highest attainable score calculation is based on the current
- # sequence length.
- highest_attainable_score = (
- best_running_seq.get_beam_search_score(
- length_penalty=length_penalty,
- eos_token_id=best_running_seq.eos_token_id))
- return current_worst_score >= highest_attainable_score
-
- def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
- outputs: SequenceGroupOutput) -> None:
-
- # Process prompt logprobs
- prompt_logprobs = outputs.prompt_logprobs
- if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
- self.detokenizer.decode_prompt_logprobs_inplace(
- seq_group, prompt_logprobs)
- seq_group.prompt_logprobs = prompt_logprobs
-
- # Process samples
- samples = outputs.samples
- parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
- existing_finished_seqs = seq_group.get_finished_seqs()
- parent_child_dict = {
- parent_seq.seq_id: []
- for parent_seq in parent_seqs
- }
- for sample in samples:
- parent_child_dict[sample.parent_seq_id].append(sample)
- # List of (child, parent)
- child_seqs: List[Tuple[Sequence, Sequence]] = []
-
- # Process the child samples for each parent sequence
- for parent in parent_seqs:
- child_samples: List[SequenceOutput] = parent_child_dict[
- parent.seq_id]
- if len(child_samples) == 0:
- # This parent sequence has no children samples. Remove
- # the parent sequence from the sequence group since it will
- # not be used in the future iterations.
- parent.status = SequenceStatus.FINISHED_ABORTED
- seq_group.remove(parent.seq_id)
- self.scheduler.free_seq(parent)
- continue
- # Fork the parent sequence if there are multiple child samples.
- for child_sample in child_samples[:-1]:
- new_child_seq_id = next(self.seq_counter)
- child = parent.fork(new_child_seq_id)
- child.append_token_id(child_sample.output_token,
- child_sample.logprobs)
- child_seqs.append((child, parent))
- # Continue the parent sequence for the last child sample.
- # We reuse the parent sequence here to reduce redundant memory
- # copies, especially when using non-beam search sampling methods.
- last_child_sample = child_samples[-1]
- parent.append_token_id(last_child_sample.output_token,
- last_child_sample.logprobs)
- child_seqs.append((parent, parent))
-
- for seq, _ in child_seqs:
- if seq_group.sampling_params.detokenize:
- new_char_count = self.detokenizer.decode_sequence_inplace(
- seq, seq_group.sampling_params)
- else:
- new_char_count = 0
- self._check_stop(seq, new_char_count, seq_group.sampling_params)
-
- # Non-beam search case
- if not seq_group.sampling_params.use_beam_search:
- # For newly created child sequences, add them to the sequence group
- # and fork them in block manager if they are not finished.
- for seq, parent in child_seqs:
- if seq is not parent:
- seq_group.add(seq)
- if not seq.is_finished():
- self.scheduler.fork_seq(parent, seq)
-
- # Free the finished and selected parent sequences' memory in block
- # manager. Keep them in the sequence group as candidate output.
- # NOTE: we need to fork the new sequences before freeing the
- # old sequences.
- for seq, parent in child_seqs:
- if seq is parent and seq.is_finished():
- self.scheduler.free_seq(seq)
- return
-
- # Beam search case
- # Select the child sequences to keep in the sequence group.
- selected_child_seqs = []
- unselected_child_seqs = []
- beam_width = seq_group.sampling_params.best_of
- length_penalty = seq_group.sampling_params.length_penalty
-
- # Select the newly finished sequences with the highest scores
- # to replace existing finished sequences.
- # Tuple of (seq, parent, is_new)
- existing_finished_seqs = [(seq, None, False)
- for seq in existing_finished_seqs]
- new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
- if seq.is_finished()]
- all_finished_seqs = existing_finished_seqs + new_finished_seqs
- # Sort the finished sequences by their scores.
- all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
- length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
- reverse=True)
- for seq, parent, is_new in all_finished_seqs[:beam_width]:
- if is_new:
- # A newly generated child sequence finishes and has a high
- # score, so we will add it into the sequence group.
- selected_child_seqs.append((seq, parent))
- for seq, parent, is_new in all_finished_seqs[beam_width:]:
- if is_new:
- # A newly generated child sequence finishes but has a low
- # score, so we will not add it into the sequence group.
- # Additionally, if this sequence is a continuation of a
- # parent sequence, we will need remove the parent sequence
- # from the sequence group.
- unselected_child_seqs.append((seq, parent))
- else:
- # An existing finished sequence has a low score, so we will
- # remove it from the sequence group.
- seq_group.remove(seq.seq_id)
-
- # select the top beam_width sequences from the running
- # sequences for the next iteration to continue the beam
- # search.
- running_child_seqs = [(seq, parent) for seq, parent in child_seqs
- if not seq.is_finished()]
- # Sort the running sequences by their scores.
- running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
- length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
- reverse=True)
-
- # Check if we can stop the beam search.
- if len(running_child_seqs) == 0:
- # No running sequences, stop the beam search.
- stop_beam_search = True
- elif len(all_finished_seqs) < beam_width:
- # Not enough finished sequences, continue the beam search.
- stop_beam_search = False
- else:
- # Check the early stopping criteria
- best_running_seq = running_child_seqs[0][0]
- current_worst_seq = all_finished_seqs[beam_width - 1][0]
- stop_beam_search = self._check_beam_search_early_stopping(
- seq_group.sampling_params.early_stopping,
- seq_group.sampling_params, best_running_seq, current_worst_seq)
-
- if stop_beam_search:
- # Stop the beam search and remove all the running sequences from
- # the sequence group.
- unselected_child_seqs.extend(running_child_seqs)
- else:
- # Continue the beam search and select the top beam_width sequences
- # to continue the beam search.
- selected_child_seqs.extend(running_child_seqs[:beam_width])
- # The remaining running sequences will not be used in the next
- # iteration. Again, if these sequences are continuations of
- # parent sequences, we will need to remove the parent sequences
- # from the sequence group.
- unselected_child_seqs.extend(running_child_seqs[beam_width:])
-
- # For newly created child sequences, add them to the sequence group
- # and fork them in block manager if they are not finished.
- for seq, parent in selected_child_seqs:
- if seq is not parent:
- seq_group.add(seq)
- if not seq.is_finished():
- self.scheduler.fork_seq(parent, seq)
-
- # Free the finished and selected parent sequences' memory in block
- # manager. Keep them in the sequence group as candidate output.
- for seq, parent in selected_child_seqs:
- if seq is parent and seq.is_finished():
- self.scheduler.free_seq(seq)
-
- # Remove the unselected parent sequences from the sequence group and
- # free their memory in block manager.
- for seq, parent in unselected_child_seqs:
- if seq is parent:
- # Remove the parent sequence if it is not selected for next
- # iteration
- seq_group.remove(seq.seq_id)
- self.scheduler.free_seq(seq)
-
def _process_model_outputs(
- self, output: SamplerOutput,
- scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
+ self, output: List[SamplerOutput],
+ scheduled_seq_groups: List[SequenceGroup],
+ ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
+ """Apply the model output to the sequences in the scheduled seq groups.
+
+ Returns RequestOutputs that can be returned to the client.
+ """
+
now = time.time()
+
+ # Organize outputs by [sequence group][step] instead of
+ # [step][sequence group].
+ output_by_sequence_group = create_output_by_sequence_group(
+ sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
+
# Update the scheduled sequence groups with the model outputs.
- scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
- for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
+ for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
+ output_by_sequence_group):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
- self._process_sequence_group_outputs(seq_group, outputs)
+ self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
@@ -657,13 +467,9 @@ def _process_model_outputs(
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
- for seq_group in scheduler_outputs.ignored_seq_groups:
+ for seq_group in ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
-
- # Log stats.
- if self.log_stats:
- self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs
def step(self) -> List[RequestOutput]:
@@ -721,13 +527,23 @@ def step(self) -> List[RequestOutput]:
if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model(
- seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
- scheduler_outputs.blocks_to_swap_out,
- scheduler_outputs.blocks_to_copy)
+ seq_group_metadata_list=seq_group_metadata_list,
+ blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
+ blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
+ blocks_to_copy=scheduler_outputs.blocks_to_copy,
+ num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
else:
output = []
- return self._process_model_outputs(output, scheduler_outputs)
+ request_outputs = self._process_model_outputs(
+ output, scheduler_outputs.scheduled_seq_groups,
+ scheduler_outputs.ignored_seq_groups)
+
+ # Log stats.
+ if self.log_stats:
+ self.stat_logger.log(self._get_stats(scheduler_outputs))
+
+ return request_outputs
def do_log_stats(self) -> None:
"""Forced log when no requests active."""
@@ -807,87 +623,6 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)
- def _check_stop(self, seq: Sequence, new_char_count: int,
- sampling_params: SamplingParams) -> None:
- """Stop the finished sequences.
-
- new_char_count is the number of chars added to the
- sequence's output text for the newly generated token
- """
-
- # Check if the minimum number of tokens has been generated yet;
- # skip the stop string/token checks if not
- if seq.get_output_len() < sampling_params.min_tokens:
- return
-
- # Check if the sequence has generated the EOS token.
- if ((not sampling_params.ignore_eos)
- and seq.get_last_token_id() == seq.eos_token_id):
- seq.status = SequenceStatus.FINISHED_STOPPED
- return
-
- # Check if a stop token was encountered.
- # This assumes a single token produced per step.
- last_token_id = seq.get_last_token_id()
- if last_token_id in sampling_params.stop_token_ids:
- if new_char_count and (
- not sampling_params.include_stop_str_in_output):
- # Remove last token
- seq.output_text = seq.output_text[:-new_char_count]
- seq.status = SequenceStatus.FINISHED_STOPPED
- seq.stop_reason = last_token_id
- return
-
- # Check if any stop strings are matched.
- stop_str = self._check_stop_strings(seq, new_char_count,
- sampling_params)
- if stop_str is not None:
- seq.status = SequenceStatus.FINISHED_STOPPED
- seq.stop_reason = stop_str
- return
-
- # Check if the sequence has reached max_model_len.
- if seq.get_len() > self.scheduler_config.max_model_len:
- seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
- return
-
- # Check if the sequence has reached max_tokens.
- if seq.get_output_len() == sampling_params.max_tokens:
- seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
- return
-
- @staticmethod
- def _check_stop_strings(seq: Sequence, new_char_count: int,
- sampling_params: SamplingParams) -> Optional[str]:
- """Check if any stop strings are matched and truncate sequence
- output text accordingly.
-
- Returns the stop string if matched or else None.
- """
- if not new_char_count:
- return None
-
- for stop_str in sampling_params.stop:
- stop_string_len = len(stop_str)
- # Avoid searching already-searched text.
- stop_index = seq.output_text.find(
- stop_str, -new_char_count - stop_string_len)
- if stop_index == -1:
- continue
-
- if sampling_params.include_stop_str_in_output:
- # Truncate to end of stop string.
- stop_index += stop_string_len
- if stop_index >= len(seq.output_text):
- # No truncation required.
- return stop_str
-
- # Truncate the output text to either the beginning
- # or end of the stop string.
- seq.output_text = seq.output_text[:stop_index]
- return stop_str
- return None
-
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py
new file mode 100644
index 0000000000000..9ddac7a04cb36
--- /dev/null
+++ b/vllm/engine/output_processor/interfaces.py
@@ -0,0 +1,69 @@
+from abc import ABC, abstractmethod
+from typing import Callable, Iterable, List
+
+from transformers import PreTrainedTokenizer
+
+from vllm.config import SchedulerConfig
+from vllm.core.scheduler import Scheduler
+from vllm.engine.output_processor.stop_checker import StopChecker
+from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
+from vllm.transformers_utils.detokenizer import Detokenizer
+
+
+class SequenceGroupOutputProcessor(ABC):
+ """Interface for logic that processes new token ids in sequence groups,
+ managing detokenization, stop checking, and freeing/forking sequences with
+ the scheduler.
+
+ This is highly coupled with the LLMEngine and should be seen as an extension
+ of it. The logic is separated to simplify the LLMEngine class and allow
+ separate implementations for single-step decoding (which supports beam
+ search sequence forking) and multi-step decoding (which does not support
+ beam search, but does support speculative decoding).
+ """
+
+ @staticmethod
+ def create_output_processor(
+ scheduler_config: SchedulerConfig,
+ detokenizer: Detokenizer,
+ scheduler: Scheduler,
+ seq_counter: Iterable[int],
+ get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
+ stop_checker: "StopChecker",
+ ):
+ """Create an output processor.
+
+ This returns a single-step output processor if num_lookahead_slots is
+ zero, else returns a multi-step output processor.
+ """
+ if scheduler_config.num_lookahead_slots == 0:
+ # Importing here to avoid cycle.
+ from vllm.engine.output_processor.single_step import (
+ SingleStepOutputProcessor)
+ return SingleStepOutputProcessor(
+ scheduler_config,
+ detokenizer,
+ scheduler,
+ seq_counter,
+ stop_checker,
+ )
+ else:
+ # Importing here to avoid cycle.
+ from vllm.engine.output_processor.multi_step import (
+ MultiStepOutputProcessor)
+ return MultiStepOutputProcessor(
+ detokenizer,
+ scheduler,
+ seq_counter,
+ get_tokenizer_for_seq,
+ stop_checker,
+ )
+
+ @abstractmethod
+ def process_outputs(self, sequence_group: SequenceGroup,
+ outputs: List[SequenceGroupOutput]) -> None:
+ """Process new token ids for the sequence group. Handles logic such as
+ detokenization, stop checking, and freeing/forking sequences in the
+ scheduler.
+ """
+ pass
diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py
new file mode 100644
index 0000000000000..50da0d35fcec1
--- /dev/null
+++ b/vllm/engine/output_processor/multi_step.py
@@ -0,0 +1,126 @@
+from typing import Callable, Iterable, List
+
+from transformers import PreTrainedTokenizer
+
+from vllm.core.scheduler import Scheduler
+from vllm.engine.output_processor.interfaces import (
+ SequenceGroupOutputProcessor)
+from vllm.engine.output_processor.stop_checker import StopChecker
+from vllm.logger import init_logger
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import (Logprob, Sequence, SequenceGroup,
+ SequenceGroupOutput, SequenceOutput, SequenceStatus)
+from vllm.transformers_utils.detokenizer import Detokenizer
+
+logger = init_logger(__name__)
+
+
+class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
+ """SequenceGroupOutputProcessor which handles logic related to
+ detokenization and stopping conditions. It specializes to "multi-step
+ decoding", where vLLM's worker may generate multiple tokens per invocation.
+ This is currently mutually exclusive with advanced sampling techniques like
+ beam search, which motivates the separation of this logic from the single
+ step output processor.
+
+ This class is responsible for things such as correctly appending all new
+ token ids to their sequence, detokenizing new token ids, truncating new
+ output tokens after an eos token, and correctly handling the case where the
+ number of new output tokens per sequence differs in a single batch.
+ """
+
+ def __init__(
+ self,
+ detokenizer: Detokenizer,
+ scheduler: Scheduler,
+ seq_counter: Iterable[int],
+ get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
+ stop_checker: StopChecker,
+ ):
+ self.detokenizer = detokenizer
+ self.scheduler = scheduler
+ self.seq_counter = seq_counter
+ self.get_tokenizer_for_seq = get_tokenizer_for_seq
+ self.stop_checker = stop_checker
+
+ def process_outputs(self, sequence_group: SequenceGroup,
+ outputs: List[SequenceGroupOutput]) -> None:
+ """Append new tokens in the outputs to sequences in the sequence group.
+
+ This only supports sequence groups of size 1. It supports greater than
+ one new token per sequence.
+
+ This applies logic like stop condition checking and detokenization,
+ including freeing finished sequences. It also handles cases where there
+ are tokens emitted after the EOS token.
+ """
+ seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
+
+ assert seqs, "expected running sequences"
+ assert len(seqs) == 1, (
+ "Beam search not supported in multi-step decoding.")
+ seq = seqs[0]
+
+ # Since there's only one sequence per sequence group, we can take the
+ # first sample.
+ samples = [outputs[step].samples[0] for step in range(len(outputs))]
+
+ # -1 means the output token is not valid (eg. due to spec decode
+ # rejecting tokens).
+ valid_samples = [
+ sample for sample in samples if sample.output_token != -1
+ ]
+ assert valid_samples
+
+ self._process_seq_outputs(seq, valid_samples,
+ sequence_group.sampling_params)
+
+ def _process_seq_outputs(self, seq: Sequence,
+ valid_samples: List[SequenceOutput],
+ sampling_params: SamplingParams) -> None:
+ output_token_ids = [sample.output_token for sample in valid_samples]
+
+ # Truncate to max_tokens if necessary.
+ remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
+ len(output_token_ids))
+ if remaining_tokens < 0:
+ valid_samples = valid_samples[:remaining_tokens]
+ output_token_ids = output_token_ids[:remaining_tokens]
+
+ # Truncate any tokens after EOS. This is required as spec decode
+ # generates a fixed number of tokens without evaluating stopping
+ # conditions within the block. This can cause an eos token to be
+ # unintentionally ignored.
+ if not sampling_params.ignore_eos:
+ eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
+ # Avoiding .index calls as exception throwing in the happy path
+ # is expensive.
+ for i in range(len(output_token_ids)):
+ if output_token_ids[i] == eos_token_id:
+ output_token_ids = output_token_ids[:i + 1]
+ valid_samples = valid_samples[:i + 1]
+ break
+
+ # Incrementally append tokens to the sequence, as if we had only one new
+ # token.
+ for output_token_id in output_token_ids:
+ seq.append_token_id(
+ token_id=output_token_id,
+ # TODO emit logprobs in multi-step decoding.
+ logprobs={output_token_id: Logprob(0.0)},
+ )
+
+ new_char_count = 0
+ if sampling_params.detokenize:
+ new_char_count = self.detokenizer.decode_sequence_inplace(
+ seq, sampling_params)
+
+ self.stop_checker.maybe_stop_sequence(
+ seq,
+ new_char_count=new_char_count,
+ sampling_params=sampling_params)
+ if seq.is_finished():
+ break
+
+ if seq.is_finished():
+ self.scheduler.free_seq(seq)
diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py
new file mode 100644
index 0000000000000..1b7eb014f802b
--- /dev/null
+++ b/vllm/engine/output_processor/single_step.py
@@ -0,0 +1,276 @@
+from typing import Iterable, List, Tuple, Union
+
+from vllm.config import SchedulerConfig
+from vllm.core.scheduler import Scheduler
+from vllm.engine.output_processor.interfaces import (
+ SequenceGroupOutputProcessor)
+from vllm.engine.output_processor.stop_checker import StopChecker
+from vllm.logger import init_logger
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
+ SequenceOutput, SequenceStatus)
+from vllm.transformers_utils.detokenizer import Detokenizer
+
+logger = init_logger(__name__)
+
+
+class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
+ """SequenceGroupOutputProcessor which handles "output processing" logic,
+ which happens after the model returns generated token ids and before
+ scheduling of the next batch. Output processing logic includes
+ detokenization, and determining if a sequence is finished (e.g. via max len
+ or eos token).
+
+ The SingleStepOutputProcessor is specialized to the case where the model
+ emits at most a single token per invocation, which precludes configurations
+ such as speculative decoding or multi-step decoding. This enables beam
+ search sampling, which requires forking/finishing/freeing sequences in a way
+ that is currently difficult to schedule multiple steps ahead of time.
+ """
+
+ def __init__(
+ self,
+ scheduler_config: SchedulerConfig,
+ detokenizer: Detokenizer,
+ scheduler: Scheduler,
+ seq_counter: Iterable[int],
+ stop_checker: StopChecker,
+ ):
+ self.scheduler_config = scheduler_config
+ self.detokenizer = detokenizer
+ self.scheduler = scheduler
+ self.seq_counter = seq_counter
+ self.stop_checker = stop_checker
+
+ def process_outputs(self, sequence_group: SequenceGroup,
+ outputs: List[SequenceGroupOutput]) -> None:
+ """Append all new tokens to sequences in the sequence group. Fork any
+ surviving beam candidates; free any unsurviving ones.
+
+ Invokes detokenizer to detokenize new tokens, and also marks sequences
+ as finished if they meet stop conditions.
+ """
+ assert (len(outputs) == 1
+ ), f"{type(self)} does not support multiple outputs per step"
+ return self._process_sequence_group_outputs(sequence_group, outputs[0])
+
+ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
+ outputs: SequenceGroupOutput) -> None:
+
+ # Process prompt logprobs
+ prompt_logprobs = outputs.prompt_logprobs
+ if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
+ self.detokenizer.decode_prompt_logprobs_inplace(
+ seq_group, prompt_logprobs)
+ seq_group.prompt_logprobs = prompt_logprobs
+
+ # Process samples
+ samples = outputs.samples
+ parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
+ existing_finished_seqs = seq_group.get_finished_seqs()
+ parent_child_dict = {
+ parent_seq.seq_id: []
+ for parent_seq in parent_seqs
+ }
+ for sample in samples:
+ parent_child_dict[sample.parent_seq_id].append(sample)
+ # List of (child, parent)
+ child_seqs: List[Tuple[Sequence, Sequence]] = []
+
+ # Process the child samples for each parent sequence
+ for parent in parent_seqs:
+ child_samples: List[SequenceOutput] = parent_child_dict[
+ parent.seq_id]
+ if len(child_samples) == 0:
+ # This parent sequence has no children samples. Remove
+ # the parent sequence from the sequence group since it will
+ # not be used in the future iterations.
+ parent.status = SequenceStatus.FINISHED_ABORTED
+ seq_group.remove(parent.seq_id)
+ self.scheduler.free_seq(parent)
+ continue
+ # Fork the parent sequence if there are multiple child samples.
+ for child_sample in child_samples[:-1]:
+ new_child_seq_id = next(self.seq_counter)
+ child = parent.fork(new_child_seq_id)
+ child.append_token_id(child_sample.output_token,
+ child_sample.logprobs)
+ child_seqs.append((child, parent))
+ # Continue the parent sequence for the last child sample.
+ # We reuse the parent sequence here to reduce redundant memory
+ # copies, especially when using non-beam search sampling methods.
+ last_child_sample = child_samples[-1]
+ parent.append_token_id(last_child_sample.output_token,
+ last_child_sample.logprobs)
+ child_seqs.append((parent, parent))
+
+ for seq, _ in child_seqs:
+ if seq_group.sampling_params.detokenize:
+ new_char_count = self.detokenizer.decode_sequence_inplace(
+ seq, seq_group.sampling_params)
+ else:
+ new_char_count = 0
+ self.stop_checker.maybe_stop_sequence(seq, new_char_count,
+ seq_group.sampling_params)
+
+ # Non-beam search case
+ if not seq_group.sampling_params.use_beam_search:
+ # For newly created child sequences, add them to the sequence group
+ # and fork them in block manager if they are not finished.
+ for seq, parent in child_seqs:
+ if seq is not parent:
+ seq_group.add(seq)
+ if not seq.is_finished():
+ self.scheduler.fork_seq(parent, seq)
+
+ # Free the finished and selected parent sequences' memory in block
+ # manager. Keep them in the sequence group as candidate output.
+ # NOTE: we need to fork the new sequences before freeing the
+ # old sequences.
+ for seq, parent in child_seqs:
+ if seq is parent and seq.is_finished():
+ self.scheduler.free_seq(seq)
+ return
+
+ # Beam search case
+ # Select the child sequences to keep in the sequence group.
+ selected_child_seqs = []
+ unselected_child_seqs = []
+ beam_width = seq_group.sampling_params.best_of
+ length_penalty = seq_group.sampling_params.length_penalty
+
+ # Select the newly finished sequences with the highest scores
+ # to replace existing finished sequences.
+ # Tuple of (seq, parent, is_new)
+ existing_finished_seqs = [(seq, None, False)
+ for seq in existing_finished_seqs]
+ new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
+ if seq.is_finished()]
+ all_finished_seqs = existing_finished_seqs + new_finished_seqs
+ # Sort the finished sequences by their scores.
+ all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
+ length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+ reverse=True)
+ for seq, parent, is_new in all_finished_seqs[:beam_width]:
+ if is_new:
+ # A newly generated child sequence finishes and has a high
+ # score, so we will add it into the sequence group.
+ selected_child_seqs.append((seq, parent))
+ for seq, parent, is_new in all_finished_seqs[beam_width:]:
+ if is_new:
+ # A newly generated child sequence finishes but has a low
+ # score, so we will not add it into the sequence group.
+ # Additionally, if this sequence is a continuation of a
+ # parent sequence, we will need remove the parent sequence
+ # from the sequence group.
+ unselected_child_seqs.append((seq, parent))
+ else:
+ # An existing finished sequence has a low score, so we will
+ # remove it from the sequence group.
+ seq_group.remove(seq.seq_id)
+
+ # select the top beam_width sequences from the running
+ # sequences for the next iteration to continue the beam
+ # search.
+ running_child_seqs = [(seq, parent) for seq, parent in child_seqs
+ if not seq.is_finished()]
+ # Sort the running sequences by their scores.
+ running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
+ length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+ reverse=True)
+
+ # Check if we can stop the beam search.
+ if len(running_child_seqs) == 0:
+ # No running sequences, stop the beam search.
+ stop_beam_search = True
+ elif len(all_finished_seqs) < beam_width:
+ # Not enough finished sequences, continue the beam search.
+ stop_beam_search = False
+ else:
+ # Check the early stopping criteria
+ best_running_seq = running_child_seqs[0][0]
+ current_worst_seq = all_finished_seqs[beam_width - 1][0]
+ stop_beam_search = self._check_beam_search_early_stopping(
+ seq_group.sampling_params.early_stopping,
+ seq_group.sampling_params, best_running_seq, current_worst_seq)
+
+ if stop_beam_search:
+ # Stop the beam search and remove all the running sequences from
+ # the sequence group.
+ unselected_child_seqs.extend(running_child_seqs)
+ else:
+ # Continue the beam search and select the top beam_width sequences
+ # to continue the beam search.
+ selected_child_seqs.extend(running_child_seqs[:beam_width])
+ # The remaining running sequences will not be used in the next
+ # iteration. Again, if these sequences are continuations of
+ # parent sequences, we will need to remove the parent sequences
+ # from the sequence group.
+ unselected_child_seqs.extend(running_child_seqs[beam_width:])
+
+ # For newly created child sequences, add them to the sequence group
+ # and fork them in block manager if they are not finished.
+ for seq, parent in selected_child_seqs:
+ if seq is not parent:
+ seq_group.add(seq)
+ if not seq.is_finished():
+ self.scheduler.fork_seq(parent, seq)
+
+ # Free the finished and selected parent sequences' memory in block
+ # manager. Keep them in the sequence group as candidate output.
+ for seq, parent in selected_child_seqs:
+ if seq is parent and seq.is_finished():
+ self.scheduler.free_seq(seq)
+
+ # Remove the unselected parent sequences from the sequence group and
+ # free their memory in block manager.
+ for seq, parent in unselected_child_seqs:
+ if seq is parent:
+ # Remove the parent sequence if it is not selected for next
+ # iteration
+ seq_group.remove(seq.seq_id)
+ self.scheduler.free_seq(seq)
+
+ def _check_beam_search_early_stopping(
+ self,
+ early_stopping: Union[bool, str],
+ sampling_params: SamplingParams,
+ best_running_seq: Sequence,
+ current_worst_seq: Sequence,
+ ) -> bool:
+ assert sampling_params.use_beam_search
+ length_penalty = sampling_params.length_penalty
+ if early_stopping is True:
+ return True
+
+ current_worst_score = current_worst_seq.get_beam_search_score(
+ length_penalty=length_penalty,
+ eos_token_id=current_worst_seq.eos_token_id)
+ if early_stopping is False:
+ highest_attainable_score = best_running_seq.get_beam_search_score(
+ length_penalty=length_penalty,
+ eos_token_id=best_running_seq.eos_token_id)
+ else:
+ assert early_stopping == "never"
+ if length_penalty > 0.0:
+ # If length_penalty > 0.0, beam search will prefer longer
+ # sequences. The highest attainable score calculation is
+ # based on the longest possible sequence length in this case.
+ max_possible_length = max(
+ best_running_seq.get_prompt_len() +
+ sampling_params.max_tokens,
+ self.scheduler_config.max_model_len)
+ highest_attainable_score = (
+ best_running_seq.get_beam_search_score(
+ length_penalty=length_penalty,
+ eos_token_id=best_running_seq.eos_token_id,
+ seq_len=max_possible_length))
+ else:
+ # Otherwise, beam search will prefer shorter sequences. The
+ # highest attainable score calculation is based on the current
+ # sequence length.
+ highest_attainable_score = (
+ best_running_seq.get_beam_search_score(
+ length_penalty=length_penalty,
+ eos_token_id=best_running_seq.eos_token_id))
+ return current_worst_score >= highest_attainable_score
diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py
new file mode 100644
index 0000000000000..66deb9b591746
--- /dev/null
+++ b/vllm/engine/output_processor/stop_checker.py
@@ -0,0 +1,101 @@
+from typing import Callable, Optional
+
+from transformers import PreTrainedTokenizer
+
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import Sequence, SequenceStatus
+
+
+class StopChecker:
+ """LLMEngine helper class which separates out the logic involving stop
+ checking. This checks things such as: whether the eos token was emitted,
+ whether the max_tokens has been consumed, whether a stop string has been
+ emitted, or if we have exceeded the max model len.
+ """
+
+ def __init__(self, max_model_len: int,
+ get_tokenizer_for_seq: Callable[[Sequence],
+ PreTrainedTokenizer]):
+ self.max_model_len = max_model_len
+ self.get_tokenizer_for_seq = get_tokenizer_for_seq
+
+ def maybe_stop_sequence(self, seq: Sequence, new_char_count: int,
+ sampling_params: SamplingParams) -> None:
+ """Stop the finished sequences.
+
+ new_char_count is the number of chars added to the
+ sequence's output text for the newly generated token
+ """
+
+ # Check if the minimum number of tokens has been generated yet;
+ # skip the stop string/token checks if not
+ if seq.get_output_len() < sampling_params.min_tokens:
+ return
+
+ # Check if the sequence has generated the EOS token.
+ if ((not sampling_params.ignore_eos)
+ and seq.get_last_token_id() == seq.eos_token_id):
+ seq.status = SequenceStatus.FINISHED_STOPPED
+ return
+
+ # Check if a stop token was encountered.
+ # This assumes a single token produced per step.
+ last_token_id = seq.get_last_token_id()
+ if last_token_id in sampling_params.stop_token_ids:
+ if new_char_count and (
+ not sampling_params.include_stop_str_in_output):
+ # Remove last token
+ seq.output_text = seq.output_text[:-new_char_count]
+ seq.status = SequenceStatus.FINISHED_STOPPED
+ seq.stop_reason = last_token_id
+ return
+
+ # Check if any stop strings are matched.
+ stop_str = self._check_stop_strings(seq, new_char_count,
+ sampling_params)
+ if stop_str is not None:
+ seq.status = SequenceStatus.FINISHED_STOPPED
+ seq.stop_reason = stop_str
+ return
+
+ # Check if the sequence has reached max_model_len.
+ if seq.get_len() > self.max_model_len:
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
+ return
+
+ # Check if the sequence has reached max_tokens.
+ if seq.get_output_len() == sampling_params.max_tokens:
+ seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
+ return
+
+ @staticmethod
+ def _check_stop_strings(seq: Sequence, new_char_count: int,
+ sampling_params: SamplingParams) -> Optional[str]:
+ """Check if any stop strings are matched and truncate sequence
+ output text accordingly.
+
+ Returns the stop string if matched or else None.
+ """
+ if not new_char_count:
+ return None
+
+ for stop_str in sampling_params.stop:
+ stop_string_len = len(stop_str)
+ # Avoid searching already-searched text.
+ stop_index = seq.output_text.find(
+ stop_str, -new_char_count - stop_string_len)
+ if stop_index == -1:
+ continue
+
+ if sampling_params.include_stop_str_in_output:
+ # Truncate to end of stop string.
+ stop_index += stop_string_len
+ if stop_index >= len(seq.output_text):
+ # No truncation required.
+ return stop_str
+
+ # Truncate the output text to either the beginning
+ # or end of the stop string.
+ seq.output_text = seq.output_text[:stop_index]
+ return stop_str
+ return None
diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py
new file mode 100644
index 0000000000000..5fbb09a857a46
--- /dev/null
+++ b/vllm/engine/output_processor/util.py
@@ -0,0 +1,16 @@
+from typing import List
+
+from vllm.sequence import SamplerOutput
+
+
+def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
+ num_seq_groups: int):
+ """Helper method which transforms a 2d list organized by
+ [step][sequence group] into [sequence group][step].
+ """
+ output_by_sequence_group = [[] for _ in range(num_seq_groups)]
+ for step in sampler_outputs:
+ for i, sequence_group_output in enumerate(step):
+ output_by_sequence_group[i].append(sequence_group_output)
+
+ return output_by_sequence_group
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 426e2c41d8427..f925a6fc93dcd 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -74,7 +74,8 @@ def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+ blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index 8cc04c5299ca1..1839b5603ff3e 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -72,8 +72,9 @@ def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
- """Executes one model step on the given sequences."""
+ blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int) -> List[SamplerOutput]:
+ """Executes at least one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index 3a9537effe6d9..6e4a765e2ffd5 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -13,13 +13,17 @@
class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
- assert (not self.speculative_config
- ), "Speculative decoding not yet supported for GPU backend"
+ """Initialize the worker and load the model.
- # Instantiate the worker and load the model to GPU.
- self._init_worker()
+ If speculative decoding is enabled, we instead create the speculative
+ worker.
+ """
+ if self.speculative_config is None:
+ self._init_non_spec_worker()
+ else:
+ self._init_spec_worker()
- def _init_worker(self):
+ def _init_non_spec_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
@@ -46,6 +50,57 @@ def _init_worker(self):
self.driver_worker.init_device()
self.driver_worker.load_model()
+ def _init_spec_worker(self):
+ """Initialize a SpecDecodeWorker, using a draft model for proposals.
+ """
+ assert self.speculative_config is not None
+
+ from vllm.spec_decode.multi_step_worker import MultiStepWorker
+ from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
+ from vllm.worker.worker import Worker
+
+ distributed_init_method = get_distributed_init_method(
+ get_ip(), get_open_port())
+
+ target_worker = Worker(
+ model_config=self.model_config,
+ parallel_config=self.parallel_config,
+ scheduler_config=self.scheduler_config,
+ device_config=self.device_config,
+ cache_config=self.cache_config,
+ local_rank=0,
+ rank=0,
+ distributed_init_method=distributed_init_method,
+ lora_config=self.lora_config,
+ vision_language_config=self.vision_language_config,
+ is_driver_worker=True,
+ )
+
+ draft_worker = MultiStepWorker(
+ model_config=self.speculative_config.draft_model_config,
+ parallel_config=self.speculative_config.draft_parallel_config,
+ scheduler_config=self.scheduler_config,
+ device_config=self.device_config,
+ cache_config=self.cache_config,
+ local_rank=0,
+ rank=0,
+ distributed_init_method=distributed_init_method,
+ lora_config=self.lora_config,
+ vision_language_config=self.vision_language_config,
+ is_driver_worker=True,
+ )
+
+ spec_decode_worker = SpecDecodeWorker.from_workers(
+ proposer_worker=draft_worker, scorer_worker=target_worker)
+
+ assert self.parallel_config.world_size == 1, (
+ "GPUExecutor only supports single GPU.")
+
+ self.driver_worker = spec_decode_worker
+
+ # Load model handled in spec decode worker.
+ self.driver_worker.init_device()
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
@@ -63,16 +118,20 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
- def execute_model(self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Dict[int, int],
- blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+ def execute_model(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ blocks_to_swap_in: Dict[int, int],
+ blocks_to_swap_out: Dict[int, int],
+ blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int,
+ ) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
+ num_lookahead_slots=num_lookahead_slots,
)
return output
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index 273b17a927efd..7cc187e297c9f 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -48,10 +48,13 @@ def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+ blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int) -> List[SamplerOutput]:
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.")
+ assert num_lookahead_slots == 0, (
+ "lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list)
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 4065c0868d79a..5f859fdc9c078 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -242,7 +242,8 @@ def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+ blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int = 0) -> SamplerOutput:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
diff --git a/vllm/sequence.py b/vllm/sequence.py
index dcde81df19923..92362a9a5d2a3 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -693,3 +693,16 @@ def __len__(self):
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
+
+ def __repr__(self) -> str:
+ """Show the shape of a tensor instead of its values to reduce noise.
+ """
+ sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
+ else self.sampled_token_probs.shape)
+ sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
+ self.sampled_token_ids.shape)
+ return (
+ f"SamplerOutput(outputs={self.outputs}, "
+ f"sampled_token_probs={sampled_token_probs_repr}, "
+ f"sampled_token_ids={sampled_token_ids_repr}, "
+ f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py
index e0b75837e8a39..88af1dd360155 100644
--- a/vllm/spec_decode/batch_expansion.py
+++ b/vllm/spec_decode/batch_expansion.py
@@ -6,10 +6,10 @@
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
-from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
- sampler_output_to_torch,
+from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
+ nvtx_range, sampler_output_to_torch,
split_batch_by_proposal_len)
-from vllm.worker.worker import Worker
+from vllm.worker.worker_base import WorkerBase
SeqId = int
TargetSeqId = int
@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
"""
- def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
+ def __init__(self, scorer_worker: WorkerBase, device: str,
+ vocab_size: int):
self._scorer_worker = scorer_worker
self._device = device
self._vocab_size = vocab_size
@@ -83,7 +84,9 @@ def score_proposals(
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
- return_python_output=False)
+ )
+ assert len(target_sampler_output) == 1, "expected single-step output"
+ target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
original_bs=len(seq_group_metadata_list),
@@ -142,6 +145,16 @@ def _contract_batch(self, original_bs: int,
This maps the scores of speculative tokens back to their original
sequences.
"""
+
+ # We mock the device tensors until PR 7/9 is merged (e2e correctness).
+ # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
+ maybe_mock_device_tensors(
+ sampler_output=target_sampler_output,
+ batch_size=len(non_spec_indices) + num_scoring_tokens,
+ vocab_size=self._vocab_size,
+ device=self._device,
+ )
+
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py
index 73b6e201c67a9..ce63c329a40aa 100644
--- a/vllm/spec_decode/multi_step_worker.py
+++ b/vllm/spec_decode/multi_step_worker.py
@@ -6,7 +6,8 @@
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
-from vllm.spec_decode.util import sampler_output_to_torch
+from vllm.spec_decode.util import (maybe_mock_device_tensors,
+ sampler_output_to_torch)
from vllm.worker.worker import Worker
@@ -69,6 +70,9 @@ def execute_model_multi_step(
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
+ assert (len(model_output) == 1
+ ), "composing multistep workers not supported"
+ model_output = model_output[0]
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
@@ -341,6 +345,16 @@ def _merge_outputs(
sampler_output = maybe_sampler_output
+ # We mock the device tensors until PR 7/9 is merged (e2e correctness).
+ # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
+ for step_output in sampler_output:
+ maybe_mock_device_tensors(
+ sampler_output=step_output,
+ batch_size=len(proposal_lens),
+ vocab_size=self._vocab_size,
+ device=self._device,
+ )
+
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py
index 885bf537568e3..be3af7be93864 100644
--- a/vllm/spec_decode/spec_decode_worker.py
+++ b/vllm/spec_decode/spec_decode_worker.py
@@ -3,8 +3,9 @@
import torch
+from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
-from vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
+from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceGroupOutput, SequenceOutput)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals,
@@ -13,8 +14,9 @@
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
split_batch_by_proposal_len)
-from vllm.worker.worker import Worker
-from vllm.worker.worker_base import LoraNotSupportedWorkerBase
+from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
+
+logger = init_logger(__name__)
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
+ @classmethod
+ def from_workers(cls, proposer_worker: MultiStepWorker,
+ scorer_worker: WorkerBase) -> "SpecDecodeWorker":
+ return SpecDecodeWorker(
+ proposer_worker,
+ scorer_worker,
+ # TODO(cade) disable strict mode for speedup.
+ rejection_sampler=RejectionSampler(strict_mode=True),
+ )
+
def __init__(
self,
proposer_worker: MultiStepWorker,
- scorer_worker: Worker,
+ scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
):
@@ -87,6 +99,10 @@ def init_device(self) -> None:
self.scorer_worker.init_device()
self.proposer_worker.init_device()
+ # NOTE(cade): load_model is not part of the WorkerBase interface.
+ self.scorer_worker.load_model()
+ self.proposer_worker.load_model()
+
self._metrics.init_gpu_tensors(self.rank)
self.rejection_sampler.init_gpu_tensors(self.rank)
self.scorer = BatchExpansionTop1Scorer(
@@ -131,7 +147,7 @@ def execute_model(
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
- num_spec_tokens: int,
+ num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
@@ -140,9 +156,11 @@ def execute_model(
"speculative decoding "
"requires non-None seq_group_metadata_list")
+ logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
+
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
- if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
+ if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
return self._run_no_spec(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
@@ -155,7 +173,7 @@ def execute_model(
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
- k=num_spec_tokens,
+ k=num_lookahead_slots,
)
@nvtx_range("spec_decode_worker._run_no_spec")
@@ -170,20 +188,24 @@ def _run_no_spec(
proposer and scorer model so that the KV cache is consistent between the
two.
"""
+ logger.info("run proposer worker no spec")
self.proposer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
- return_python_output=False)
+ )
+ logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
+ assert len(sampler_output) == 1
+ sampler_output = sampler_output[0]
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
@@ -209,11 +231,13 @@ def _run_speculative_decoding_step(
sequence.
"""
+ logger.info("get spec proposals")
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
+ logger.info("score proposals")
proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list,
blocks_to_swap_in,
@@ -223,9 +247,11 @@ def _run_speculative_decoding_step(
proposals,
)
+ logger.info("verify proposals")
accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
proposal_scores, proposals, k)
+ logger.info("create output list")
return self._create_output_sampler_list(seq_group_metadata_list,
accepted_token_ids, k)
@@ -311,7 +337,7 @@ def _create_output_sampler_list(
parent_seq_id=seq_id,
output_token=token_id,
# TODO Add verifier logprobs.
- logprobs={token_id: 0.0},
+ logprobs={token_id: Logprob(0.0)},
)
],
prompt_logprobs=None,
diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py
index 406568a4bc08c..eb6d4ca1da8e6 100644
--- a/vllm/spec_decode/util.py
+++ b/vllm/spec_decode/util.py
@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return sampled_token_ids, sampled_token_probs
+def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
+ vocab_size: int, device: str) -> None:
+ """Helper method which mocks out the GPU tensors in SamplerOutput with dummy
+ values. This will be removed in PR 7/9.
+ https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
+ """
+ values = [
+ sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
+ ]
+ assert all(v is None for v in values) or not any(v is None for v in values)
+ if not any(v is None for v in values):
+ # Do nothing if the tensors are already created (usually in unit tests).
+ return
+
+ # Softmax to ensure valid probs.
+ sampler_output.sampled_token_probs = torch.nn.functional.softmax(
+ torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
+ dim=-1)
+
+ sampler_output.sampled_token_ids = torch.randint(low=10,
+ high=100,
+ size=(batch_size, ),
+ dtype=torch.long,
+ device=device)
+
+
@contextmanager
def nvtx_range(msg, *args, **kwargs):
"""
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 6610b9c4be876..afc4a1e1f4630 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -251,7 +251,7 @@ def execute_model(
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
- ) -> Optional[SamplerOutput]:
+ ) -> List[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
@@ -274,11 +274,13 @@ def execute_model(
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
- return {}
+ return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache)
- return output
+
+ # CPU worker only supports single-step execution.
+ return [output]
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py
index 2f22f82c045db..142c6c97f5194 100644
--- a/vllm/worker/neuron_worker.py
+++ b/vllm/worker/neuron_worker.py
@@ -1,5 +1,5 @@
"""A Neuron worker class."""
-from typing import List, Optional, Tuple
+from typing import List, Tuple
import torch
import torch.distributed
@@ -73,15 +73,18 @@ def initialize_cache(self, num_gpu_blocks: int,
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Optional[SamplerOutput]:
+ ) -> List[SamplerOutput]:
num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
- return {}
+ return []
output = self.model_runner.execute_model(seq_group_metadata_list)
- return output
+
+ # Neuron worker only supports single-step output. Wrap the output in a
+ # list to conform to interface.
+ return [output]
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 6a79285f60579..e2b47530d41e4 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -210,7 +210,9 @@ def execute_model(
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
- ) -> Optional[SamplerOutput]:
+ num_lookahead_slots: int = 0,
+ ) -> List[SamplerOutput]:
+
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
@@ -235,11 +237,14 @@ def execute_model(
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
- return {}
+ return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
- return output
+
+ # Worker only supports single-step execution. Wrap the output in a list
+ # to conform to interface.
+ return [output]
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index d8c9febb11584..a92f5aea76059 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -40,12 +40,13 @@ def initialize_cache(self, num_gpu_blocks: int,
raise NotImplementedError
@abstractmethod
- def execute_model(self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Dict[int, int],
- blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
- """Executes one model step on the given sequences."""
+ def execute_model(
+ self, seq_group_metadata_list: List[SequenceGroupMetadata],
+ blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
+ int],
+ blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
+ """Executes at least one model step on the given sequences, unless no
+ sequences are provided."""
raise NotImplementedError
@abstractmethod
From d150e4f89f9da5b093511227449ad940a9a82a52 Mon Sep 17 00:00:00 2001
From: Cade Daniel
Date: Tue, 16 Apr 2024 17:56:01 -0700
Subject: [PATCH 054/413] [Misc] [CI] Fix CI failure caught after merge (#4126)
---
vllm/executor/gpu_executor.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index 6e4a765e2ffd5..77c997f97956e 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -68,6 +68,7 @@ def _init_spec_worker(self):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
+ load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
@@ -82,6 +83,7 @@ def _init_spec_worker(self):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
+ load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
From 11d652bd4f6f81d09638399885099b78a4e3b9c8 Mon Sep 17 00:00:00 2001
From: Cade Daniel
Date: Tue, 16 Apr 2024 22:53:26 -0700
Subject: [PATCH 055/413] [CI] Move CPU/AMD tests to after wait (#4123)
---
.buildkite/test-template.j2 | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2
index 3ed23c62c005d..0e1acc9777d4b 100644
--- a/.buildkite/test-template.j2
+++ b/.buildkite/test-template.j2
@@ -3,13 +3,6 @@
{% set default_working_dir = "/vllm-workspace/tests" %}
steps:
- - label: "AMD Test"
- agents:
- queue: amd
- command: bash .buildkite/run-amd-test.sh
-
- - label: "CPU Test"
- command: bash .buildkite/run-cpu-test.sh
- label: ":docker: build image"
commands:
@@ -23,6 +16,14 @@ steps:
limit: 5
- wait
+ - label: "AMD Test"
+ agents:
+ queue: amd
+ command: bash .buildkite/run-amd-test.sh
+
+ - label: "CPU Test"
+ command: bash .buildkite/run-cpu-test.sh
+
{% for step in steps %}
- label: "{{ step.label }}"
agents:
From 8438e0569eaf8496aa3d41deb808f2c831b64ecf Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 17 Apr 2024 01:34:33 -0700
Subject: [PATCH 056/413] [Core] RayWorkerVllm --> WorkerWrapper to reduce
duplication (#4024)
[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024)
---
tests/distributed/test_pynccl.py | 7 +-
vllm/engine/ray_utils.py | 44 ++-------
vllm/executor/ray_gpu_executor.py | 156 ++++++++++++++++--------------
vllm/utils.py | 16 ++-
vllm/worker/cpu_worker.py | 5 +-
vllm/worker/neuron_worker.py | 4 +
vllm/worker/worker.py | 4 +
vllm/worker/worker_base.py | 56 +++++++++++
8 files changed, 176 insertions(+), 116 deletions(-)
diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py
index b50eed1c8c722..d58f621d36b86 100644
--- a/tests/distributed/test_pynccl.py
+++ b/tests/distributed/test_pynccl.py
@@ -1,18 +1,18 @@
import multiprocessing
-import os
import pytest
import torch
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
+from vllm.utils import update_environment_variables
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
- env = os.environ.copy()
+ env = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
@@ -32,8 +32,7 @@ def update_env(fn):
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
- import os
- os.environ.update(env)
+ update_environment_variables(env)
fn()
return wrapper
diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py
index 04d4ed83976d0..febae42b84549 100644
--- a/vllm/engine/ray_utils.py
+++ b/vllm/engine/ray_utils.py
@@ -1,55 +1,28 @@
import pickle
-from typing import Callable, List, Optional, Tuple
+from typing import List, Optional, Tuple
from vllm.config import ParallelConfig
from vllm.logger import init_logger
-from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
-from vllm.worker.worker import Worker
+from vllm.utils import get_ip, is_hip
+from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
try:
import ray
- class RayWorkerVllm:
+ class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
- def __init__(self, init_cached_hf_modules=False) -> None:
- if init_cached_hf_modules:
- from transformers.dynamic_module_utils import init_hf_modules
- init_hf_modules()
- self._worker: Optional[Worker] = None
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
- def init_worker(self, worker_init_fn: Callable[[], Worker]):
- self._worker = worker_init_fn()
-
- @property
- def worker(self) -> Worker:
- assert self._worker is not None
- return self._worker
-
- def __getattr__(self, name):
- return getattr(self.worker, name)
-
- def execute_method(self, method, *args, **kwargs):
- try:
- executor = getattr(self, method)
- return executor(*args, **kwargs)
- except Exception as e:
- # exceptions in ray worker may cause deadlock
- # see https://github.com/vllm-project/vllm/issues/3455
- # print the error and inform the user to solve the error
- msg = (f"Error executing method {method}. "
- "This might cause deadlock in distributed execution.")
- logger.exception(msg)
- raise e
-
def get_node_ip(self) -> str:
return get_ip()
@@ -58,9 +31,6 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
- def set_cuda_visible_devices(self, device_ids) -> None:
- set_cuda_visible_devices(device_ids)
-
def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
@@ -77,7 +47,7 @@ def execute_model_compiled_dag_remote(self, ignored):
"For distributed inference, please install Ray with "
"`pip install ray`.")
ray = None # type: ignore
- RayWorkerVllm = None # type: ignore
+ RayWorkerWrapper = None # type: ignore
def initialize_ray_cluster(
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 5f859fdc9c078..5a43f1fc28a84 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -1,17 +1,16 @@
import asyncio
-import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
-from vllm.engine.ray_utils import RayWorkerVllm, ray
+from vllm.engine.ray_utils import RayWorkerWrapper, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
- make_async, set_cuda_visible_devices)
+ make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -74,9 +73,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
- self.driver_dummy_worker: RayWorkerVllm = None
+ self.driver_dummy_worker: RayWorkerWrapper = None
# The remaining workers are the actual ray actors.
- self.workers: List[RayWorkerVllm] = []
+ self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
@@ -97,13 +96,20 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
- )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
+ )(RayWorkerWrapper).remote(
+ worker_module_name="vllm.worker.worker",
+ worker_class_name="Worker",
+ )
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
+ self.driver_worker = RayWorkerWrapper(
+ worker_module_name="vllm.worker.worker",
+ worker_class_name="Worker",
+ )
else:
# Else, added to the list of workers.
self.workers.append(worker)
@@ -115,82 +121,56 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
"GPU node.")
# Get the set of GPU IDs used on each node.
- driver_node_id, driver_gpu_ids = ray.get(
- self.driver_dummy_worker.get_node_and_gpu_ids.remote())
- worker_node_and_gpu_ids = ray.get(
- [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
+ worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
+ use_dummy_driver=True)
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
- node_workers[driver_node_id].append(0)
- node_gpus[driver_node_id].extend(driver_gpu_ids)
- for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
- start=1):
+ for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
- set_cuda_visible_devices(node_gpus[driver_node_id])
- for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
- worker.set_cuda_visible_devices.remote(node_gpus[node_id])
+ all_args_to_update_environment_variables = []
+ for (node_id, _) in worker_node_and_gpu_ids:
+ all_args_to_update_environment_variables.append([{
+ "CUDA_VISIBLE_DEVICES":
+ ",".join(map(str, node_gpus[node_id]))
+ }])
+ self._run_workers("update_environment_variables",
+ all_args=all_args_to_update_environment_variables)
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
- # Lazy import the Worker to avoid importing torch.cuda/xformers
- # before CUDA_VISIBLE_DEVICES is set in the Worker
- from vllm.worker.worker import Worker
-
- model_config = copy.deepcopy(self.model_config)
- parallel_config = copy.deepcopy(self.parallel_config)
- scheduler_config = copy.deepcopy(self.scheduler_config)
- load_config = copy.deepcopy(self.load_config)
- device_config = copy.deepcopy(self.device_config)
- lora_config = copy.deepcopy(self.lora_config)
- cache_config = copy.deepcopy(self.cache_config)
- vision_language_config = copy.deepcopy(self.vision_language_config)
-
- # Initialize the actual workers with the Worker class.
- for rank, (worker, (node_id, _)) in enumerate(
- zip(self.workers, worker_node_and_gpu_ids),
- start=1,
- ):
+ def collect_arg_helper_func(**kwargs):
+ # avoid writing `{"name": value}` manually
+ return kwargs
+
+ init_worker_all_kwargs = []
+
+ # Initialize the actual workers inside worker wrapper.
+ for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
local_rank = node_workers[node_id].index(rank)
- worker.init_worker.remote(
- lambda rank=rank, local_rank=local_rank: Worker(
- model_config=model_config,
- parallel_config=parallel_config,
- scheduler_config=scheduler_config,
- device_config=device_config,
- cache_config=cache_config,
- load_config=load_config,
+ init_worker_all_kwargs.append(
+ collect_arg_helper_func(
+ model_config=self.model_config,
+ parallel_config=self.parallel_config,
+ scheduler_config=self.scheduler_config,
+ device_config=self.device_config,
+ cache_config=self.cache_config,
+ load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
- lora_config=lora_config,
- vision_language_config=vision_language_config,
+ lora_config=self.lora_config,
+ vision_language_config=self.vision_language_config,
+ is_driver_worker=rank == 0,
))
-
- # Initialize the driver worker with the Worker class.
- driver_rank = 0
- driver_local_rank = node_workers[driver_node_id].index(driver_rank)
- self.driver_worker = Worker(
- model_config=self.model_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config,
- device_config=self.device_config,
- cache_config=self.cache_config,
- local_rank=driver_local_rank,
- rank=driver_rank,
- distributed_init_method=distributed_init_method,
- lora_config=self.lora_config,
- vision_language_config=self.vision_language_config,
- load_config=self.load_config,
- is_driver_worker=True,
- )
+ self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
self._run_workers(
@@ -279,13 +259,35 @@ def _run_workers(
self,
method: str,
*args,
- driver_args: Optional[Tuple[Any, ...]] = None,
+ driver_args: Optional[Tuple[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
+ all_args: Optional[List[List[Any]]] = None,
+ all_kwargs: Optional[List[Dict[str, Any]]] = None,
+ use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
- """Runs the given method on all workers."""
+ """Runs the given method on all workers.
+ all_args and all_kwargs are used to pass heterogeneous arguments,
+ i.e. different arguments for each worker.
+ """
+ if driver_args is None:
+ driver_args = args
+ if driver_kwargs is None:
+ driver_kwargs = kwargs
+
+ # for mypy type checking
+ assert driver_args is not None
+ assert driver_kwargs is not None
+ if all_args is None:
+ all_args = [driver_args] + [args] * len(self.workers)
+ if all_kwargs is None:
+ all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
+
+ # for mypy type checking
+ assert all_args is not None
+ assert all_kwargs is not None
if max_concurrent_workers:
raise NotImplementedError(
@@ -299,8 +301,10 @@ def _run_workers(
else:
# Start the ray workers first.
ray_worker_outputs = [
- worker.execute_method.remote(method, *args, **kwargs)
- for worker in self.workers
+ worker.execute_method.remote(method, *worker_args,
+ **worker_kwargs)
+ for (worker, worker_args, worker_kwargs
+ ) in zip(self.workers, all_args[1:], all_kwargs[1:])
]
if driver_args is None:
@@ -309,9 +313,13 @@ def _run_workers(
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
- driver_worker_output = getattr(self.driver_worker,
- method)(*driver_args, **driver_kwargs)
-
+ if not use_dummy_driver:
+ driver_worker_output = self.driver_worker.execute_method(
+ method, *all_args[0], **all_kwargs[0])
+ else:
+ driver_worker_output = ray.get(
+ self.driver_dummy_worker.execute_method.remote(
+ method, *all_args[0], **all_kwargs[0]))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
@@ -386,8 +394,12 @@ async def _run_workers_async(
driver_kwargs = kwargs
# Run the driver worker asynchronously.
- driver_executor = make_async(getattr(self.driver_worker, method))
- coros.append(driver_executor(*driver_args, **driver_kwargs))
+ def helper():
+ return self.driver_worker.execute_method(method, *driver_args,
+ **driver_kwargs)
+
+ driver_executor = make_async(helper)
+ coros.append(driver_executor())
# Run the ray workers asynchronously.
for worker in self.workers:
diff --git a/vllm/utils.py b/vllm/utils.py
index aad62516ad1b9..e132575e7bf81 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -271,8 +271,12 @@ def get_open_port() -> int:
return s.getsockname()[1]
-def set_cuda_visible_devices(device_ids: List[int]) -> None:
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
+def update_environment_variables(envs: Dict[str, str]):
+ for k, v in envs.items():
+ if k in os.environ:
+ logger.warning(f"Overwriting environment variable {k} "
+ f"from '{os.environ[k]}' to '{v}'")
+ os.environ[k] = v
def chunk_list(lst, chunk_size):
@@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
merged_dict[key].extend(value)
return dict(merged_dict)
+
+
+def init_cached_hf_modules():
+ """
+ Lazy initialization of the Hugging Face modules.
+ """
+ from transformers.dynamic_module_utils import init_hf_modules
+ init_hf_modules()
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index afc4a1e1f4630..8468ace5a2fdc 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -138,7 +138,10 @@ def __init__(
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
-
+ if self.model_config.trust_remote_code:
+ # note: lazy import to avoid importing torch before initializing
+ from vllm.utils import init_cached_hf_modules
+ init_cached_hf_modules()
self.model_runner = CPUModelRunner(model_config,
parallel_config,
scheduler_config,
diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py
index 142c6c97f5194..d0e6aaed180e6 100644
--- a/vllm/worker/neuron_worker.py
+++ b/vllm/worker/neuron_worker.py
@@ -29,6 +29,10 @@ def __init__(
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
+ if self.model_config.trust_remote_code:
+ # note: lazy import to avoid importing torch before initializing
+ from vllm.utils import init_cached_hf_modules
+ init_cached_hf_modules()
self.model_runner = NeuronModelRunner(model_config, parallel_config,
scheduler_config, device_config)
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index e2b47530d41e4..b021866965401 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -60,6 +60,10 @@ def __init__(
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
+ if self.model_config.trust_remote_code:
+ # note: lazy import to avoid importing torch before initializing
+ from vllm.utils import init_cached_hf_modules
+ init_cached_hf_modules()
self.vision_language_config = vision_language_config
if self.vision_language_config:
assert not self.lora_config, (
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index a92f5aea76059..309aa6256acea 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -1,8 +1,14 @@
+import importlib
+import os
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
+from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
+from vllm.utils import update_environment_variables
+
+logger = init_logger(__name__)
class WorkerBase(ABC):
@@ -82,3 +88,53 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> List[int]:
raise ValueError(f"{type(self)} does not support LoRA")
+
+
+class WorkerWrapperBase:
+ """
+ The whole point of this class is to lazily initialize the worker.
+ We first instantiate the WorkerWrapper, which remembers the worker module
+ and class name. Then, when we call `update_environment_variables`, and the
+ real initialization happens in `init_worker`.
+ """
+
+ def __init__(self,
+ worker_module_name=None,
+ worker_class_name=None) -> None:
+ self.worker_module_name = worker_module_name
+ self.worker_class_name = worker_class_name
+ self.worker = None
+
+ def update_environment_variables(self, envs: Dict[str, str]) -> None:
+ key = 'CUDA_VISIBLE_DEVICES'
+ if key in envs and key in os.environ:
+ # overwriting CUDA_VISIBLE_DEVICES is desired behavior
+ # suppress the warning in `update_environment_variables`
+ del os.environ[key]
+ update_environment_variables(envs)
+
+ def init_worker(self, *args, **kwargs):
+ """
+ Actual initialization of the worker class.
+ Arguments are passed to the worker class constructor.
+ """
+ mod = importlib.import_module(self.worker_module_name)
+ worker_class = getattr(mod, self.worker_class_name)
+ self.worker = worker_class(*args, **kwargs)
+
+ def execute_method(self, method, *args, **kwargs):
+ try:
+ if hasattr(self, method):
+ executor = getattr(self, method)
+ else:
+ executor = getattr(self.worker, method)
+ return executor(*args, **kwargs)
+ except Exception as e:
+ # if the driver worker also execute methods,
+ # exceptions in the rest worker may cause deadlock in rpc like ray
+ # see https://github.com/vllm-project/vllm/issues/3455
+ # print the error and inform the user to solve the error
+ msg = (f"Error executing method {method}. "
+ "This might cause deadlock in distributed execution.")
+ logger.exception(msg)
+ raise e
From fe3b5bbc23a99533bc7d4a94ae073828ed025974 Mon Sep 17 00:00:00 2001
From: Elinx
Date: Wed, 17 Apr 2024 19:07:23 +0800
Subject: [PATCH 057/413] [Bugfix] fix output parsing error for trtllm backend
(#4137)
Co-authored-by: Roger Wang
---
benchmarks/backend_request_func.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index bab570252c929..f9d167590fe47 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -135,6 +135,7 @@ async def async_request_trt_llm(
"data:")
data = json.loads(chunk)
+ output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
@@ -149,7 +150,6 @@ async def async_request_trt_llm(
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
- output.generated_text = json.loads(data)["text_output"]
output.success = True
else:
From a53222544c6385ee314a26fdf42eb14f5b4e5ad9 Mon Sep 17 00:00:00 2001
From: Shoichi Uchinami
Date: Thu, 18 Apr 2024 02:02:45 +0900
Subject: [PATCH 058/413] [Kernel] Add punica dimension for Swallow-MS-7B LoRA
(#4134)
---
csrc/punica/bgmv/bgmv_config.h | 1 +
tests/lora/test_punica.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h
index d2906914f927e..fec484d693055 100644
--- a/csrc/punica/bgmv/bgmv_config.h
+++ b/csrc/punica/bgmv/bgmv_config.h
@@ -60,6 +60,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
+ f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py
index 8b174f01d87d4..f3b9bd5912967 100644
--- a/tests/lora/test_punica.py
+++ b/tests/lora/test_punica.py
@@ -82,6 +82,7 @@ def _lora_ref_impl(
32768,
33024,
36864,
+ 43264,
49152,
64000,
64256,
From 533d2a1f3962c45ddebbdeb0ff1cb7cd54d5e771 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Thu, 18 Apr 2024 09:28:43 +0900
Subject: [PATCH 059/413] [Typing] Mypy typing part 2 (#4043)
Co-authored-by: SangBin Cho
---
.github/workflows/mypy.yaml | 8 +-
format.sh | 8 +-
vllm/engine/async_llm_engine.py | 44 +++++----
vllm/lora/worker_manager.py | 4 +-
.../guided_decoding/outlines_decoding.py | 4 +-
.../outlines_logits_processors.py | 6 +-
vllm/model_executor/model_loader/neuron.py | 16 ++--
.../model_executor/model_loader/tensorizer.py | 1 +
vllm/model_executor/sampling_metadata.py | 4 +
vllm/spec_decode/batch_expansion.py | 6 +-
vllm/spec_decode/interfaces.py | 4 +-
vllm/spec_decode/metrics.py | 1 +
vllm/spec_decode/multi_step_worker.py | 21 ++--
vllm/spec_decode/spec_decode_worker.py | 6 +-
vllm/worker/cpu_model_runner.py | 25 +++--
vllm/worker/cpu_worker.py | 11 ++-
vllm/worker/model_runner.py | 96 +++++++++++--------
vllm/worker/neuron_model_runner.py | 22 +++--
vllm/worker/worker.py | 11 ++-
vllm/worker/worker_base.py | 8 +-
20 files changed, 180 insertions(+), 126 deletions(-)
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
index 6db0bb7645ecd..477ce9bc9ce85 100644
--- a/.github/workflows/mypy.yaml
+++ b/.github/workflows/mypy.yaml
@@ -41,10 +41,10 @@ jobs:
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
- # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
- # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
- # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
- # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
diff --git a/format.sh b/format.sh
index 1c195b899c742..84ee88b5b4c8a 100755
--- a/format.sh
+++ b/format.sh
@@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
-# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
-# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
-# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
-# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 27192449bf15a..c3020d2b38db0 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -2,8 +2,8 @@
import os
import time
from functools import partial
-from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
- Set, Tuple, Type, Union)
+from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
+ Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@@ -52,7 +52,7 @@ class AsyncStream:
def __init__(self, request_id: str) -> None:
self.request_id = request_id
- self._queue = asyncio.Queue()
+ self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
@@ -312,15 +312,17 @@ def __init__(self,
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
- self.background_loop = None
+ self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
- self._background_loop_unshielded = None
+ self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop
- self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
+ # Lazy initialized fields
+ self._request_tracker: RequestTracker
+
@classmethod
def from_engine_args(
cls,
@@ -361,11 +363,13 @@ def from_engine_args(
@property
def is_running(self) -> bool:
return (self.background_loop is not None
+ and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
- return self.errored or (self.background_loop is not None
+ return self.errored or (self.background_loop is not None and
+ self._background_loop_unshielded is not None
and self._background_loop_unshielded.done())
@property
@@ -381,7 +385,7 @@ def _error_callback(self, exc: Exception) -> None:
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
- return await self.engine.get_tokenizer.remote()
+ return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
@@ -434,7 +438,8 @@ async def engine_step(self) -> bool:
# TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray:
- await self.engine.add_request.remote(**new_request)
+ await self.engine.add_request.remote( # type: ignore
+ **new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
@@ -449,7 +454,7 @@ async def engine_step(self) -> bool:
await self._engine_abort(finished_requests)
if self.engine_use_ray:
- request_outputs = await self.engine.step.remote()
+ request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
@@ -462,7 +467,7 @@ async def engine_step(self) -> bool:
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
- await self.engine.abort_request.remote(request_ids)
+ await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)
@@ -525,11 +530,12 @@ async def add_request(
arrival_time = time.time()
if self.engine_use_ray:
- prompt_token_ids = await self.engine.encode_request_async.remote(
- request_id=request_id,
- prompt=prompt,
- prompt_token_ids=prompt_token_ids,
- lora_request=lora_request)
+ prompt_token_ids = await (
+ self.engine.encode_request_async.remote( # type: ignore
+ request_id=request_id,
+ prompt=prompt,
+ prompt_token_ids=prompt_token_ids,
+ lora_request=lora_request))
else:
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
@@ -676,13 +682,13 @@ def _abort(self, request_id: str) -> None:
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
- return await self.engine.get_model_config.remote()
+ return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()
async def do_log_stats(self) -> None:
if self.engine_use_ray:
- await self.engine.do_log_stats.remote()
+ await self.engine.do_log_stats.remote() # type: ignore
else:
self.engine.do_log_stats()
@@ -695,7 +701,7 @@ async def check_health(self) -> None:
if self.engine_use_ray:
try:
- await self.engine.check_health.remote()
+ await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py
index a0868defbd3ca..5356b79537b05 100644
--- a/vllm/lora/worker_manager.py
+++ b/vllm/lora/worker_manager.py
@@ -107,12 +107,12 @@ def create_lora_manager(
self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model
- def set_active_loras(self, lora_requests: List[LoRARequest],
+ def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)
- def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
+ def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py
index bd4564a36e1ed..53efebb604048 100644
--- a/vllm/model_executor/guided_decoding/outlines_decoding.py
+++ b/vllm/model_executor/guided_decoding/outlines_decoding.py
@@ -55,7 +55,7 @@ class GuidedDecodingMode(Enum):
async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
- tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
+ tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
@@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
-) -> Tuple[str, GuidedDecodingMode]:
+) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
if request.guided_json:
json = request.guided_json
diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
index 28041695546dc..95a67b612f08b 100644
--- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -21,7 +21,7 @@
from typing import Callable, DefaultDict, Dict, List, Optional, Union
import torch
-from outlines.fsm.fsm import CFGFSM, RegexFSM
+from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
@@ -29,6 +29,10 @@
class BaseLogitsProcessor:
+ def __init__(self):
+ # Child class should use initialize in their init.
+ self.fsm: FSM
+
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py
index 43d17ad373b87..07e23aca6cc5f 100644
--- a/vllm/model_executor/model_loader/neuron.py
+++ b/vllm/model_executor/model_loader/neuron.py
@@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models."""
import importlib
import os
-from typing import Optional, Type
+from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
@@ -27,7 +27,7 @@
}
# Models supported by Neuron.
-_NEURON_SUPPORTED_MODELS = {
+_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
@@ -43,11 +43,13 @@ def __init__(
) -> None:
super().__init__()
self.config = config
- self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
+ # Lazy initialized
+ self.model: nn.Module
+
def forward(
self,
input_ids: torch.Tensor,
@@ -74,17 +76,17 @@ def sample(
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
- neuronx_module_path, neuronx_model_cls, hf_model_cls = (
+ neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
- neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
+ neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
- hf_model_cls = getattr(transformers, hf_model_cls)
+ hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
@@ -96,7 +98,7 @@ def load_weights(self, model_name_or_path: str, **kwargs):
self.model.to_neuron()
-def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py
index ad554844384eb..16be0ecf9ce07 100644
--- a/vllm/model_executor/model_loader/tensorizer.py
+++ b/vllm/model_executor/model_loader/tensorizer.py
@@ -167,6 +167,7 @@ def __post_init__(self):
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params
+ @staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments"""
diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py
index 534cb75c2fd2f..31032c4cead20 100644
--- a/vllm/model_executor/sampling_metadata.py
+++ b/vllm/model_executor/sampling_metadata.py
@@ -113,6 +113,8 @@ def from_sampling_metadata(
get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0
+ assert sampling_metadata.seq_groups is not None
+ assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
@@ -147,6 +149,7 @@ def from_sampling_metadata(
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
+ assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
@@ -172,6 +175,7 @@ def from_sampling_metadata(
is_prompt = i < sampling_metadata.num_prompts
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
+ assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py
index 88af1dd360155..bbc5b1778854f 100644
--- a/vllm/spec_decode/batch_expansion.py
+++ b/vllm/spec_decode/batch_expansion.py
@@ -106,7 +106,7 @@ def score_proposals(
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
- proposal_token_ids_list: List[TokenId],
+ proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
@@ -218,7 +218,7 @@ def _create_scoring_model_input(
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
- proposal_token_ids: List[TokenId], # shape: [batch_size, k]
+ proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
@@ -360,7 +360,7 @@ def _get_token_ids_to_score(
[0, 1, 2]
[0, 1, 2, 3]
"""
- empty_token_ids = []
+ empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py
index 2a72974d01bdc..f0715120192e5 100644
--- a/vllm/spec_decode/interfaces.py
+++ b/vllm/spec_decode/interfaces.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
import torch
@@ -73,5 +73,5 @@ def score_proposals(
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ) -> SpeculativeScores:
raise NotImplementedError
diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py
index 5df8fc4316d48..d1e72b6640548 100644
--- a/vllm/spec_decode/metrics.py
+++ b/vllm/spec_decode/metrics.py
@@ -112,6 +112,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
Returns a CUDA event recording when the copy is complete.
"""
+ assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._copy_stream):
diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py
index ce63c329a40aa..8b722476853fa 100644
--- a/vllm/spec_decode/multi_step_worker.py
+++ b/vllm/spec_decode/multi_step_worker.py
@@ -26,7 +26,8 @@ class MultiStepWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._proposer: Optional[DraftModelTop1Proposer] = None
+ # Lazy initialization list.
+ self._proposer: DraftModelTop1Proposer
def init_device(self):
super().init_device()
@@ -338,10 +339,10 @@ def _merge_outputs(
self._vocab_size,
dtype=torch.float32,
device=self._device)
- proposal_lens = torch.zeros(len(proposal_lens),
- dtype=torch.long,
- device=self._device)
- return proposal_tokens, proposal_probs, proposal_lens
+ proposal_lens_tensor = torch.zeros(len(proposal_lens),
+ dtype=torch.long,
+ device=self._device)
+ return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
@@ -376,9 +377,9 @@ def _merge_outputs(
proposal_tokens, proposal_probs = (entire_proposal_tokens,
entire_proposal_probs)
- proposal_lens = torch.zeros(batch_size,
- dtype=torch.long,
- device=self._device)
- proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
+ proposal_lens_tensor = torch.zeros(batch_size,
+ dtype=torch.long,
+ device=self._device)
+ proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len
- return proposal_tokens, proposal_probs, proposal_lens
+ return proposal_tokens, proposal_probs, proposal_lens_tensor
diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py
index be3af7be93864..68a2a774ef4b7 100644
--- a/vllm/spec_decode/spec_decode_worker.py
+++ b/vllm/spec_decode/spec_decode_worker.py
@@ -89,7 +89,8 @@ def __init__(
self.probs_dtype = self.rejection_sampler.probs_dtype
self.token_id_dtype = self.rejection_sampler.token_id_dtype
- self.scorer: SpeculativeScorer = None
+ # Lazy initiazliation.
+ self.scorer: SpeculativeScorer
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@@ -233,6 +234,9 @@ def _run_speculative_decoding_step(
logger.info("get spec proposals")
# Generate proposals using draft worker.
+ assert blocks_to_swap_in is not None
+ assert blocks_to_swap_out is not None
+ assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index d378e3a90e1e7..7377c8931cefa 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple
import torch
+from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
@@ -48,14 +49,15 @@ def __init__(
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
- self.model = None
- self.block_size = None # Set after initial profiling.
-
self.kv_cache_dtype = kv_cache_dtype
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
+ # Lazy initialization.
+ self.model: nn.Module # Set after init_Model
+ self.block_size: int # Set after initial profiling.
+
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
@@ -245,7 +247,11 @@ def _prepare_sample(
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
- categorized_sample_indices = {t: [] for t in SamplingType}
+ categorized_sample_indices: Dict[SamplingType,
+ List[Tuple[int, int]]] = {
+ t: []
+ for t in SamplingType
+ }
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
@@ -262,10 +268,9 @@ def _prepare_sample(
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
- sampling_params.sampling_type].append([
- categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx
- ])
+ sampling_params.sampling_type].append(
+ (categorized_sample_indices_start_idx,
+ categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
@@ -328,7 +333,7 @@ def _prepare_sample(
def prepare_input_tensors(
self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
if self.is_driver_worker:
@@ -381,7 +386,7 @@ def prepare_input_tensors(
@torch.inference_mode()
def execute_model(
self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 8468ace5a2fdc..3652830b7d519 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -1,5 +1,5 @@
"""A CPU worker class."""
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed
@@ -152,8 +152,8 @@ def __init__(
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
- self.cache_engine = None
- self.cpu_cache = None
+ self.cache_engine: CPUCacheEngine
+ self.cpu_cache: List[torch.Tensor]
def init_device(self) -> None:
self.init_distributed_environment()
@@ -257,13 +257,13 @@ def execute_model(
) -> List[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
- num_seq_groups = len(seq_group_metadata_list)
+ num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
assert len(blocks_to_swap_in) == 0
assert len(blocks_to_swap_out) == 0
- data = {
+ data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy,
}
@@ -273,6 +273,7 @@ def execute_model(
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
+ assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 42c06a1b19361..31e08789dfd1f 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -128,23 +128,17 @@ def __init__(
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
- self.model = None
- self.block_size = None # Set after initial profiling.
- self.lora_manager = None
+ # Set after load_model.
+ self.lora_manager: LRUCacheWorkerLoRAManager = None
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
- self.graph_memory_pool = None # Set during graph capture.
+ self.graph_memory_pool: Optional[Tuple[
+ int, int]] = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
- # When using CUDA graph, the input block tables must be padded to
- # max_context_len_to_capture. However, creating the block table in
- # Python can be expensive. To optimize this, we cache the block table
- # in numpy and only copy the actual input content at every iteration.
- # The shape of the cached block table will be
- # (max batch size to capture, max context len to capture / block size).
- self.graph_block_tables = None # Set after initial profiling.
+
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.vision_language_config = vision_language_config
@@ -152,6 +146,17 @@ def __init__(
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
+ # Lazy initialization
+ self.model: torch.nn.Module # Set after load_model
+ self.block_size: int # Set after initial profiling.
+ # When using CUDA graph, the input block tables must be padded to
+ # max_context_len_to_capture. However, creating the block table in
+ # Python can be expensive. To optimize this, we cache the block table
+ # in numpy and only copy the actual input content at every iteration.
+ # The shape of the cached block table will be
+ # (max batch size to capture, max context len to capture / block size).
+ self.graph_block_tables: torch.Tensor # Set after initial profiling.
+
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
@@ -489,16 +494,16 @@ def _prepare_decode(
lora_index_mapping.append(0)
batch_size = graph_batch_size
- context_lens = torch.tensor(context_lens,
- dtype=torch.int,
- device=self.device)
+ context_lens_tensor = torch.tensor(context_lens,
+ dtype=torch.int,
+ device=self.device)
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
- assert context_lens.shape[0] == len(input_tokens)
- assert context_lens.shape[0] == len(input_positions)
- assert context_lens.shape[0] == len(slot_mapping)
+ assert context_lens_tensor.shape[0] == len(input_tokens)
+ assert context_lens_tensor.shape[0] == len(input_positions)
+ assert context_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
@@ -527,7 +532,7 @@ def _prepare_decode(
max_prompt_len=None,
subquery_start_loc=None,
seq_start_loc=None,
- context_lens=context_lens,
+ context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
@@ -551,7 +556,11 @@ def _prepare_sample(
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
- categorized_sample_indices = {t: [] for t in SamplingType}
+ categorized_sample_indices: Dict[SamplingType,
+ List[Tuple[int, int]]] = {
+ t: []
+ for t in SamplingType
+ }
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
@@ -569,10 +578,9 @@ def _prepare_sample(
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
- sampling_params.sampling_type].append([
- categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx
- ])
+ sampling_params.sampling_type].append(
+ (categorized_sample_indices_start_idx,
+ categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
@@ -596,15 +604,16 @@ def _prepare_sample(
categorized_sample_indices[
sampling_params.sampling_type].extend(
- zip(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx +
- num_seqs),
- range(
- categorized_sampled_token_indices_start_idx,
- categorized_sampled_token_indices_start_idx +
- num_seqs)))
+ list(
+ zip(
+ range(
+ categorized_sample_indices_start_idx,
+ categorized_sample_indices_start_idx +
+ num_seqs),
+ range(
+ categorized_sampled_token_indices_start_idx,
+ categorized_sampled_token_indices_start_idx
+ + num_seqs))))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
@@ -641,9 +650,9 @@ def _prepare_sample(
def prepare_input_tensors(
self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
- Set[int], LoRAMapping, torch.Tensor]:
+ Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
prefill_reqs = []
decode_reqs = []
@@ -741,6 +750,7 @@ def prepare_input_tensors(
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
+ assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
@@ -809,7 +819,7 @@ def prepare_input_tensors(
@torch.inference_mode()
def execute_model(
self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
@@ -923,7 +933,7 @@ def remove_all_loras(self) -> bool:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_all_loras()
- def set_active_loras(self, lora_requests: List[LoRARequest],
+ def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
@@ -1065,10 +1075,16 @@ class CUDAGraphRunner:
def __init__(self, model: nn.Module):
self.model = model
- self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
+ self._graph: Optional[torch.cuda.CUDAGraph] = None
+
+ @property
+ def graph(self):
+ assert self._graph is not None
+ return self._graph
+
def capture(
self,
input_ids: torch.Tensor,
@@ -1078,7 +1094,7 @@ def capture(
memory_pool,
**kwargs,
) -> None:
- assert self.graph is None
+ assert self._graph is None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
@@ -1095,8 +1111,8 @@ def capture(
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
- self.graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
+ self._graph = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
with _maybe_pynccl():
hidden_states = self.model(
input_ids,
diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py
index f70a7193effeb..487df334d73e3 100644
--- a/vllm/worker/neuron_model_runner.py
+++ b/vllm/worker/neuron_model_runner.py
@@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple
import torch
+from torch import nn
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
@@ -34,9 +35,11 @@ def __init__(
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
- self.model = None
self.pin_memory = is_pin_memory_available()
+ # Lazy initialization.
+ self.model: nn.Module # initialize after load_model.
+
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
@@ -147,7 +150,11 @@ def _prepare_sample(
selected_token_indices: List[int] = []
generators: List[torch.Generator] = []
selected_token_start_idx = 0
- categorized_sample_indices = {t: [] for t in SamplingType}
+ categorized_sample_indices: Dict[SamplingType,
+ List[Tuple[int, int]]] = {
+ t: []
+ for t in SamplingType
+ }
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
@@ -165,10 +172,9 @@ def _prepare_sample(
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
- sampling_params.sampling_type].append([
- categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx
- ])
+ sampling_params.sampling_type].append(
+ (categorized_sample_indices_start_idx,
+ categorized_sampled_token_indices_start_idx))
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
@@ -237,7 +243,7 @@ def _prepare_sample(
def prepare_input_tensors(
self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
@@ -259,7 +265,7 @@ def prepare_input_tensors(
@torch.inference_mode()
def execute_model(
self,
- seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list)
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index b021866965401..2203570b37ad6 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -1,7 +1,7 @@
"""A GPU worker class."""
import gc
import os
-from typing import Dict, List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
import torch
import torch.distributed
@@ -82,8 +82,8 @@ def __init__(
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
- self.cache_engine = None
- self.gpu_cache = None
+ self.cache_engine: CacheEngine
+ self.gpu_cache: List[torch.Tensor]
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
@@ -223,7 +223,7 @@ def execute_model(
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
- data = {
+ data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
@@ -237,6 +237,9 @@ def execute_model(
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
+ assert blocks_to_swap_in is not None
+ assert blocks_to_swap_out is not None
+ assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 309aa6256acea..13e062fe64b29 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -1,7 +1,7 @@
import importlib
import os
from abc import ABC, abstractmethod
-from typing import Dict, List, Tuple
+from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -56,7 +56,7 @@ def execute_model(
raise NotImplementedError
@abstractmethod
- def get_cache_block_size_bytes() -> int:
+ def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
@@ -71,7 +71,7 @@ def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
raise NotImplementedError
@@ -86,7 +86,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
- def list_loras(self) -> List[int]:
+ def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
From 6dc1fc9cfed74f63f3e90c60e4d6ad7065bd4529 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 17 Apr 2024 22:28:52 -0700
Subject: [PATCH 060/413] [Core] nccl integrity check and test (#4155)
[Core] Add integrity check during initialization; add test for it (#4155)
---
.buildkite/test-pipeline.yaml | 1 +
tests/distributed/test_pynccl_library.py | 43 ++++++++++++++++
.../device_communicators/pynccl.py | 38 +++++---------
vllm/utils.py | 51 +++++++++++++++++++
4 files changed, 107 insertions(+), 26 deletions(-)
create mode 100644 tests/distributed/test_pynccl_library.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index f39c3302ac2e9..2263dee20fbed 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -33,6 +33,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s test_pynccl.py
+ - pytest -v -s test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
diff --git a/tests/distributed/test_pynccl_library.py b/tests/distributed/test_pynccl_library.py
new file mode 100644
index 0000000000000..ec60a5ed3114d
--- /dev/null
+++ b/tests/distributed/test_pynccl_library.py
@@ -0,0 +1,43 @@
+import multiprocessing
+import tempfile
+
+
+def target_fn(env, filepath):
+ from vllm.utils import update_environment_variables
+ update_environment_variables(env)
+ from vllm.utils import nccl_integrity_check
+ nccl_integrity_check(filepath)
+
+
+def test_library_file():
+ # note: don't import vllm.distributed.device_communicators.pynccl
+ # before running this test, otherwise the library file will be loaded
+ # and it might interfere with the test
+ from vllm.utils import find_nccl_library
+ so_file = find_nccl_library()
+ with open(so_file, 'rb') as f:
+ content = f.read()
+ try:
+ # corrupt the library file, should raise an exception
+ with open(so_file, 'wb') as f:
+ f.write(content[:len(content) // 2])
+ p = multiprocessing.Process(target=target_fn, args=({}, so_file))
+ p.start()
+ p.join()
+ assert p.exitcode != 0
+
+ # move the library file to a tmp path
+ # test VLLM_NCCL_SO_PATH
+ fd, path = tempfile.mkstemp()
+ with open(path, 'wb') as f:
+ f.write(content)
+ p = multiprocessing.Process(target=target_fn,
+ args=({
+ "VLLM_NCCL_SO_PATH": path
+ }, path))
+ p.start()
+ p.join()
+ assert p.exitcode == 0
+ finally:
+ with open(so_file, 'wb') as f:
+ f.write(content)
diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
index 0a8bb860efa1c..c57a4f59d442c 100644
--- a/vllm/distributed/device_communicators/pynccl.py
+++ b/vllm/distributed/device_communicators/pynccl.py
@@ -21,8 +21,7 @@
import ctypes
import datetime
-import glob
-import os
+import platform
# ===================== import region =====================
import torch
@@ -30,40 +29,27 @@
from torch.distributed import ReduceOp
from vllm.logger import init_logger
+from vllm.utils import find_nccl_library, nccl_integrity_check
logger = init_logger(__name__)
-so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
-
-# check if we have vllm-managed nccl
-vllm_nccl_path = None
-if torch.version.cuda is not None:
- cuda_major = torch.version.cuda.split(".")[0]
- path = os.path.expanduser(
- f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
- files = glob.glob(path)
- vllm_nccl_path = files[0] if files else None
-
-# manually load the nccl library
-if so_file:
- logger.info(
- f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
-else:
- if torch.version.cuda is not None:
- so_file = vllm_nccl_path or "libnccl.so.2"
- elif torch.version.hip is not None:
- so_file = "librccl.so.1"
- else:
- raise ValueError("NCCL only supports CUDA and ROCm backends.")
- logger.info(f"Loading nccl from library {so_file}")
+so_file = find_nccl_library()
try:
+ # load the library in another process.
+ # if it core dumps, it will not crash the current process
+ nccl_integrity_check(so_file)
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
f"Failed to load NCCL library from {so_file} ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
- "Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
+ "Otherwise, the nccl library might not exist, be corrupted "
+ f"or it does not support the current platform {platform.platform()}."
+ f"One solution is to download libnccl2 version 2.18 from "
+ f"https://developer.download.nvidia.com/compute/cuda/repos/ "
+ f"and extract the libnccl.so.2 file. If you already have the "
+ f"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.")
raise e
diff --git a/vllm/utils.py b/vllm/utils.py
index e132575e7bf81..49e7033c23e6a 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -1,6 +1,7 @@
import asyncio
import enum
import gc
+import glob
import os
import socket
import subprocess
@@ -517,3 +518,53 @@ def init_cached_hf_modules():
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
+
+
+def nccl_integrity_check(filepath):
+ """
+ when the library is corrupted, we cannot catch
+ the exception in python. it will crash the process.
+ instead, we use the exit code of `ldd` to check
+ if the library is corrupted. if not, we will return
+ the version of the library.
+ """
+ exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
+ if exit_code != 0:
+ raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
+ import ctypes
+
+ nccl = ctypes.CDLL(filepath)
+ version = ctypes.c_int()
+ nccl.ncclGetVersion.restype = ctypes.c_int
+ nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
+ result = nccl.ncclGetVersion(ctypes.byref(version))
+ assert result == 0
+ return version.value
+
+
+def find_nccl_library():
+ so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
+
+ # check if we have vllm-managed nccl
+ vllm_nccl_path = None
+ if torch.version.cuda is not None:
+ cuda_major = torch.version.cuda.split(".")[0]
+ path = os.path.expanduser(
+ f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*")
+ files = glob.glob(path)
+ vllm_nccl_path = files[0] if files else None
+
+ # manually load the nccl library
+ if so_file:
+ logger.info(
+ f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}"
+ )
+ else:
+ if torch.version.cuda is not None:
+ so_file = vllm_nccl_path or "libnccl.so.2"
+ elif torch.version.hip is not None:
+ so_file = "librccl.so.1"
+ else:
+ raise ValueError("NCCL only supports CUDA and ROCm backends.")
+ logger.info(f"Found nccl from library {so_file}")
+ return so_file
From 66ded030677c7a0ca696f8d64e41637f4a358c00 Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Thu, 18 Apr 2024 08:16:26 +0100
Subject: [PATCH 061/413] Allow model to be served under multiple names (#2894)
Co-authored-by: Alexandre Payot
---
vllm/entrypoints/openai/api_server.py | 8 ++++----
vllm/entrypoints/openai/cli_args.py | 10 +++++++---
vllm/entrypoints/openai/serving_chat.py | 8 ++++----
vllm/entrypoints/openai/serving_completion.py | 6 +++---
vllm/entrypoints/openai/serving_engine.py | 15 ++++++++-------
5 files changed, 26 insertions(+), 21 deletions(-)
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 32282bfd8d12b..d6673976bb775 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -150,18 +150,18 @@ async def authentication(request: Request, call_next):
logger.info(f"args: {args}")
if args.served_model_name is not None:
- served_model = args.served_model_name
+ served_model_names = args.served_model_name
else:
- served_model = args.model
+ served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
- openai_serving_chat = OpenAIServingChat(engine, served_model,
+ openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
- engine, served_model, args.lora_modules)
+ engine, served_model_names, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app,
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index cc71931b97955..5c361b4d184ee 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -54,11 +54,15 @@ def make_arg_parser():
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
+ nargs="+",
type=str,
default=None,
- help="The model name used in the API. If not "
- "specified, the model name will be the same as "
- "the huggingface name.")
+ help="The model name(s) used in the API. If multiple "
+ "names are provided, the server will respond to any "
+ "of the provided names. The model name in the model "
+ "field of a response will be the first name in this "
+ "list. If not specified, the model name will be the "
+ "same as the `--model` argument.")
parser.add_argument(
"--lora-modules",
type=str,
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index c9ed4a9de20f4..f35eab15bc487 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -24,12 +24,12 @@ class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
- served_model: str,
+ served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None):
super().__init__(engine=engine,
- served_model=served_model,
+ served_model_names=served_model_names,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
@@ -109,7 +109,7 @@ async def chat_completion_stream_generator(
result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
- model_name = request.model
+ model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
first_iteration = True
@@ -251,7 +251,7 @@ async def chat_completion_full_generator(
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
- model_name = request.model
+ model_name = self.served_model_names[0]
created_time = int(time.time())
final_res: RequestOutput = None
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index a71f2d6a4426a..b7e2530a69b51 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -53,10 +53,10 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
- served_model: str,
+ served_model_names: List[str],
lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine,
- served_model=served_model,
+ served_model_names=served_model_names,
lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest,
@@ -79,7 +79,7 @@ async def create_completion(self, request: CompletionRequest,
return self.create_error_response(
"suffix is not currently supported")
- model_name = request.model
+ model_name = self.served_model_names[0]
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 77a568b564039..b5a7a977ebbab 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -29,10 +29,10 @@ class OpenAIServing:
def __init__(self,
engine: AsyncLLMEngine,
- served_model: str,
+ served_model_names: List[str],
lora_modules=Optional[List[LoRA]]):
self.engine = engine
- self.served_model = served_model
+ self.served_model_names = served_model_names
if lora_modules is None:
self.lora_requests = []
else:
@@ -74,13 +74,14 @@ async def _post_init(self):
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
- ModelCard(id=self.served_model,
- root=self.served_model,
+ ModelCard(id=served_model_name,
+ root=self.served_model_names[0],
permission=[ModelPermission()])
+ for served_model_name in self.served_model_names
]
lora_cards = [
ModelCard(id=lora.lora_name,
- root=self.served_model,
+ root=self.served_model_names[0],
permission=[ModelPermission()])
for lora in self.lora_requests
]
@@ -150,7 +151,7 @@ def create_streaming_error_response(
return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]:
- if request.model == self.served_model:
+ if request.model in self.served_model_names:
return
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
@@ -160,7 +161,7 @@ async def _check_model(self, request) -> Optional[ErrorResponse]:
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
- if request.model == self.served_model:
+ if request.model in self.served_model_names:
return
for lora in self.lora_requests:
if request.model == lora.lora_name:
From 53b018edcbc601f0eea9f65f13a9a9620c4be8dc Mon Sep 17 00:00:00 2001
From: Michael Goin
Date: Thu, 18 Apr 2024 03:21:55 -0400
Subject: [PATCH 062/413] [Bugfix] Get available quantization methods from
quantization registry (#4098)
---
benchmarks/benchmark_latency.py | 3 ++-
benchmarks/benchmark_throughput.py | 4 +++-
tests/models/test_marlin.py | 7 +++----
vllm/config.py | 7 ++++---
vllm/engine/arg_utils.py | 3 ++-
vllm/model_executor/layers/quantization/__init__.py | 7 ++++---
6 files changed, 18 insertions(+), 13 deletions(-)
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
index aadbc441713fc..44da3bad8d840 100644
--- a/benchmarks/benchmark_latency.py
+++ b/benchmarks/benchmark_latency.py
@@ -9,6 +9,7 @@
from tqdm import tqdm
from vllm import LLM, SamplingParams
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def main(args: argparse.Namespace):
@@ -101,7 +102,7 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
- choices=['awq', 'gptq', 'squeezellm', None],
+ choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 6df1e1d628e6c..6bb889d1eceba 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -10,6 +10,8 @@
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+
def sample_requests(
dataset_path: str,
@@ -267,7 +269,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
- choices=['awq', 'gptq', 'squeezellm', None],
+ choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py
index 2305db3510060..4fe6daec02520 100644
--- a/tests/models/test_marlin.py
+++ b/tests/models/test_marlin.py
@@ -16,13 +16,12 @@
import pytest
import torch
-from vllm.model_executor.layers.quantization import (
- _QUANTIZATION_CONFIG_REGISTRY)
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
-marlin_not_supported = (
- capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
+marlin_not_supported = (capability <
+ QUANTIZATION_METHODS["marlin"].get_min_capability())
@dataclass
diff --git a/vllm/config.py b/vllm/config.py
index 5a29620e85ac6..2912d6ccc2c5b 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -9,6 +9,7 @@
from transformers import PretrainedConfig
from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
@@ -118,8 +119,8 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
- supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
- rocm_not_supported_quantization = ["awq", "marlin"]
+ supported_quantization = [*QUANTIZATION_METHODS]
+ rocm_supported_quantization = ["gptq", "squeezellm"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
@@ -155,7 +156,7 @@ def _verify_quantization(self) -> None:
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
if is_hip(
- ) and self.quantization in rocm_not_supported_quantization:
+ ) and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index c61c0cc67d7a2..2999ab0a7e72a 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -7,6 +7,7 @@
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig)
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
@@ -286,7 +287,7 @@ def add_cli_args(
parser.add_argument('--quantization',
'-q',
type=str,
- choices=['awq', 'gptq', 'squeezellm', None],
+ choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index ad988d48755b0..a3b89a66469eb 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -7,7 +7,7 @@
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
-_QUANTIZATION_CONFIG_REGISTRY = {
+QUANTIZATION_METHODS = {
"awq": AWQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
@@ -16,12 +16,13 @@
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
- if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
+ if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
- return _QUANTIZATION_CONFIG_REGISTRY[quantization]
+ return QUANTIZATION_METHODS[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
+ "QUANTIZATION_METHODS",
]
From e8cc7967ff8a6f8432747a9e87ab451d36e1ff57 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Micha=C5=82=20Moskal?=
Date: Thu, 18 Apr 2024 00:51:28 -0700
Subject: [PATCH 063/413] [Bugfix][Kernel] allow non-power-of-two head sizes in
prefix prefill (#4128)
---
tests/kernels/test_prefix_prefill.py | 2 +-
vllm/attention/ops/prefix_prefill.py | 44 +++++++++++++++++-----------
2 files changed, 28 insertions(+), 18 deletions(-)
diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py
index 6494fb34af98f..ad31b0a7c2a19 100644
--- a/tests/kernels/test_prefix_prefill.py
+++ b/tests/kernels/test_prefix_prefill.py
@@ -10,7 +10,7 @@
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
-HEAD_SIZES = [128]
+HEAD_SIZES = [128, 96]
DTYPES = [torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py
index 70f09224f1cf6..4896cf3909c6e 100644
--- a/vllm/attention/ops/prefix_prefill.py
+++ b/vllm/attention/ops/prefix_prefill.py
@@ -47,7 +47,8 @@ def _fwd_kernel(
stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr, # head size
+ BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
@@ -59,26 +60,30 @@ def _fwd_kernel(
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+ cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
- q = tl.load(
- Q + off_q,
- mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
- other=0.0)
+ dim_mask = tl.where(
+ tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
+
+ q = tl.load(Q + off_q,
+ mask=dim_mask[None, :] &
+ (offs_m[:, None] < cur_batch_query_len),
+ other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
@@ -99,7 +104,8 @@ def _fwd_kernel(
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
- mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
+ mask=dim_mask[:, None] &
+ ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -126,7 +132,8 @@ def _fwd_kernel(
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
- mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
+ mask=dim_mask[None, :] &
+ ((start_n + offs_n[:, None]) < cur_batch_ctx_len),
other=0.0)
p = p.to(v.dtype)
@@ -142,16 +149,15 @@ def _fwd_kernel(
k_ptrs = K + off_k
v_ptrs = V + off_v
- block_mask = tl.where(
- block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
+ block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
- mask=(start_n + offs_n[None, :]) <
- cur_batch_seq_len - cur_batch_ctx_len,
+ mask=dim_mask[:, None] &
+ ((start_n + offs_n[None, :]) < cur_batch_query_len),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
@@ -179,8 +185,8 @@ def _fwd_kernel(
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
- mask=(start_n + offs_n[:, None]) <
- cur_batch_seq_len - cur_batch_ctx_len,
+ mask=dim_mask[None, :] &
+ ((start_n + offs_n[:, None]) < cur_batch_query_len),
other=0.0)
p = p.to(v.dtype)
@@ -195,7 +201,8 @@ def _fwd_kernel(
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
- mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
+ mask=dim_mask[None, :] &
+ (offs_m[:, None] < cur_batch_query_len))
return
@triton.jit
@@ -636,7 +643,8 @@ def context_attention_fwd(q,
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
- assert Lk in {16, 32, 64, 128}
+ # round up Lk to a power of 2 - this is required for Triton block size
+ Lk_padded = 2**((Lk - 1).bit_length())
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
@@ -646,6 +654,7 @@ def context_attention_fwd(q,
num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
+ assert Lk == Lk_padded
_fwd_kernel_alibi[grid](
q,
k,
@@ -738,6 +747,7 @@ def context_attention_fwd(q,
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
+ BLOCK_DMODEL_PADDED=Lk_padded,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
From 705578ae14b648782a8a321dd0903c163bd77375 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Thu, 18 Apr 2024 10:55:48 -0700
Subject: [PATCH 064/413] [Docs] document that Meta Llama 3 is supported
(#4175)
---
README.md | 2 +-
docs/source/models/supported_models.rst | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 8434c11883341..947d50d4ad764 100644
--- a/README.md
+++ b/README.md
@@ -69,7 +69,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
-- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
+- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index 5e5ce871f61dd..951fc3aac0c75 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -80,8 +80,8 @@ Alongside each architecture, we include some popular models that use it.
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
* - :code:`LlamaForCausalLM`
- - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
- - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
+ - LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
+ - :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
From e1bb2fd52dea0bbc772bdf35fd27664c5daec7b2 Mon Sep 17 00:00:00 2001
From: James Whedbee
Date: Thu, 18 Apr 2024 16:12:55 -0500
Subject: [PATCH 065/413] [Bugfix] Support logprobs when using guided_json and
other constrained decoding fields (#4149)
---
tests/entrypoints/test_openai_server.py | 30 +++++++++++++++++++++++
vllm/entrypoints/openai/serving_engine.py | 4 ++-
2 files changed, 33 insertions(+), 1 deletion(-)
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index 14e6ee0ffe9d9..0dd30eec30086 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -723,6 +723,36 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
+@pytest.mark.parametrize("guided_decoding_backend",
+ ["outlines", "lm-format-enforcer"])
+async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
+ guided_decoding_backend: str):
+ messages = [{
+ "role": "system",
+ "content": "you are a helpful assistant"
+ }, {
+ "role":
+ "user",
+ "content":
+ "The best language for type-safe systems programming is "
+ }]
+ chat_completion = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=messages,
+ max_tokens=10,
+ logprobs=True,
+ top_logprobs=5,
+ extra_body=dict(guided_choice=TEST_CHOICE,
+ guided_decoding_backend=guided_decoding_backend))
+ top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
+
+ # -9999.0 is the minimum logprob returned by OpenAI
+ assert all(
+ isinstance(logprob, float) and logprob >= -9999.0
+ for token_dict in top_logprobs
+ for token, logprob in token_dict.items())
+
+
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index b5a7a977ebbab..8e5ee88d7f3a9 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -116,7 +116,9 @@ def _create_logprobs(
if num_output_top_logprobs:
logprobs.top_logprobs.append({
- p.decoded_token: p.logprob
+ # Convert float("-inf") to the
+ # JSON-serializable float that OpenAI uses
+ p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
From 87fa80c91f5b24a2ee1805b80c3eca8fd6600cd5 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Thu, 18 Apr 2024 14:36:39 -0700
Subject: [PATCH 066/413] [Misc] Bump transformers to latest version (#4176)
---
requirements-common.txt | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/requirements-common.txt b/requirements-common.txt
index c1614d2537b25..3de0f98e7c0c3 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -5,7 +5,8 @@ sentencepiece # Required for LLaMA tokenizer.
numpy
requests
py-cpuinfo
-transformers >= 4.39.1 # Required for StarCoder2 & Llava.
+transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
+tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
From cd2f63fb362b2c53b993e7edf6565ab6a5f9f260 Mon Sep 17 00:00:00 2001
From: Liangfu Chen
Date: Thu, 18 Apr 2024 15:26:01 -0700
Subject: [PATCH 067/413] [CI/CD] add neuron docker and ci test scripts (#3571)
---
.buildkite/run-neuron-test.sh | 37 ++++++++++++++++++++++++++++++++
.buildkite/test-template.j2 | 5 +++++
Dockerfile.neuron | 36 +++++++++++++++++++++++++++++++
setup.py | 3 ++-
vllm/engine/async_llm_engine.py | 4 ++--
vllm/executor/neuron_executor.py | 22 ++++++++++++++++++-
6 files changed, 103 insertions(+), 4 deletions(-)
create mode 100644 .buildkite/run-neuron-test.sh
create mode 100644 Dockerfile.neuron
diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh
new file mode 100644
index 0000000000000..8ba03b78e8dbf
--- /dev/null
+++ b/.buildkite/run-neuron-test.sh
@@ -0,0 +1,37 @@
+# This script build the Neuron docker image and run the API server inside the container.
+# It serves a sanity check for compilation and basic model usage.
+set -e
+
+# Try building the docker image
+aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
+docker build -t neuron -f Dockerfile.neuron .
+
+# Setup cleanup
+remove_docker_container() { docker rm -f neuron || true; }
+trap remove_docker_container EXIT
+remove_docker_container
+
+# Run the image
+docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \
+ --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 &
+
+# Wait for the server to start
+wait_for_server_to_start() {
+ timeout=300
+ counter=0
+
+ while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
+ sleep 1
+ counter=$((counter + 1))
+ if [ $counter -ge $timeout ]; then
+ echo "Timeout after $timeout seconds"
+ break
+ fi
+ done
+}
+wait_for_server_to_start
+
+# Test a simple prompt
+curl -X POST -H "Content-Type: application/json" \
+ localhost:8000/generate \
+ -d '{"prompt": "San Francisco is a"}'
diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2
index 0e1acc9777d4b..fb1086db77823 100644
--- a/.buildkite/test-template.j2
+++ b/.buildkite/test-template.j2
@@ -21,6 +21,11 @@ steps:
queue: amd
command: bash .buildkite/run-amd-test.sh
+ - label: "Neuron Test"
+ agents:
+ queue: neuron
+ command: bash .buildkite/run-neuron-test.sh
+
- label: "CPU Test"
command: bash .buildkite/run-cpu-test.sh
diff --git a/Dockerfile.neuron b/Dockerfile.neuron
new file mode 100644
index 0000000000000..fe42b4ef393f1
--- /dev/null
+++ b/Dockerfile.neuron
@@ -0,0 +1,36 @@
+# default base image
+ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
+
+FROM $BASE_IMAGE
+
+RUN echo "Base image is $BASE_IMAGE"
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+
+### Mount Point ###
+# When launching the container, mount the code directory to /app
+ARG APP_MOUNT=/app
+VOLUME [ ${APP_MOUNT} ]
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
+RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
+RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
+RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
+
+COPY ./vllm /app/vllm/vllm
+COPY ./setup.py /app/vllm/setup.py
+COPY ./requirements-common.txt /app/vllm/requirements-common.txt
+COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
+
+RUN cd /app/vllm \
+ && python3 -m pip install -U -r requirements-neuron.txt
+
+ENV VLLM_BUILD_WITH_NEURON 1
+RUN cd /app/vllm \
+ && pip install -e . \
+ && cd ..
+
+CMD ["/bin/bash"]
diff --git a/setup.py b/setup.py
index 19a9150ad2e64..4b672e1af8494 100644
--- a/setup.py
+++ b/setup.py
@@ -204,7 +204,8 @@ def _is_neuron() -> bool:
subprocess.run(["neuron-ls"], capture_output=True, check=True)
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
torch_neuronx_installed = False
- return torch_neuronx_installed
+ return torch_neuronx_installed or os.environ.get("VLLM_BUILD_WITH_NEURON",
+ False)
def _is_cpu() -> bool:
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index c3020d2b38db0..c436ece83f65a 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -335,8 +335,8 @@ def from_engine_args(
engine_config = engine_args.create_engine_config()
if engine_config.device_config.device_type == "neuron":
- raise NotImplementedError("Neuron is not supported for "
- "async engine yet.")
+ from vllm.executor.neuron_executor import NeuronExecutorAsync
+ executor_class = NeuronExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index 7cc187e297c9f..5a137d1bdcb3b 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -1,9 +1,10 @@
from typing import Dict, List, Set, Tuple
-from vllm.executor.executor_base import ExecutorBase
+from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
+from vllm.utils import make_async
logger = init_logger(__name__)
@@ -73,3 +74,22 @@ def check_health(self) -> None:
# NeuronExecutor will always be healthy as long as
# it's running.
return
+
+
+class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
+
+ async def execute_model_async(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ blocks_to_swap_in: Dict[int, int],
+ blocks_to_swap_out: Dict[int, int],
+ blocks_to_copy: Dict[int, List[int]],
+ ) -> SamplerOutput:
+ output = await make_async(self.driver_worker.execute_model)(
+ seq_group_metadata_list=seq_group_metadata_list, )
+ return output
+
+ async def check_health_async(self) -> None:
+ # NeuronExecutor will always be healthy as long as
+ # it's running.
+ return
From 8f9c28fd40a9daa744d08cacb0bd5ac2247a97d1 Mon Sep 17 00:00:00 2001
From: Adam Tilghman
Date: Thu, 18 Apr 2024 15:32:47 -0700
Subject: [PATCH 068/413] [Bugfix] Fix CustomAllreduce nvlink topology
detection (#3974)
[Bugfix] Fix CustomAllreduce pcie nvlink topology detection (#3974) (#4159)
---
vllm/distributed/device_communicators/custom_all_reduce.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index f83caef879da3..7602897d3dd8f 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -145,8 +145,10 @@ def _is_full_nvlink(rank, world_size):
for i in range(world_size):
if i != rank:
try:
- link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
- if not link_state:
+ peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.info(
From 8a7a3e4436d7284df4c0913f074d77d640a9c6c3 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Thu, 18 Apr 2024 16:15:12 -0700
Subject: [PATCH 069/413] [Core] add an option to log every function call to
for debugging hang/crash in distributed inference (#4079)
Co-authored-by: Simon Mo
---
.buildkite/test-pipeline.yaml | 2 +-
.github/ISSUE_TEMPLATE/400-bug report.yml | 2 +
tests/test_logger.py | 27 ++++++++++++
vllm/executor/ray_gpu_executor.py | 12 ++++--
vllm/logger.py | 52 +++++++++++++++++++++++
vllm/utils.py | 13 +++++-
vllm/worker/worker_base.py | 20 +++++++--
7 files changed, 120 insertions(+), 8 deletions(-)
create mode 100644 tests/test_logger.py
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 2263dee20fbed..0f920c7ec1442 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -40,7 +40,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test
- command: pytest -v -s engine tokenization test_sequence.py test_config.py
+ command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
- label: Entrypoints Test
commands:
diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml
index f1124dfa78bbc..c87f8fddcb776 100644
--- a/.github/ISSUE_TEMPLATE/400-bug report.yml
+++ b/.github/ISSUE_TEMPLATE/400-bug report.yml
@@ -57,6 +57,8 @@ body:
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
+
+ If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: |
A clear and concise description of what the bug is.
diff --git a/tests/test_logger.py b/tests/test_logger.py
new file mode 100644
index 0000000000000..601f72b50811c
--- /dev/null
+++ b/tests/test_logger.py
@@ -0,0 +1,27 @@
+import os
+import sys
+import tempfile
+
+from vllm.logger import enable_trace_function_call
+
+
+def f1(x):
+ return f2(x)
+
+
+def f2(x):
+ return x
+
+
+def test_trace_function_call():
+ fd, path = tempfile.mkstemp()
+ cur_dir = os.path.dirname(__file__)
+ enable_trace_function_call(path, cur_dir)
+ f1(1)
+ with open(path, 'r') as f:
+ content = f.read()
+
+ assert "f1" in content
+ assert "f2" in content
+ sys.settrace(None)
+ os.remove(path)
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 5a43f1fc28a84..f779b0f8a5113 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -10,7 +10,7 @@
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
- make_async)
+ get_vllm_instance_id, make_async)
if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -133,12 +133,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
- # Set CUDA_VISIBLE_DEVICES for the driver and workers.
+ VLLM_INSTANCE_ID = get_vllm_instance_id()
+
+ # Set environment variables for the driver and workers.
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
- ",".join(map(str, node_gpus[node_id]))
+ ",".join(map(str, node_gpus[node_id])),
+ "VLLM_INSTANCE_ID":
+ VLLM_INSTANCE_ID,
+ "VLLM_TRACE_FUNCTION":
+ os.getenv("VLLM_TRACE_FUNCTION", "0"),
}])
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
diff --git a/vllm/logger.py b/vllm/logger.py
index af9575085ef37..046f0e9099a4b 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -1,9 +1,11 @@
# Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM."""
+import datetime
import logging
import os
import sys
+from functools import partial
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
@@ -65,3 +67,53 @@ def init_logger(name: str):
logger.addHandler(_default_handler)
logger.propagate = False
return logger
+
+
+logger = init_logger(__name__)
+
+
+def _trace_calls(log_path, root_dir, frame, event, arg=None):
+ if event in ['call', 'return']:
+ # Extract the filename, line number, function name, and the code object
+ filename = frame.f_code.co_filename
+ lineno = frame.f_lineno
+ func_name = frame.f_code.co_name
+ if not filename.startswith(root_dir):
+ # only log the functions in the vllm root_dir
+ return
+ # Log every function call or return
+ try:
+ with open(log_path, 'a') as f:
+ if event == 'call':
+ f.write(f"{datetime.datetime.now()} Call to"
+ f" {func_name} in {filename}:{lineno}\n")
+ else:
+ f.write(f"{datetime.datetime.now()} Return from"
+ f" {func_name} in {filename}:{lineno}\n")
+ except NameError:
+ # modules are deleted during shutdown
+ pass
+ return partial(_trace_calls, log_path, root_dir)
+
+
+def enable_trace_function_call(log_file_path: str,
+ root_dir: Optional[str] = None):
+ """
+ Enable tracing of every function call in code under `root_dir`.
+ This is useful for debugging hangs or crashes.
+ `log_file_path` is the path to the log file.
+ `root_dir` is the root directory of the code to trace. If None, it is the
+ vllm root directory.
+
+ Note that this call is thread-level, any threads calling this function
+ will have the trace enabled. Other threads will not be affected.
+ """
+ logger.warning(
+ "VLLM_TRACE_FUNCTION is enabled. It will record every"
+ " function executed by Python. This will slow down the code. It "
+ "is suggested to be used for debugging hang or crashes only.")
+ logger.info(f"Trace frame log is saved to {log_file_path}")
+ if root_dir is None:
+ # by default, this is the vllm root directory
+ root_dir = os.path.dirname(os.path.dirname(__file__))
+ sys.settrace(partial(_trace_calls, log_file_path, root_dir))
diff --git a/vllm/utils.py b/vllm/utils.py
index 49e7033c23e6a..fbe86dacaeb99 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -163,6 +163,17 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex)
+@lru_cache(maxsize=None)
+def get_vllm_instance_id():
+ """
+ If the environment variable VLLM_INSTANCE_ID is set, return it.
+ Otherwise, return a random UUID.
+ Instance id represents an instance of the VLLM. All processes in the same
+ instance should have the same instance id.
+ """
+ return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")
+
+
@lru_cache(maxsize=None)
def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071
@@ -274,7 +285,7 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
- if k in os.environ:
+ if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 13e062fe64b29..783dff3a43404 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -1,12 +1,15 @@
+import datetime
import importlib
import os
+import tempfile
+import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple
-from vllm.logger import init_logger
+from vllm.logger import enable_trace_function_call, init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
-from vllm.utils import update_environment_variables
+from vllm.utils import get_vllm_instance_id, update_environment_variables
logger = init_logger(__name__)
@@ -115,9 +118,20 @@ def update_environment_variables(self, envs: Dict[str, str]) -> None:
def init_worker(self, *args, **kwargs):
"""
- Actual initialization of the worker class.
+ Actual initialization of the worker class, and set up
+ function tracing if required.
Arguments are passed to the worker class constructor.
"""
+ if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
+ tmp_dir = tempfile.gettempdir()
+ filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
+ f"_thread_{threading.get_ident()}_"
+ f"at_{datetime.datetime.now()}.log").replace(" ", "_")
+ log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
+ filename)
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
+ enable_trace_function_call(log_path)
+
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
From a134ef6f5e6c24d3cd459c63557e5db276db25b2 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Thu, 18 Apr 2024 21:13:36 -0700
Subject: [PATCH 070/413] Support eos_token_id from generation_config.json
(#4182)
---
vllm/engine/llm_engine.py | 19 +++++++++++++++++--
vllm/sampling_params.py | 14 +++++++++++++-
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index c3de57e249ff8..88b344f50767f 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -1,7 +1,7 @@
import time
from typing import Iterable, List, Optional, Type, Union
-from transformers import PreTrainedTokenizer
+from transformers import GenerationConfig, PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
@@ -34,6 +34,17 @@
_LOCAL_LOGGING_INTERVAL_SEC = 5
+def _load_generation_config_dict(model_config: ModelConfig):
+ try:
+ return GenerationConfig.from_pretrained(
+ model_config.model,
+ revision=model_config.revision,
+ ).to_diff_dict()
+ except OSError:
+ # Not found.
+ return {}
+
+
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@@ -124,6 +135,8 @@ def __init__(
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()
+ self.generation_config_fields = _load_generation_config_dict(
+ model_config)
self.model_executor = executor_class(
model_config=model_config,
@@ -391,6 +404,8 @@ def add_request(
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
+ sampling_params.update_from_generation_config(
+ self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
@@ -435,7 +450,7 @@ def _process_model_outputs(
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
-
+
Returns RequestOutputs that can be returned to the client.
"""
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 53a38b25bfdac..dc0e60344d858 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -2,7 +2,7 @@
import copy
from enum import IntEnum
from functools import cached_property
-from typing import Callable, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
import torch
from pydantic import Field
@@ -271,6 +271,18 @@ def _verify_greedy_sampling(self) -> None:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
+ def update_from_generation_config(
+ self, generation_config: Dict[str, Any]) -> None:
+ """Update if there are non-default values from generation_config"""
+ # Update eos_token_id for generation
+ if eos_ids := generation_config.get("eos_token_id"):
+ # it can be either int or list of int
+ if isinstance(eos_ids, int):
+ eos_ids = [eos_ids]
+ original_stop_token_ids = set(self.stop_token_ids)
+ original_stop_token_ids.update(eos_ids)
+ self.stop_token_ids = list(original_stop_token_ids)
+
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
From d17c8477f1fb337ea9fcf439bcab4d323058c1b4 Mon Sep 17 00:00:00 2001
From: Jee Li
Date: Fri, 19 Apr 2024 15:59:54 +0800
Subject: [PATCH 071/413] [Bugfix] Fix LoRA loading check (#4138)
Co-authored-by: simon-mo
---
tests/lora/conftest.py | 6 ++++++
tests/lora/test_lora_checkpoints.py | 22 ++++++++++++++++++++--
vllm/lora/models.py | 4 +++-
3 files changed, 29 insertions(+), 3 deletions(-)
diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py
index 2dabfb6b4337c..a3ffc53d8cd1d 100644
--- a/tests/lora/conftest.py
+++ b/tests/lora/conftest.py
@@ -143,6 +143,12 @@ def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
+@pytest.fixture(scope="session")
+def baichuan_zero_lora_files():
+ # all the lora_B weights are initialized to zero.
+ return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
+
+
@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py
index 35ad7342944cd..d4d1665b624ea 100644
--- a/tests/lora/test_lora_checkpoints.py
+++ b/tests/lora/test_lora_checkpoints.py
@@ -3,9 +3,16 @@
from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
+lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
-@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
-def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
+
+@pytest.mark.parametrize("lora_name", lora_lst)
+def test_load_checkpoints(
+ lora_name,
+ baichuan_lora_files,
+ baichuan_zero_lora_files,
+ chatglm3_lora_files,
+):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
@@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
+ elif lora_name == "baichuan7B-zero":
+ #Test that the target_modules contain prefix
+ # such as "model.layers.0.self_atten.W_pack", and
+ # the test should pass.
+ LoRAModel.from_local_checkpoint(
+ baichuan_zero_lora_files,
+ expected_lora_modules,
+ lora_model_id=1,
+ device="cpu",
+ embedding_modules=embedding_modules,
+ embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 62f1502458008..6bb9fee27d535 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -212,7 +212,9 @@ def from_local_checkpoint(
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
- if module not in expected_lora_modules:
+ # Compatible with more modules, such as:layers.11.self_attn.k_proj
+ part_name = module.split(".")[-1]
+ if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
From 221d93ecbf51102df69deaf153d35df6d93370f6 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Fri, 19 Apr 2024 01:00:22 -0700
Subject: [PATCH 072/413] Bump version of 0.4.1 (#4177)
---
vllm/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/__init__.py b/vllm/__init__.py
index 2c1fd40573240..5ca4680227598 100644
--- a/vllm/__init__.py
+++ b/vllm/__init__.py
@@ -9,7 +9,7 @@
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
-__version__ = "0.4.0.post1"
+__version__ = "0.4.1"
__all__ = [
"LLM",
From 8f20fc04bf7384089395caa021766cd352d0cf0b Mon Sep 17 00:00:00 2001
From: Uranus <109661872+UranusSeven@users.noreply.github.com>
Date: Fri, 19 Apr 2024 16:18:33 +0800
Subject: [PATCH 073/413] [Misc] fix docstrings (#4191)
Co-authored-by: Zhong Wang
---
vllm/sequence.py | 9 +++------
1 file changed, 3 insertions(+), 6 deletions(-)
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 92362a9a5d2a3..7dcacab6f2ab6 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -160,7 +160,7 @@ def reset_state_for_recompute(self) -> None:
self._stage = SequenceStage.PREFILL
def get_num_uncomputed_tokens(self) -> int:
- """Return the number of prefil tokens that are not computed."""
+ """Return the number of prefill tokens that are not computed."""
# we use `get_len()` which includes prompt_len + output_len instead
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
@@ -345,12 +345,9 @@ def fork(self, new_seq_id: int) -> "Sequence":
def get_num_new_tokens(self) -> int:
"""Get the number of new tokens to be computed.
- Args:
- remainig_token_budget: The remaining token budgets.
Returns:
- The new number of tokens to be computed. I.e., 1 for decode, prompt
- size for prefill. If there's not enough remainig_token_budget, it
- can return the chunked number of new tokens.
+ The new number of tokens to be computed. I.e., 1 for decode, or
+ the remaining prompt size for prefill.
"""
if self.data.stage == SequenceStage.DECODE:
return 1
From 7be4f5628fc9999bf8a6025edd8f098353e0724b Mon Sep 17 00:00:00 2001
From: Ronen Schaffer
Date: Fri, 19 Apr 2024 18:08:26 +0300
Subject: [PATCH 074/413] [Bugfix][Core] Restore logging of stats in the async
engine (#4150)
---
vllm/engine/async_llm_engine.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index c436ece83f65a..ca4ba66f09cb8 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -217,10 +217,16 @@ async def step_async(self) -> List[RequestOutput]:
else:
output = []
- return self._process_model_outputs(
+ request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups)
+ # Log stats.
+ if self.log_stats:
+ self.stat_logger.log(self._get_stats(scheduler_outputs))
+
+ return request_outputs
+
async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
From 15b86408a89d5b998409e7fbe7850e937cc837da Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Fri, 19 Apr 2024 12:44:51 -0700
Subject: [PATCH 075/413] [Misc] add nccl in collect env (#4211)
---
.github/ISSUE_TEMPLATE/200-installation.yml | 1 +
.github/ISSUE_TEMPLATE/300-usage.yml | 1 +
.github/ISSUE_TEMPLATE/400-bug report.yml | 1 +
.github/ISSUE_TEMPLATE/700-performance discussion.yml | 1 +
collect_env.py | 2 ++
5 files changed, 6 insertions(+)
diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml
index 4c6c96187cc6c..df41ade8c3c01 100644
--- a/.github/ISSUE_TEMPLATE/200-installation.yml
+++ b/.github/ISSUE_TEMPLATE/200-installation.yml
@@ -18,6 +18,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
+ It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`
diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml
index 88227b4b2e7b9..54763af1058f6 100644
--- a/.github/ISSUE_TEMPLATE/300-usage.yml
+++ b/.github/ISSUE_TEMPLATE/300-usage.yml
@@ -18,6 +18,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
+ It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`
diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml
index c87f8fddcb776..08120ad8e5a60 100644
--- a/.github/ISSUE_TEMPLATE/400-bug report.yml
+++ b/.github/ISSUE_TEMPLATE/400-bug report.yml
@@ -18,6 +18,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
+ It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`
diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance discussion.yml
index 9e8e7b4aa3530..4f8843420a94e 100644
--- a/.github/ISSUE_TEMPLATE/700-performance discussion.yml
+++ b/.github/ISSUE_TEMPLATE/700-performance discussion.yml
@@ -39,6 +39,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
+ It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`
diff --git a/collect_env.py b/collect_env.py
index 8982fba024274..1ecfeb8e22e2f 100644
--- a/collect_env.py
+++ b/collect_env.py
@@ -63,6 +63,7 @@
"magma",
"triton",
"optree",
+ "nccl",
}
DEFAULT_PIP_PATTERNS = {
@@ -73,6 +74,7 @@
"triton",
"optree",
"onnx",
+ "nccl",
}
From bc9df1571b8002738eb8db70a07f552e32feb75f Mon Sep 17 00:00:00 2001
From: Chirag Jain
Date: Sat, 20 Apr 2024 05:43:56 +0530
Subject: [PATCH 076/413] Pass `tokenizer_revision` when getting tokenizer in
openai serving (#4214)
---
vllm/entrypoints/openai/serving_engine.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 8e5ee88d7f3a9..376b581052d85 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -68,6 +68,7 @@ async def _post_init(self):
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
+ tokenizer_revision=engine_model_config.tokenizer_revision,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
From 138485a82de50f90536ea0a650dd2f6bba1927e9 Mon Sep 17 00:00:00 2001
From: Ayush Rautwar <42046470+ayusher@users.noreply.github.com>
Date: Fri, 19 Apr 2024 23:49:22 -0400
Subject: [PATCH 077/413] [Bugfix] Add fix for JSON whitespace (#4189)
Co-authored-by: Ubuntu
---
tests/entrypoints/test_openai_server.py | 27 ++++++++++---------
.../outlines_logits_processors.py | 5 ++++
2 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index 0dd30eec30086..85a7ef464c032 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -754,19 +754,20 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
- resp = await client.chat.completions.create(
- model=MODEL_NAME,
- messages=[{
- "role":
- "user",
- "content": ('what is 1+1? please respond with a JSON object, '
- 'the format is {"result": 2}')
- }],
- response_format={"type": "json_object"})
-
- content = resp.choices[0].message.content
- loaded = json.loads(content)
- assert loaded == {"result": 2}, loaded
+ for _ in range(2):
+ resp = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[{
+ "role":
+ "user",
+ "content": ('what is 1+1? please respond with a JSON object, '
+ 'the format is {"result": 2}')
+ }],
+ response_format={"type": "json_object"})
+
+ content = resp.choices[0].message.content
+ loaded = json.loads(content)
+ assert loaded == {"result": 2}, loaded
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
index 95a67b612f08b..25ab5bf8b6a9c 100644
--- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -131,6 +131,11 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm
+ def init_state(self):
+ """Initialize state with a CFGFSM copy."""
+ super().init_state()
+ self.fsm = self.fsm.copy()
+
@lru_cache
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
From 682789d4026429a04cc32acb88064265441080dd Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Sat, 20 Apr 2024 04:51:33 +0100
Subject: [PATCH 078/413] Fix missing docs and out of sync `EngineArgs` (#4219)
Co-authored-by: Harry Mellor
---
docs/source/conf.py | 2 +
docs/source/models/engine_args.rst | 134 ++----------------------
vllm/engine/arg_utils.py | 159 ++++++++++++++++-------------
3 files changed, 98 insertions(+), 197 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 19cc8557a7541..cfa956b143ba3 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -11,12 +11,14 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.
import logging
+import os
import sys
from typing import List
from sphinx.ext import autodoc
logger = logging.getLogger(__name__)
+sys.path.append(os.path.abspath("../.."))
# -- Project information -----------------------------------------------------
diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst
index 235cb4e128c99..92bc7e0e843ed 100644
--- a/docs/source/models/engine_args.rst
+++ b/docs/source/models/engine_args.rst
@@ -5,133 +5,17 @@ Engine Arguments
Below, you can find an explanation of every engine argument for vLLM:
-.. option:: --model
-
- Name or path of the huggingface model to use.
-
-.. option:: --tokenizer
-
- Name or path of the huggingface tokenizer to use.
-
-.. option:: --revision
-
- The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
-
-.. option:: --tokenizer-revision
-
- The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
-
-.. option:: --tokenizer-mode {auto,slow}
-
- The tokenizer mode.
-
- * "auto" will use the fast tokenizer if available.
- * "slow" will always use the slow tokenizer.
-
-.. option:: --trust-remote-code
-
- Trust remote code from huggingface.
-
-.. option:: --download-dir
-
- Directory to download and load the weights, default to the default cache dir of huggingface.
-
-.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
-
- The format of the model weights to load.
-
- * "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
- * "pt" will load the weights in the pytorch bin format.
- * "safetensors" will load the weights in the safetensors format.
- * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
- * "dummy" will initialize the weights with random values, mainly for profiling.
- * "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. `_ See `examples/tensorize_vllm_model.py `_ to serialize a vLLM model, and for more information.
-
-.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
-
- Data type for model weights and activations.
-
- * "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
- * "half" for FP16. Recommended for AWQ quantization.
- * "float16" is the same as "half".
- * "bfloat16" for a balance between precision and range.
- * "float" is shorthand for FP32 precision.
- * "float32" for FP32 precision.
-
-.. option:: --max-model-len
-
- Model context length. If unspecified, will be automatically derived from the model config.
-
-.. option:: --worker-use-ray
-
- Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
-
-.. option:: --pipeline-parallel-size (-pp)
-
- Number of pipeline stages.
-
-.. option:: --tensor-parallel-size (-tp)
-
- Number of tensor parallel replicas.
-
-.. option:: --max-parallel-loading-workers
-
- Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
-
-.. option:: --block-size {8,16,32}
-
- Token block size for contiguous chunks of tokens.
-
-.. option:: --enable-prefix-caching
-
- Enables automatic prefix caching
-
-.. option:: --seed
-
- Random seed for operations.
-
-.. option:: --swap-space
-
- CPU swap space size (GiB) per GPU.
-
-.. option:: --gpu-memory-utilization
-
- The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
- For example, a value of 0.5 would imply 50% GPU memory utilization.
- If unspecified, will use the default value of 0.9.
-
-.. option:: --max-num-batched-tokens
-
- Maximum number of batched tokens per iteration.
-
-.. option:: --max-num-seqs
-
- Maximum number of sequences per iteration.
-
-.. option:: --max-paddings
-
- Maximum number of paddings in a batch.
-
-.. option:: --disable-log-stats
-
- Disable logging statistics.
-
-.. option:: --quantization (-q) {awq,squeezellm,None}
-
- Method used to quantize the weights.
+.. argparse::
+ :module: vllm.engine.arg_utils
+ :func: _engine_args_parser
+ :prog: -m vllm.entrypoints.openai.api_server
Async Engine Arguments
----------------------
-Below are the additional arguments related to the asynchronous engine:
-
-.. option:: --engine-use-ray
- Use Ray to start the LLM engine in a separate process as the server process.
-
-.. option:: --disable-log-requests
-
- Disable logging requests.
-
-.. option:: --max-log-len
+Below are the additional arguments related to the asynchronous engine:
- Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited.
\ No newline at end of file
+.. argparse::
+ :module: vllm.engine.arg_utils
+ :func: _async_engine_args_parser
+ :prog: -m vllm.entrypoints.openai.api_server
\ No newline at end of file
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 2999ab0a7e72a..53f129598270c 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -82,57 +82,55 @@ def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
- # NOTE: If you update any of the arguments below, please also
- # make sure to update docs/source/models/engine_args.rst
-
# Model arguments
parser.add_argument(
'--model',
type=str,
default='facebook/opt-125m',
- help='name or path of the huggingface model to use')
+ help='Name or path of the huggingface model to use.')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
- help='name or path of the huggingface tokenizer to use')
+ help='Name or path of the huggingface tokenizer to use.')
parser.add_argument(
'--revision',
type=str,
default=None,
- help='the specific model version to use. It can be a branch '
+ help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=str,
default=None,
- help='the specific revision to use for the model code on '
+ help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
default=None,
- help='the specific tokenizer version to use. It can be a branch '
+ help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
- parser.add_argument('--tokenizer-mode',
- type=str,
- default=EngineArgs.tokenizer_mode,
- choices=['auto', 'slow'],
- help='tokenizer mode. "auto" will use the fast '
- 'tokenizer if available, and "slow" will '
- 'always use the slow tokenizer.')
+ parser.add_argument(
+ '--tokenizer-mode',
+ type=str,
+ default=EngineArgs.tokenizer_mode,
+ choices=['auto', 'slow'],
+ help='The tokenizer mode.\n\n* "auto" will use the '
+ 'fast tokenizer if available.\n* "slow" will '
+ 'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
- help='trust remote code from huggingface')
+ help='Trust remote code from huggingface.')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
- help='directory to download and load the weights, '
+ help='Directory to download and load the weights, '
'default to the default cache dir of '
- 'huggingface')
+ 'huggingface.')
parser.add_argument(
'--load-format',
type=str,
@@ -140,19 +138,19 @@ def add_cli_args(
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
],
- help='The format of the model weights to load. '
- '"auto" will try to load the weights in the safetensors format '
+ help='The format of the model weights to load.\n\n'
+ '* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
- 'is not available. '
- '"pt" will load the weights in the pytorch bin format. '
- '"safetensors" will load the weights in the safetensors format. '
- '"npcache" will load the weights in pytorch format and store '
- 'a numpy cache to speed up the loading. '
- '"dummy" will initialize the weights with random values, '
- 'which is mainly for profiling.'
- '"tensorizer" will load the weights using tensorizer from CoreWeave'
- 'which assumes tensorizer_uri is set to the location of the '
- 'serialized weights.')
+ 'is not available.\n'
+ '* "pt" will load the weights in the pytorch bin format.\n'
+ '* "safetensors" will load the weights in the safetensors format.\n'
+ '* "npcache" will load the weights in pytorch format and store '
+ 'a numpy cache to speed up the loading.\n'
+ '* "dummy" will initialize the weights with random values, '
+ 'which is mainly for profiling.\n'
+ '* "tensorizer" will load the weights using tensorizer from '
+ 'CoreWeave which assumes tensorizer_uri is set to the location of '
+ 'the serialized weights.')
parser.add_argument(
'--dtype',
type=str,
@@ -160,10 +158,14 @@ def add_cli_args(
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
- help='data type for model weights and activations. '
- 'The "auto" option will use FP16 precision '
- 'for FP32 and FP16 models, and BF16 precision '
- 'for BF16 models.')
+ help='Data type for model weights and activations.\n\n'
+ '* "auto" will use FP16 precision for FP32 and FP16 models, and '
+ 'BF16 precision for BF16 models.\n'
+ '* "half" for FP16. Recommended for AWQ quantization.\n'
+ '* "float16" is the same as "half".\n'
+ '* "bfloat16" for a balance between precision and range.\n'
+ '* "float" is shorthand for FP32 precision.\n'
+ '* "float32" for FP32 precision.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
@@ -172,7 +174,7 @@ def add_cli_args(
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
- 'supported for common inference criteria. ')
+ 'supported for common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
@@ -183,58 +185,59 @@ def add_cli_args(
'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
- 'supported for common inference criteria. ')
+ 'supported for common inference criteria.')
parser.add_argument('--max-model-len',
type=int,
default=EngineArgs.max_model_len,
- help='model context length. If unspecified, '
- 'will be automatically derived from the model.')
+ help='Model context length. If unspecified, will '
+ 'be automatically derived from the model config.')
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
- ' (JSON schema / regex etc)')
+ ' (JSON schema / regex etc).')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
- help='use Ray for distributed serving, will be '
- 'automatically set when using more than 1 GPU')
+ help='Use Ray for distributed serving, will be '
+ 'automatically set when using more than 1 GPU.')
parser.add_argument('--pipeline-parallel-size',
'-pp',
type=int,
default=EngineArgs.pipeline_parallel_size,
- help='number of pipeline stages')
+ help='Number of pipeline stages.')
parser.add_argument('--tensor-parallel-size',
'-tp',
type=int,
default=EngineArgs.tensor_parallel_size,
- help='number of tensor parallel replicas')
+ help='Number of tensor parallel replicas.')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
default=EngineArgs.max_parallel_loading_workers,
- help='load model sequentially in multiple batches, '
+ help='Load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
- 'parallel and large models')
+ 'parallel and large models.')
parser.add_argument(
'--ray-workers-use-nsight',
action='store_true',
- help='If specified, use nsight to profile ray workers')
+ help='If specified, use nsight to profile Ray workers.')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 128],
- help='token block size')
+ help='Token block size for contiguous chunks of '
+ 'tokens.')
parser.add_argument('--enable-prefix-caching',
action='store_true',
- help='Enables automatic prefix caching')
+ help='Enables automatic prefix caching.')
parser.add_argument('--use-v2-block-manager',
action='store_true',
- help='Use BlockSpaceMangerV2')
+ help='Use BlockSpaceMangerV2.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
@@ -247,18 +250,19 @@ def add_cli_args(
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
- help='random seed')
+ help='Random seed for operations.')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
- help='CPU swap space size (GiB) per GPU')
+ help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
- help='the fraction of GPU memory to be used for '
- 'the model executor, which can range from 0 to 1.'
- 'If unspecified, will use the default value of 0.9.')
+ help='The fraction of GPU memory to be used for the model '
+ 'executor, which can range from 0 to 1. For example, a value of '
+ '0.5 would imply 50%% GPU memory utilization. If unspecified, '
+ 'will use the default value of 0.9.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
@@ -268,21 +272,21 @@ def add_cli_args(
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
- help='maximum number of batched tokens per '
- 'iteration')
+ help='Maximum number of batched tokens per '
+ 'iteration.')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
- help='maximum number of sequences per iteration')
+ help='Maximum number of sequences per iteration.')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
- help=('max number of log probs to return logprobs is specified in'
- ' SamplingParams'))
+ help=('Max number of log probs to return logprobs is specified in'
+ ' SamplingParams.'))
parser.add_argument('--disable-log-stats',
action='store_true',
- help='disable logging statistics')
+ help='Disable logging statistics.')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
@@ -303,13 +307,13 @@ def add_cli_args(
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
- help='maximum context length covered by CUDA '
+ help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
default=EngineArgs.disable_custom_all_reduce,
- help='See ParallelConfig')
+ help='See ParallelConfig.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
@@ -402,7 +406,7 @@ def add_cli_args(
'--enable-chunked-prefill',
action='store_true',
help='If set, the prefill requests can be chunked based on the '
- 'max_num_batched_tokens')
+ 'max_num_batched_tokens.')
parser.add_argument(
'--speculative-model',
@@ -416,7 +420,7 @@ def add_cli_args(
type=int,
default=None,
help='The number of speculative tokens to sample from '
- 'the draft model in speculative decoding')
+ 'the draft model in speculative decoding.')
parser.add_argument('--model-loader-extra-config',
type=str,
@@ -534,20 +538,31 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None
@staticmethod
- def add_cli_args(
- parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
- parser = EngineArgs.add_cli_args(parser)
+ def add_cli_args(parser: argparse.ArgumentParser,
+ async_args_only: bool = False) -> argparse.ArgumentParser:
+ if not async_args_only:
+ parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
- help='use Ray to start the LLM engine in a '
+ help='Use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests',
action='store_true',
- help='disable logging requests')
+ help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
- help='max number of prompt characters or prompt '
- 'ID numbers being printed in log. '
- 'Default: unlimited.')
+ help='Max number of prompt characters or prompt '
+ 'ID numbers being printed in log.'
+ '\n\nDefault: Unlimited')
return parser
+
+
+# These functions are used by sphinx to build the documentation
+def _engine_args_parser():
+ return EngineArgs.add_cli_args(argparse.ArgumentParser())
+
+
+def _async_engine_args_parser():
+ return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
+ async_args_only=True)
From a22cdea371bb26b4bdba112d4602736b48ca4a3a Mon Sep 17 00:00:00 2001
From: Cody Yu
Date: Fri, 19 Apr 2024 21:28:57 -0700
Subject: [PATCH 079/413] [Kernel][FP8] Initial support with dynamic per-tensor
scaling (#4118)
Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726
This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.
Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.
Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:
BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
---
tests/quantization/test_fp8.py | 24 +++
vllm/entrypoints/llm.py | 9 +-
vllm/model_executor/layers/linear.py | 8 +
.../layers/quantization/__init__.py | 2 +
.../model_executor/layers/quantization/fp8.py | 138 ++++++++++++++++++
vllm/model_executor/model_loader/loader.py | 4 +
.../model_loader/weight_utils.py | 9 +-
7 files changed, 189 insertions(+), 5 deletions(-)
create mode 100644 tests/quantization/test_fp8.py
create mode 100644 vllm/model_executor/layers/quantization/fp8.py
diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py
new file mode 100644
index 0000000000000..fa10e60de10a7
--- /dev/null
+++ b/tests/quantization/test_fp8.py
@@ -0,0 +1,24 @@
+"""Tests whether FP8 computation is enabled correctly.
+
+Run `pytest tests/quantization/test_fp8.py --forked`.
+"""
+import pytest
+import torch
+
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
+
+capability = torch.cuda.get_device_capability()
+capability = capability[0] * 10 + capability[1]
+
+
+@pytest.mark.skipif(
+ capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
+ reason="FP8 is not supported on this GPU type.")
+def test_load_fp16_model(vllm_runner) -> None:
+ llm = vllm_runner("facebook/opt-125m", quantization="fp8")
+
+ model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
+ fc1 = model.model.decoder.layers[0].fc1
+ assert isinstance(fc1.linear_method, Fp8LinearMethod)
+ assert fc1.weight.dtype == torch.float8_e4m3fn
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 9e08c253dc539..961de5d5063fa 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -42,10 +42,11 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
- we support "awq", "gptq" and "squeezellm". If None, we first check
- the `quantization_config` attribute in the model config file. If
- that is None, we assume the model weights are not quantized and use
- `dtype` to determine the data type of the weights.
+ we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
+ If None, we first check the `quantization_config` attribute in the
+ model config file. If that is None, we assume the model weights are
+ not quantized and use `dtype` to determine the data type of
+ the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 3ca870742efc5..d466d8807fc64 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -3,6 +3,7 @@
import torch
import torch.nn.functional as F
+from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@@ -48,6 +49,13 @@ def apply_weights(self,
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
+ def process_weights_after_loading(self, layer: nn.Module) -> None:
+ """Process the weight after loading.
+
+ This can be used for example, to transpose weights for computation.
+ """
+ return
+
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index a3b89a66469eb..0344d6e4e3e45 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -3,12 +3,14 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.layers.quantization.fp8 import FP8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = {
"awq": AWQConfig,
+ "fp8": FP8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
new file mode 100644
index 0000000000000..9dc0e86e1243d
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -0,0 +1,138 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+
+from vllm.model_executor.layers.linear import (LinearMethodBase,
+ set_weight_attrs)
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+
+
+class FP8Config(QuantizationConfig):
+ """Config class for FP8."""
+
+ @classmethod
+ def get_name(cls) -> str:
+ return "fp8"
+
+ @classmethod
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+ return [torch.bfloat16, torch.half]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ # TODO: PyTorch 2.3.0+ is required to run FP8 on
+ # SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
+ # be included: https://github.com/pytorch/pytorch/pull/118881
+ return 90
+
+ @classmethod
+ def get_config_filenames(cls) -> List[str]:
+ return []
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
+ return cls()
+
+ def get_linear_method(self) -> "Fp8LinearMethod":
+ return Fp8LinearMethod(self)
+
+ def get_scaled_act_names(self) -> List[str]:
+ return []
+
+
+class Fp8LinearMethod(LinearMethodBase):
+ """Linear method for FP8.
+ We now support common FP16/BF16 model checkpoints ONLY. The weight
+ scaling factor will be initialized after the model weights are loaded.
+
+ Limitations:
+ 1. Only support per-tensor quantization due to torch._scaled_mm support.
+ 2. Only support float8_e4m3fn data type due to the limitation of
+ torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
+
+ Args:
+ quant_config: The quantization config.
+ """
+
+ def __init__(self, quant_config: FP8Config):
+ self.quant_config = quant_config
+
+ def create_weights(
+ self,
+ layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_size_per_partition: int,
+ input_size: int,
+ output_size: int,
+ params_dtype: torch.dtype,
+ **extra_weight_attrs,
+ ):
+ weight = Parameter(torch.empty(output_size_per_partition,
+ input_size_per_partition,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("weight", weight)
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
+ set_weight_attrs(weight, extra_weight_attrs)
+
+ w_scale = Parameter(
+ torch.empty(1, dtype=torch.float32),
+ requires_grad=False,
+ )
+ layer.register_parameter("weight_scaling_factor", w_scale)
+
+ def process_weights_after_loading(self, layer: Module) -> None:
+ # Although the linear_method is propagated to all layers,
+ # only linear layers invoke "create_weights". So we check
+ # whether "weight_scaling_facor" is registered to determine
+ # whether the layer is a linear layer that requires quantization.
+ if not hasattr(layer, "weight_scaling_factor"):
+ return
+
+ qweight, weight_scale = per_tensor_quantize(layer.weight)
+ # torch._scaled_mm requires column-major in the second
+ # input (weight), so we transpose the quantized weight.
+ layer.weight = Parameter(qweight.t(), requires_grad=False)
+ layer.weight_scaling_factor.data.copy_(weight_scale)
+
+ def apply_weights(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ qinput, x_scale = per_tensor_quantize(x)
+ output, _ = torch._scaled_mm(
+ qinput,
+ layer.weight,
+ out_dtype=x.dtype,
+ scale_a=x_scale,
+ scale_b=layer.weight_scaling_factor,
+ bias=bias,
+ )
+ return output
+
+
+def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
+ """Quantize a tensor using per-tensor static scaling factor.
+
+ Args:
+ tensor: The input tensor.
+ """
+ finfo = torch.finfo(torch.float8_e4m3fn)
+ # Calculate the scale as dtype max divided by absmax.
+ # Since .abs() creates a new tensor, we use aminmax to get
+ # the min and max first and then calculate the absmax.
+ min_val, max_val = tensor.aminmax()
+ amax = min_val.abs().max(max_val.abs())
+ scale = finfo.max / amax.clamp(min=1e-12)
+ # scale and clamp the tensor to bring it to
+ # the representative range of float8 data type
+ # (as default cast is unsaturated)
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
+ # Return both float8 data and the inverse scale (as float),
+ # as both required as inputs to torch._scaled_mm
+ qweight = qweight.to(torch.float8_e4m3fn)
+ scale = scale.float().reciprocal()
+ return qweight, scale
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 3b1d125ef8a67..6c8cb2935f37e 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -228,6 +228,10 @@ def load_model(self, *, model_config: ModelConfig,
model,
"fall_back_to_pt_during_load",
True)), )
+ for _, module in model.named_modules():
+ linear_method = getattr(module, "linear_method", None)
+ if linear_method is not None:
+ linear_method.process_weights_after_loading(module)
return model.eval()
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 1798db0136868..9995f2afe3cf7 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -134,11 +134,18 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm)
else:
hf_folder = model_name_or_path
+
+ possible_config_filenames = quant_cls.get_config_filenames()
+
+ # If the quantization config is not found, use the default config.
+ if not possible_config_filenames:
+ return quant_cls()
+
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(
- f.endswith(x) for x in quant_cls.get_config_filenames())
+ f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(
From 91528575ec0e6fd251bac08973a3abf23d4c318c Mon Sep 17 00:00:00 2001
From: nunjunj <106306814+nunjunj@users.noreply.github.com>
Date: Sat, 20 Apr 2024 00:11:57 -0700
Subject: [PATCH 080/413] [Frontend] multiple sampling params support (#3570)
---
tests/entrypoints/test_llm_generate.py | 41 ++++++++++++++++++++++++++
vllm/entrypoints/llm.py | 30 ++++++++++++-------
2 files changed, 61 insertions(+), 10 deletions(-)
create mode 100644 tests/entrypoints/test_llm_generate.py
diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py
new file mode 100644
index 0000000000000..5e8b7ca4d9977
--- /dev/null
+++ b/tests/entrypoints/test_llm_generate.py
@@ -0,0 +1,41 @@
+import pytest
+
+from vllm import LLM, SamplingParams
+
+
+def test_multiple_sampling_params():
+
+ llm = LLM(model="facebook/opt-125m",
+ max_num_batched_tokens=4096,
+ tensor_parallel_size=1)
+
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+
+ sampling_params = [
+ SamplingParams(temperature=0.01, top_p=0.95),
+ SamplingParams(temperature=0.3, top_p=0.95),
+ SamplingParams(temperature=0.7, top_p=0.95),
+ SamplingParams(temperature=0.99, top_p=0.95),
+ ]
+
+ # Multiple SamplingParams should be matched with each prompt
+ outputs = llm.generate(prompts, sampling_params=sampling_params)
+ assert len(prompts) == len(outputs)
+
+ # Exception raised, if the size of params does not match the size of prompts
+ with pytest.raises(ValueError):
+ outputs = llm.generate(prompts, sampling_params=sampling_params[:3])
+
+ # Single SamplingParams should be applied to every prompt
+ single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
+ outputs = llm.generate(prompts, sampling_params=single_sampling_params)
+ assert len(prompts) == len(outputs)
+
+ # sampling_params is None, default params should be applied
+ outputs = llm.generate(prompts, sampling_params=None)
+ assert len(prompts) == len(outputs)
\ No newline at end of file
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index 961de5d5063fa..f745dbd736d17 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -127,7 +127,8 @@ def set_tokenizer(
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
- sampling_params: Optional[SamplingParams] = None,
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
@@ -142,7 +143,10 @@ def generate(
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
- None, we use the default sampling parameters.
+ None, we use the default sampling parameters.
+ When it is a single value, it is applied to every prompt.
+ When it is a list, the list must have the same length as the
+ prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
@@ -163,27 +167,33 @@ def generate(
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
+
+ if prompts is not None:
+ num_requests = len(prompts)
+ else:
+ assert prompt_token_ids is not None
+ num_requests = len(prompt_token_ids)
+
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
+ elif isinstance(sampling_params,
+ list) and len(sampling_params) != num_requests:
+ raise ValueError("The lengths of prompts and sampling_params "
+ "must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
- if prompts is not None:
- num_requests = len(prompts)
- else:
- assert prompt_token_ids is not None
- num_requests = len(prompt_token_ids)
-
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(
prompt,
- sampling_params,
+ sampling_params[i]
+ if isinstance(sampling_params, list) else sampling_params,
token_ids,
lora_request=lora_request,
# Get ith image while maintaining the batch dim.
@@ -232,4 +242,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
- return outputs
+ return outputs
\ No newline at end of file
From cc74b2b232070f74d8765a5eefa49ae93ee45490 Mon Sep 17 00:00:00 2001
From: Noam Gat
Date: Sat, 20 Apr 2024 11:33:16 +0300
Subject: [PATCH 081/413] Updating lm-format-enforcer version and adding links
to decoding libraries in docs (#4222)
---
requirements-common.txt | 2 +-
vllm/engine/arg_utils.py | 6 +++++-
vllm/model_executor/layers/quantization/fp8.py | 4 ++--
3 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/requirements-common.txt b/requirements-common.txt
index 3de0f98e7c0c3..3cc7bba8f84db 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -12,7 +12,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
-lm-format-enforcer == 0.9.3
+lm-format-enforcer == 0.9.8
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 53f129598270c..8939a0d537281 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -197,7 +197,11 @@ def add_cli_args(
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
- ' (JSON schema / regex etc).')
+ ' (JSON schema / regex etc) by default. Currently support '
+ 'https://github.com/outlines-dev/outlines and '
+ 'https://github.com/noamgat/lm-format-enforcer.'
+ ' Can be overridden per request via guided_decoding_backend'
+ ' parameter.')
# Parallel arguments
parser.add_argument('--worker-use-ray',
action='store_true',
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 9dc0e86e1243d..8df82e0e18edd 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.nn import Module
@@ -114,7 +114,7 @@ def apply_weights(self,
return output
-def per_tensor_quantize(tensor: torch.Tensor) -> tuple[torch.Tensor, float]:
+def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
From fe7d648fe56c138811e9b2b02937b55a84830454 Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Sun, 21 Apr 2024 17:15:28 +0100
Subject: [PATCH 082/413] Don't show default value for flags in `EngineArgs`
(#4223)
Co-authored-by: Harry Mellor
---
docs/source/models/engine_args.rst | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst
index 92bc7e0e843ed..bdf566d3ebbd1 100644
--- a/docs/source/models/engine_args.rst
+++ b/docs/source/models/engine_args.rst
@@ -9,6 +9,7 @@ Below, you can find an explanation of every engine argument for vLLM:
:module: vllm.engine.arg_utils
:func: _engine_args_parser
:prog: -m vllm.entrypoints.openai.api_server
+ :nodefaultconst:
Async Engine Arguments
----------------------
@@ -18,4 +19,5 @@ Below are the additional arguments related to the asynchronous engine:
.. argparse::
:module: vllm.engine.arg_utils
:func: _async_engine_args_parser
- :prog: -m vllm.entrypoints.openai.api_server
\ No newline at end of file
+ :prog: -m vllm.entrypoints.openai.api_server
+ :nodefaultconst:
\ No newline at end of file
From 7f2593b164c2ff115ba4fb9ce95fe63bdd824b85 Mon Sep 17 00:00:00 2001
From: xiaoji <44150358+YeFD@users.noreply.github.com>
Date: Mon, 22 Apr 2024 00:57:08 +0800
Subject: [PATCH 083/413] [Doc]: Update the doc of adding new models (#4236)
---
docs/source/models/adding_model.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst
index a82c2cef10e83..cbc8099e6f70f 100644
--- a/docs/source/models/adding_model.rst
+++ b/docs/source/models/adding_model.rst
@@ -95,7 +95,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
5. Register your model
----------------------
-Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py `_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py `_.
+Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py `_.
6. Out-of-Tree Model Integration
--------------------------------------------
From a37d815b83849b5a96a182929dd6f3bd35f68fb8 Mon Sep 17 00:00:00 2001
From: GeauxEric
Date: Sun, 21 Apr 2024 15:06:46 -0700
Subject: [PATCH 084/413] Make initialization of tokenizer and detokenizer
optional (#3748)
Co-authored-by: Yun Ding
Co-authored-by: Roger Wang
---
tests/engine/test_skip_tokenizer_init.py | 23 ++++++++++++++++
vllm/config.py | 7 ++++-
vllm/engine/arg_utils.py | 7 ++++-
vllm/engine/llm_engine.py | 29 +++++++++++++++------
vllm/engine/output_processor/single_step.py | 5 ++--
vllm/entrypoints/llm.py | 9 +++++++
6 files changed, 68 insertions(+), 12 deletions(-)
create mode 100644 tests/engine/test_skip_tokenizer_init.py
diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py
new file mode 100644
index 0000000000000..baa463a316902
--- /dev/null
+++ b/tests/engine/test_skip_tokenizer_init.py
@@ -0,0 +1,23 @@
+import pytest
+
+from vllm.entrypoints.llm import LLM
+from vllm.sampling_params import SamplingParams
+
+
+@pytest.mark.parametrize("model", ["facebook/opt-125m"])
+def test_skip_tokenizer_initialization(model: str):
+ # This test checks if the flag skip_tokenizer_init skips the initialization
+ # of tokenizer and detokenizer. The generated output is expected to contain
+ # token ids.
+ llm = LLM(model=model, skip_tokenizer_init=True)
+ sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
+ with pytest.raises(ValueError) as err:
+ llm.generate("abc", sampling_params)
+ assert "prompts must be None if" in str(err.value)
+ outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
+ sampling_params=sampling_params)
+ assert len(outputs) > 0
+ completions = outputs[0].outputs
+ assert len(completions) > 0
+ assert completions[0].text == ""
+ assert completions[0].token_ids
diff --git a/vllm/config.py b/vllm/config.py
index 2912d6ccc2c5b..97ede0faa21ab 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -66,6 +66,8 @@ class ModelConfig:
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
+ detokenizer.
"""
def __init__(
@@ -85,6 +87,7 @@ def __init__(
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
+ skip_tokenizer_init: bool = False,
) -> None:
self.model = model
self.tokenizer = tokenizer
@@ -99,6 +102,7 @@ def __init__(
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs
+ self.skip_tokenizer_init = skip_tokenizer_init
self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
@@ -106,7 +110,8 @@ def __init__(
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
- self._verify_tokenizer_mode()
+ if not self.skip_tokenizer_init:
+ self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 8939a0d537281..5de20633ffdd6 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -16,6 +16,7 @@ class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
+ skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
@@ -93,6 +94,10 @@ def add_cli_args(
type=str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.')
+ parser.add_argument(
+ '--skip-tokenizer-init',
+ action='store_true',
+ help='Skip initialization of tokenizer and detokenizer')
parser.add_argument(
'--revision',
type=str,
@@ -453,7 +458,7 @@ def create_engine_config(self, ) -> EngineConfig:
self.code_revision, self.tokenizer_revision, self.max_model_len,
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
- self.max_logprobs)
+ self.max_logprobs, self.skip_tokenizer_init)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 88b344f50767f..d96025ea1fb6a 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -100,6 +100,7 @@ def __init__(
f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, "
+ f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
@@ -132,8 +133,14 @@ def __init__(
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats
- self._init_tokenizer()
- self.detokenizer = Detokenizer(self.tokenizer)
+ if not self.model_config.skip_tokenizer_init:
+ self.tokenizer: BaseTokenizerGroup
+ self._init_tokenizer()
+ self.detokenizer = Detokenizer(self.tokenizer)
+ else:
+ self.detokenizer = None
+ self.tokenizer = None
+
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
@@ -187,9 +194,10 @@ def __init__(
parallel_config.disable_custom_all_reduce,
})
- # Ping the tokenizer to ensure liveness if it runs in a
- # different process.
- self.tokenizer.ping()
+ if self.tokenizer:
+ # Ping the tokenizer to ensure liveness if it runs in a
+ # different process.
+ self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
@@ -296,7 +304,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs):
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
- self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
+ self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
def _verify_args(self) -> None:
@@ -393,8 +401,13 @@ def add_request(
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
- eos_token_id = self.tokenizer.get_lora_tokenizer(
- lora_request).eos_token_id
+ eos_token_id = None
+ if self.tokenizer:
+ eos_token_id = self.tokenizer.get_lora_tokenizer(
+ lora_request).eos_token_id
+ else:
+ logger.warning("Use None for EOS token id because tokenizer is "
+ "not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py
index 1b7eb014f802b..b32937327ba7f 100644
--- a/vllm/engine/output_processor/single_step.py
+++ b/vllm/engine/output_processor/single_step.py
@@ -59,7 +59,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
- if prompt_logprobs is not None and seq_group.sampling_params.detokenize:
+ if prompt_logprobs is not None and \
+ seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
seq_group.prompt_logprobs = prompt_logprobs
@@ -105,7 +106,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
- if seq_group.sampling_params.detokenize:
+ if seq_group.sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
else:
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index f745dbd736d17..b022707794a78 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -32,6 +32,9 @@ class LLM:
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
+ skip_tokenizer_init: If true, skip initialization of tokenizer and
+ detokenizer. Expect valid prompt_token_ids and None for prompt
+ from the input.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
@@ -76,6 +79,7 @@ def __init__(
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
+ skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
@@ -96,6 +100,7 @@ def __init__(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
+ skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
@@ -160,6 +165,10 @@ def generate(
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
+ if self.llm_engine.model_config.skip_tokenizer_init \
+ and prompts is not None:
+ raise ValueError("prompts must be None if skip_tokenizer_init "
+ "is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
From 95e5b087cfed33bf20891852b2d7a0cac2341519 Mon Sep 17 00:00:00 2001
From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Date: Mon, 22 Apr 2024 00:57:24 -0400
Subject: [PATCH 085/413] [AMD][Hardware][Misc][Bugfix] xformer cleanup and
light navi logic and CI fixes and refactoring (#4129)
---
.buildkite/test-pipeline.yaml | 2 -
Dockerfile.rocm | 5 +-
patch_xformers.rocm.sh | 33 ----
.../commonpy_xformers-0.0.23.rocm.patch | 13 --
rocm_patch/flashpy_xformers-0.0.23.rocm.patch | 152 ------------------
vllm/attention/backends/rocm_flash_attn.py | 31 ++--
6 files changed, 19 insertions(+), 217 deletions(-)
delete mode 100644 patch_xformers.rocm.sh
delete mode 100644 rocm_patch/commonpy_xformers-0.0.23.rocm.patch
delete mode 100644 rocm_patch/flashpy_xformers-0.0.23.rocm.patch
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 0f920c7ec1442..f7c1569696249 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -15,10 +15,8 @@ steps:
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test
command: pytest -v -s core
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
index b1c5fac9d78ef..3f84b949481d1 100644
--- a/Dockerfile.rocm
+++ b/Dockerfile.rocm
@@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
-ARG FA_BRANCH="3d2b6f5"
+ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH"
# whether to build flash-attention
@@ -92,13 +92,10 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip numba
-RUN python3 -m pip install xformers==0.0.23 --no-deps
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
- && if [ "$BUILD_FA" = "1" ]; then \
- bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cd ..
diff --git a/patch_xformers.rocm.sh b/patch_xformers.rocm.sh
deleted file mode 100644
index de427b24d306f..0000000000000
--- a/patch_xformers.rocm.sh
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/bin/bash
-set -e
-
-XFORMERS_VERSION="0.0.23"
-
-export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
-
-if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
- echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
- exit 1
-fi
-
-export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
-export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
-
-echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
-echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
-
-if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
- echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
- patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
- echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
-else
- echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
-fi
-
-if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
- echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
- patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
- echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
-else
- echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
-fi
diff --git a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch b/rocm_patch/commonpy_xformers-0.0.23.rocm.patch
deleted file mode 100644
index 4d7495cf13e1d..0000000000000
--- a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch
+++ /dev/null
@@ -1,13 +0,0 @@
---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
-+++ common.py 2023-11-28 16:14:19.846233146 +0000
-@@ -298,8 +298,8 @@
- dtype = d.query.dtype
- if device_type not in cls.SUPPORTED_DEVICES:
- reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
-- if device_type == "cuda" and not _built_with_cuda:
-- reasons.append("xFormers wasn't build with CUDA support")
-+ #if device_type == "cuda" and not _built_with_cuda:
-+ # reasons.append("xFormers wasn't build with CUDA support")
- if device_type == "cuda":
- device_capability = torch.cuda.get_device_capability(d.device)
- if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
diff --git a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch b/rocm_patch/flashpy_xformers-0.0.23.rocm.patch
deleted file mode 100644
index ac846728a7a91..0000000000000
--- a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch
+++ /dev/null
@@ -1,152 +0,0 @@
---- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
-+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
-@@ -36,44 +36,44 @@
-
- FLASH_VERSION = "0.0.0"
- try:
-- try:
-- from ... import _C_flashattention # type: ignore[attr-defined]
-- from ..._cpp_lib import _build_metadata
--
-- if _build_metadata is not None:
-- FLASH_VERSION = _build_metadata.flash_version
-- except ImportError:
-- import flash_attn
-- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
--
-- FLASH_VERSION = flash_attn.__version__
-- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
-- if (
-- flash_ver_parsed != (2, 3, 6)
-- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
-- ):
-- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
-+ #try:
-+ # from ... import _C_flashattention # type: ignore[attr-defined]
-+ # from ..._cpp_lib import _build_metadata
-+
-+ # if _build_metadata is not None:
-+ # FLASH_VERSION = _build_metadata.flash_version
-+ #except ImportError:
-+ import flash_attn
-+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-+
-+ FLASH_VERSION = flash_attn.__version__
-+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
-+ # if (
-+ # flash_ver_parsed != (2, 3, 6)
-+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
-+ # ):
-+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
-
- # create library so that flash-attn goes through the PyTorch Dispatcher
-- _flash_lib = torch.library.Library("xformers_flash", "DEF")
--
-- _flash_lib.define(
-- "flash_fwd(Tensor query, Tensor key, Tensor value, "
-- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
-- "int max_seqlen_q, int max_seqlen_k, "
-- "float p, float softmax_scale, "
-- "bool is_causal, int window_left, "
-- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
-- )
-+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
-
-- _flash_lib.define(
-- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
-- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
-- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
-- "int max_seqlen_q, int max_seqlen_k, "
-- "float p, float softmax_scale, bool is_causal, "
-- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
-- )
-+ #_flash_lib.define(
-+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
-+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
-+ # "int max_seqlen_q, int max_seqlen_k, "
-+ # "float p, float softmax_scale, "
-+ # "bool is_causal, int window_left, "
-+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
-+ #)
-+
-+ #_flash_lib.define(
-+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
-+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
-+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
-+ # "int max_seqlen_q, int max_seqlen_k, "
-+ # "float p, float softmax_scale, bool is_causal, "
-+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
-+ #)
-
- def _flash_fwd(
- query,
-@@ -111,8 +111,8 @@
- p,
- softmax_scale,
- is_causal,
-- window_left, # window_size_left
-- window_right, # window_size_right
-+ # window_left, # window_size_left
-+ # window_right, # window_size_right
- return_softmax,
- None, # rng
- )
-@@ -134,15 +134,15 @@
- out,
- cu_seq_lens_q,
- cu_seq_lens_k,
-- seqused_k,
-+ # seqused_k,
- max_seq_len_q,
- max_seq_len_k,
- p,
- softmax_scale,
- False,
- is_causal,
-- window_left,
-- window_right,
-+ # window_left,
-+ # window_right,
- return_softmax,
- None,
- )
-@@ -184,8 +184,8 @@
- p,
- softmax_scale,
- is_causal,
-- window_left,
-- window_right,
-+ # window_left,
-+ # window_right,
- None,
- rng_state,
- )
-@@ -208,15 +208,15 @@
- softmax_scale,
- False, # zero_tensors
- is_causal,
-- window_left,
-- window_right,
-+ # window_left,
-+ # window_right,
- None,
- rng_state,
- )
- return dq, dk, dv
-
-- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
-- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
-+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
-+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
- except ImportError:
- pass
-
-@@ -400,7 +400,7 @@
- implementation.
- """
-
-- OPERATOR = get_operator("xformers_flash", "flash_fwd")
-+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
- SUPPORTED_DEVICES: Set[str] = {"cuda"}
- CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
- SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py
index c42660fb8f74f..dbaa71fd16add 100644
--- a/vllm/attention/backends/rocm_flash_attn.py
+++ b/vllm/attention/backends/rocm_flash_attn.py
@@ -154,25 +154,30 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
- self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
+ self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = (os.environ.get(
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
- if self.use_naive_attn:
- # AMD Radeon 7900 series (gfx1100) currently does not support
- # xFormers nor FlashAttention. As a temporary workaround, we use
- # naive PyTorch implementation of attention.
- self.attn_fuc = _naive_attention
- logger.debug("Using naive attention in ROCmBackend")
- elif self.use_triton_flash_attn:
+ if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
- from flash_attn import flash_attn_varlen_func # noqa: F401
- self.attn_func = flash_attn_varlen_func
- logger.debug("Using CK FA in ROCmBackend")
+ # if not using triton, navi3x not use flash-attn either
+ if torch.cuda.get_device_capability()[0] == 11:
+ self.use_naive_attn = True
+ else:
+ try:
+ from flash_attn import flash_attn_varlen_func # noqa: F401
+ self.attn_func = flash_attn_varlen_func
+ logger.debug("Using CK FA in ROCmBackend")
+ except ModuleNotFoundError:
+ self.use_naive_attn = True
+
+ if self.use_naive_attn:
+ self.attn_func = _naive_attention
+ logger.debug("Using naive attention in ROCmBackend")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
@@ -247,13 +252,13 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
- if self.use_naive_attn or self.use_triton_flash_attn:
+ if self.use_triton_flash_attn or self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
- out = self.attn_fuc(
+ out = self.attn_func(
query,
key,
value,
From 747b1a7147515c08491ef4aa1b23ea23329966ed Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Sun, 21 Apr 2024 23:04:16 -0700
Subject: [PATCH 086/413] [Core][Distributed] fix _is_full_nvlink detection
(#4233)
---
.../device_communicators/custom_all_reduce.py | 48 ++++++++++++-------
1 file changed, 30 insertions(+), 18 deletions(-)
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index 7602897d3dd8f..58cbe77baf7e0 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -1,5 +1,6 @@
+import os
from contextlib import contextmanager
-from typing import Optional
+from typing import List, Optional
import torch
import torch.distributed as dist
@@ -53,14 +54,20 @@ def init_custom_ar() -> None:
return False
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
- full_nvlink = _is_full_nvlink(rank, world_size)
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ device_ids = list(
+ map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
+ else:
+ device_ids = list(range(num_dev))
+ # this checks hardware and driver support for NVLink
+ full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warn(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
- # test P2P capability
+ # test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
@@ -138,23 +145,28 @@ def _nvml():
pynvml.nvmlShutdown()
-# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml()
-def _is_full_nvlink(rank, world_size):
- handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
- for i in range(world_size):
- if i != rank:
- try:
- peer_handle = pynvml.nvmlDeviceGetHandleByIndex(i)
- p2p_status = pynvml.nvmlDeviceGetP2PStatus(
- handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
- if p2p_status != pynvml.NVML_P2P_STATUS_OK:
+def _is_full_nvlink(device_ids: List[int]) -> bool:
+ """
+ query if the set of gpus are fully connected by nvlink (1 hop)
+ Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
+ so it works on real physical device ids.
+ """
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
+ for i, handle in enumerate(handles):
+ for j, peer_handle in enumerate(handles):
+ if i < j:
+ try:
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
+ return False
+ except pynvml.NVMLError as error:
+ logger.error(
+ "NVLink detection failed. This is normal if your"
+ " machine has no NVLink equipped.",
+ exc_info=error)
return False
- except pynvml.NVMLError as error:
- logger.info(
- f"NVLink detection failed with message \"{str(error)}\". "
- "This is normal if your machine has no NVLink equipped")
- return False
return True
From 296cdf8ac7853b671d10f3df094288ea4f35bb38 Mon Sep 17 00:00:00 2001
From: Isotr0py <41363108+Isotr0py@users.noreply.github.com>
Date: Mon, 22 Apr 2024 15:44:16 +0800
Subject: [PATCH 087/413] [Misc] Add vision language model support to CPU
backend (#3968)
---
vllm/executor/cpu_executor.py | 1 +
vllm/worker/cpu_model_runner.py | 60 ++++++++++++++++++++-------------
vllm/worker/cpu_worker.py | 24 ++++++++-----
3 files changed, 53 insertions(+), 32 deletions(-)
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index f925a6fc93dcd..35249cd7302cb 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -45,6 +45,7 @@ def _init_worker(self):
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
+ vision_language_config=self.vision_language_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index 7377c8931cefa..a82373d3d1626 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -5,7 +5,7 @@
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
- ParallelConfig, SchedulerConfig)
+ ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
@@ -29,6 +29,7 @@ def __init__(
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
+ vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
@@ -38,6 +39,7 @@ def __init__(
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
+ self.vision_language_config = vision_language_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
@@ -59,13 +61,14 @@ def __init__(
self.block_size: int # Set after initial profiling.
def load_model(self) -> None:
- self.model = get_model(model_config=self.model_config,
- load_config=self.load_config,
- device_config=self.device_config,
- vision_language_config=None,
- lora_config=self.lora_config,
- parallel_config=self.parallel_config,
- scheduler_config=self.scheduler_config)
+ self.model = get_model(
+ model_config=self.model_config,
+ load_config=self.load_config,
+ device_config=self.device_config,
+ vision_language_config=self.vision_language_config,
+ lora_config=self.lora_config,
+ parallel_config=self.parallel_config,
+ scheduler_config=self.scheduler_config)
def _prepare_prompt(
self,
@@ -76,6 +79,7 @@ def _prepare_prompt(
input_positions: List[int] = []
slot_mapping: List[int] = []
prompt_lens: List[int] = []
+ multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
@@ -96,6 +100,10 @@ def _prepare_prompt(
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, prompt_len)))
+ if seq_group_metadata.multi_modal_data:
+ multi_modal_input_list.append(
+ seq_group_metadata.multi_modal_data.data)
+
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
@@ -118,6 +126,15 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
+ if multi_modal_input_list:
+ assert self.vision_language_config, (
+ "Multi-modal inputs are only supported by "
+ "vision language models.")
+ multi_modal_input = torch.cat(multi_modal_input_list,
+ dim=0).to(self.device)
+ else:
+ multi_modal_input = None
+
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
@@ -144,12 +161,8 @@ def _prepare_prompt(
slot_mapping=slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
)
- return (
- input_tokens,
- input_positions,
- attn_metadata,
- prompt_lens,
- )
+ return (input_tokens, input_positions, attn_metadata, prompt_lens,
+ multi_modal_input)
def _prepare_decode(
self,
@@ -336,14 +349,16 @@ def prepare_input_tensors(
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
+ multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
- (input_tokens, input_positions, attn_metadata,
- prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
+ (input_tokens, input_positions, attn_metadata, prompt_lens,
+ multi_modal_input
+ ) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
@@ -376,12 +391,8 @@ def prepare_input_tensors(
perform_sampling=False,
)
- return (
- input_tokens,
- input_positions,
- attn_metadata,
- sampling_metadata,
- )
+ return (input_tokens, input_positions, attn_metadata,
+ sampling_metadata, multi_modal_input)
@torch.inference_mode()
def execute_model(
@@ -389,7 +400,8 @@ def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
- (input_tokens, input_positions, attn_metadata, sampling_metadata
+ (input_tokens, input_positions, attn_metadata, sampling_metadata,
+ multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
@@ -399,6 +411,8 @@ def execute_model(
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
+ if self.vision_language_config:
+ execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index 3652830b7d519..83ededd742533 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -6,7 +6,8 @@
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
- ModelConfig, ParallelConfig, SchedulerConfig)
+ ModelConfig, ParallelConfig, SchedulerConfig,
+ VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment)
@@ -122,6 +123,7 @@ def __init__(
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
+ vision_language_config: Optional[VisionLanguageConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
@@ -135,21 +137,25 @@ def __init__(
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
+ self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
+
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
- self.model_runner = CPUModelRunner(model_config,
- parallel_config,
- scheduler_config,
- device_config,
- load_config=self.load_config,
- lora_config=self.lora_config,
- kv_cache_dtype=kv_cache_dtype,
- is_driver_worker=is_driver_worker)
+ self.model_runner = CPUModelRunner(
+ model_config,
+ parallel_config,
+ scheduler_config,
+ device_config,
+ load_config=self.load_config,
+ lora_config=self.lora_config,
+ vision_language_config=self.vision_language_config,
+ kv_cache_dtype=kv_cache_dtype,
+ is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CPUCacheEngine
From e73ed0f1c624f85d348c0709c256a0ae6627986b Mon Sep 17 00:00:00 2001
From: Woosuk Kwon
Date: Mon, 22 Apr 2024 00:54:16 -0700
Subject: [PATCH 088/413] [Bugfix] Fix type annotations in CPU model runner
(#4256)
---
vllm/worker/cpu_model_runner.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index a82373d3d1626..bf0a6c84e6f07 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -73,7 +73,8 @@ def load_model(self) -> None:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]:
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
+ Optional[torch.Tensor]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
@@ -347,8 +348,8 @@ def _prepare_sample(
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
- SamplingMetadata]:
+ ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
+ Optional[torch.Tensor]]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
From 077f0a2e8a873340b1a2cf54d6c9043754eb7514 Mon Sep 17 00:00:00 2001
From: Tao He
Date: Mon, 22 Apr 2024 17:19:51 +0800
Subject: [PATCH 089/413] [Frontend] Enable support for CPU backend in
AsyncLLMEngine. (#3993)
Signed-off-by: Tao He
---
vllm/engine/async_llm_engine.py | 5 +++++
vllm/executor/cpu_executor.py | 27 +++++++++++++++++++++++++--
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index ca4ba66f09cb8..3a2f7db679358 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -343,6 +343,11 @@ def from_engine_args(
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
+ elif engine_config.device_config.device_type == "cpu":
+ assert not engine_config.parallel_config.worker_use_ray, (
+ "Ray is not supported with the CPU backend.")
+ from vllm.executor.cpu_executor import CPUExecutorAsync
+ executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 35249cd7302cb..8d6a1fff91fd8 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -4,11 +4,12 @@
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
-from vllm.executor.executor_base import ExecutorBase
+from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
-from vllm.utils import get_distributed_init_method, get_ip, get_open_port
+from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
+ make_async)
logger = init_logger(__name__)
@@ -100,6 +101,28 @@ def check_health(self) -> None:
return
+class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
+
+ async def execute_model_async(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ blocks_to_swap_in: Dict[int, int],
+ blocks_to_swap_out: Dict[int, int],
+ blocks_to_copy: Dict[int, List[int]],
+ ) -> SamplerOutput:
+ output = await make_async(self.driver_worker.execute_model)(
+ seq_group_metadata_list=seq_group_metadata_list,
+ blocks_to_swap_in=blocks_to_swap_in,
+ blocks_to_swap_out=blocks_to_swap_out,
+ blocks_to_copy=blocks_to_copy)
+ return output
+
+ async def check_health_async(self) -> None:
+ # CPUExecutor will always be healthy as long as
+ # it's running.
+ return
+
+
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.float16:
logger.warning("float16 is not supported on CPU, casting to bfloat16.")
From 15436806912d7ad9371c8bcf6a46857590c107d2 Mon Sep 17 00:00:00 2001
From: alexm-nm <59768536+alexm-nm@users.noreply.github.com>
Date: Mon, 22 Apr 2024 12:10:48 -0400
Subject: [PATCH 090/413] [Bugfix] Ensure download_weights_from_hf(..) inside
loader is using the revision parameter (#4217)
---
vllm/model_executor/model_loader/loader.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 6c8cb2935f37e..64cd186506bdb 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -172,7 +172,7 @@ def _prepare_weights(self, model_name_or_path: str,
if not is_local:
hf_folder = download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
- allow_patterns)
+ allow_patterns, revision)
else:
hf_folder = model_name_or_path
From 3d925165f2b18379640a63fbb42de95440d63b64 Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Mon, 22 Apr 2024 17:36:54 +0100
Subject: [PATCH 091/413] Add example scripts to documentation (#4225)
Co-authored-by: Harry Mellor
---
.gitignore | 2 +
docs/source/conf.py | 9 ++-
docs/source/generate_examples.py | 61 +++++++++++++++++++
.../examples/examples_index.template.rst | 8 +++
docs/source/index.rst | 1 +
...nt.py => openai_chat_completion_client.py} | 0
6 files changed, 80 insertions(+), 1 deletion(-)
create mode 100644 docs/source/generate_examples.py
create mode 100644 docs/source/getting_started/examples/examples_index.template.rst
rename examples/{openai_chatcompletion_client.py => openai_chat_completion_client.py} (100%)
diff --git a/.gitignore b/.gitignore
index b1513ef0ddb0c..e077366d1e4a1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -70,6 +70,8 @@ instance/
# Sphinx documentation
docs/_build/
+docs/source/getting_started/examples/*.rst
+!**/*.template.rst
# PyBuilder
.pybuilder/
diff --git a/docs/source/conf.py b/docs/source/conf.py
index cfa956b143ba3..aac8cbb63ebeb 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -48,7 +48,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns: List[str] = []
+exclude_patterns: List[str] = ["**/*.template.rst"]
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "
@@ -73,6 +73,13 @@
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']
+
+# Generate additional rst documentation here.
+def setup(app):
+ from docs.source.generate_examples import generate_examples
+ generate_examples()
+
+
# Mock out external dependencies here.
autodoc_mock_imports = [
"cpuinfo",
diff --git a/docs/source/generate_examples.py b/docs/source/generate_examples.py
new file mode 100644
index 0000000000000..79b49a186236a
--- /dev/null
+++ b/docs/source/generate_examples.py
@@ -0,0 +1,61 @@
+import re
+from pathlib import Path
+
+
+def fix_case(text: str) -> str:
+ subs = [
+ ("api", "API"),
+ ("llm", "LLM"),
+ ("vllm", "vLLM"),
+ ("openai", "OpenAI"),
+ ("multilora", "MultiLoRA"),
+ ]
+ for sub in subs:
+ text = re.sub(*sub, text, flags=re.IGNORECASE)
+ return text
+
+
+def underline(title: str, character: str = "=") -> str:
+ return f"{title}\n{character * len(title)}"
+
+
+def generate_title(filename: str) -> str:
+ # Turn filename into a title
+ title = filename.replace("_", " ").title()
+ # Handle acronyms and names
+ title = fix_case(title)
+ # Underline title
+ title = underline(title)
+ return title
+
+
+def generate_examples():
+ root_dir = Path(__file__).parent.parent.parent.resolve()
+
+ # Source paths
+ script_dir = root_dir / "examples"
+ script_paths = sorted(script_dir.glob("*.py"))
+
+ # Destination paths
+ doc_dir = root_dir / "docs/source/getting_started/examples"
+ doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths]
+
+ # Generate the example docs for each example script
+ for script_path, doc_path in zip(script_paths, doc_paths):
+ script_url = f"https://github.com/vllm-project/vllm/blob/main/examples/{script_path.name}"
+ # Make script_path relative to doc_path and call it include_path
+ include_path = '../../../..' / script_path.relative_to(root_dir)
+ content = (f"{generate_title(doc_path.stem)}\n\n"
+ f"Source {script_url}.\n\n"
+ f".. literalinclude:: {include_path}\n"
+ " :language: python\n"
+ " :linenos:\n")
+ with open(doc_path, "w+") as f:
+ f.write(content)
+
+ # Generate the toctree for the example scripts
+ with open(doc_dir / "examples_index.template.rst") as f:
+ examples_index = f.read()
+ with open(doc_dir / "examples_index.rst", "w+") as f:
+ example_docs = "\n ".join(path.stem for path in script_paths)
+ f.write(examples_index.replace(r"%EXAMPLE_DOCS%", example_docs))
diff --git a/docs/source/getting_started/examples/examples_index.template.rst b/docs/source/getting_started/examples/examples_index.template.rst
new file mode 100644
index 0000000000000..1b34cccbae15a
--- /dev/null
+++ b/docs/source/getting_started/examples/examples_index.template.rst
@@ -0,0 +1,8 @@
+Examples
+=================================
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Scripts
+
+ %EXAMPLE_DOCS%
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 5d5d52696ba34..e8daa5f052754 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -65,6 +65,7 @@ Documentation
getting_started/neuron-installation
getting_started/cpu-installation
getting_started/quickstart
+ getting_started/examples/examples_index
.. toctree::
:maxdepth: 1
diff --git a/examples/openai_chatcompletion_client.py b/examples/openai_chat_completion_client.py
similarity index 100%
rename from examples/openai_chatcompletion_client.py
rename to examples/openai_chat_completion_client.py
From ad8d696a99ca1eee19f1404e16e8e82df592ff85 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Tue, 23 Apr 2024 06:11:06 +0900
Subject: [PATCH 092/413] [Core] Scheduler perf fix (#4270)
---
tests/core/test_scheduler.py | 18 +++++++++---------
vllm/core/scheduler.py | 7 ++-----
2 files changed, 11 insertions(+), 14 deletions(-)
diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py
index 9588a1bead5f6..a2511238506b0 100644
--- a/tests/core/test_scheduler.py
+++ b/tests/core/test_scheduler.py
@@ -540,7 +540,7 @@ def test_decode_schedule_preempted():
curr_loras = None
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
running.append(seq_group)
scheduler.block_manager.can_append_slots = MagicMock()
@@ -581,7 +581,7 @@ def test_decode_swap_beam_search():
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
running.append(seq_group)
append_new_token_seq_group(60, seq_group, 1)
budget.add_num_seqs(seq_group.request_id,
@@ -629,7 +629,7 @@ def test_schedule_decode_blocks_to_copy_update():
running = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
running.append(seq_group)
@@ -659,7 +659,7 @@ def test_schedule_swapped_simple():
curr_loras = None
blocks_to_swap_out = {}
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
@@ -687,7 +687,7 @@ def test_schedule_swapped_max_token_budget():
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
@@ -721,7 +721,7 @@ def test_schedule_swapped_max_seqs():
blocks_to_swap_out = {}
for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
@@ -759,7 +759,7 @@ def test_schedule_swapped_max_loras():
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
@@ -783,7 +783,7 @@ def test_schedule_swapped_cannot_swap_in():
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
@@ -808,7 +808,7 @@ def test_schedule_swapped_blocks_to_copy():
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
- scheduler._allocate_and_set_running(seq_group, 60)
+ scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = {}
scheduler._swap_out(seq_group, blocks_to_swap_out)
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 4198550621030..8d7db09bbea08 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -297,7 +297,6 @@ def num_decoding_tokens_per_seq(self) -> int:
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
- logger.debug(f"add_seq_group {seq_group.request_id}")
self.waiting.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
@@ -427,7 +426,6 @@ def _schedule_running(
swapped_out.append(seq_group)
break
else:
- logger.debug(f"append slot for {seq_group}")
self._append_slots(seq_group, blocks_to_copy)
is_prefill = seq_group.is_prefill()
if is_prefill:
@@ -659,7 +657,7 @@ def _schedule_prefills(
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
- self._allocate_and_set_running(seq_group, num_new_tokens)
+ self._allocate_and_set_running(seq_group)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
@@ -952,8 +950,7 @@ def free_finished_seq_groups(self) -> None:
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())
- def _allocate_and_set_running(self, seq_group: SequenceGroup,
- num_new_tokens: int) -> None:
+ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
From ceaf4ed0030c7816537359b8efc750474149ce0f Mon Sep 17 00:00:00 2001
From: Zhanghao Wu
Date: Mon, 22 Apr 2024 15:34:31 -0700
Subject: [PATCH 093/413] [Doc] Update the SkyPilot doc with serving and
Llama-3 (#4276)
---
docs/source/serving/run_on_sky.rst | 287 ++++++++++++++++++++++++++---
1 file changed, 264 insertions(+), 23 deletions(-)
diff --git a/docs/source/serving/run_on_sky.rst b/docs/source/serving/run_on_sky.rst
index 2c88d24dc5d0b..bd33c76cec3de 100644
--- a/docs/source/serving/run_on_sky.rst
+++ b/docs/source/serving/run_on_sky.rst
@@ -1,7 +1,7 @@
.. _on_cloud:
-Running on clouds with SkyPilot
-===============================
+Deploying and scaling up with SkyPilot
+================================================
.. raw:: html
@@ -9,51 +9,75 @@ Running on clouds with SkyPilot
-vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot `__, an open-source framework for running LLMs on any cloud.
+vLLM can be **run and scaled to multiple service replicas on clouds and Kubernetes** with `SkyPilot `__, an open-source framework for running LLMs on any cloud. More examples for various open models, such as Llama-3, Mixtral, etc, can be found in `SkyPilot AI gallery `__.
-To install SkyPilot and setup your cloud credentials, run:
+
+Prerequisites
+-------------
+
+- Go to the `HuggingFace model page `__ and request access to the model :code:`meta-llama/Meta-Llama-3-8B-Instruct`.
+- Check that you have installed SkyPilot (`docs `__).
+- Check that :code:`sky check` shows clouds or Kubernetes are enabled.
.. code-block:: console
- $ pip install skypilot
- $ sky check
+ pip install skypilot-nightly
+ sky check
+
+
+Run on a single instance
+------------------------
See the vLLM SkyPilot YAML for serving, `serving.yaml `__.
.. code-block:: yaml
resources:
- accelerators: A100
+ accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model.
+ use_spot: True
+ disk_size: 512 # Ensure model checkpoints can fit.
+ disk_tier: best
+ ports: 8081 # Expose to internet traffic.
envs:
- MODEL_NAME: decapoda-research/llama-13b-hf
- TOKENIZER: hf-internal-testing/llama-tokenizer
+ MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
+ HF_TOKEN: # Change to your own huggingface token, or use --env to pass.
setup: |
- conda create -n vllm python=3.9 -y
+ conda create -n vllm python=3.10 -y
conda activate vllm
- git clone https://github.com/vllm-project/vllm.git
- cd vllm
- pip install .
- pip install gradio
+
+ pip install vllm==0.4.0.post1
+ # Install Gradio for web UI.
+ pip install gradio openai
+ pip install flash-attn==2.5.7
run: |
conda activate vllm
echo 'Starting vllm api server...'
- python -u -m vllm.entrypoints.api_server \
- --model $MODEL_NAME \
- --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
- --tokenizer $TOKENIZER 2>&1 | tee api_server.log &
+ python -u -m vllm.entrypoints.openai.api_server \
+ --port 8081 \
+ --model $MODEL_NAME \
+ --trust-remote-code \
+ --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
+ 2>&1 | tee api_server.log &
+
echo 'Waiting for vllm api server to start...'
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
+
echo 'Starting gradio server...'
- python vllm/examples/gradio_webserver.py
+ git clone https://github.com/vllm-project/vllm.git || true
+ python vllm/examples/gradio_openai_chatbot_webserver.py \
+ -m $MODEL_NAME \
+ --port 8811 \
+ --model-url http://localhost:8081/v1 \
+ --stop-token-ids 128009,128001
-Start the serving the LLaMA-13B model on an A100 GPU:
+Start the serving the Llama-3 8B model on any of the candidate GPUs listed (L4, A10g, ...):
.. code-block:: console
- $ sky launch serving.yaml
+ HF_TOKEN="your-huggingface-token" sky launch serving.yaml --env HF_TOKEN
Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
@@ -61,9 +85,226 @@ Check the output of the command. There will be a shareable gradio link (like the
(task, pid=7431) Running on public URL: https://.gradio.live
-**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
+**Optional**: Serve the 70B model instead of the default 8B and use more GPU:
+
+.. code-block:: console
+
+ HF_TOKEN="your-huggingface-token" sky launch serving.yaml --gpus A100:8 --env HF_TOKEN --env MODEL_NAME=meta-llama/Meta-Llama-3-70B-Instruct
+
+
+Scale up to multiple replicas
+-----------------------------
+
+SkyPilot can scale up the service to multiple service replicas with built-in autoscaling, load-balancing and fault-tolerance. You can do it by adding a services section to the YAML file.
+
+.. code-block:: yaml
+
+ service:
+ replicas: 2
+ # An actual request for readiness probe.
+ readiness_probe:
+ path: /v1/chat/completions
+ post_data:
+ model: $MODEL_NAME
+ messages:
+ - role: user
+ content: Hello! What is your name?
+ max_tokens: 1
+
+.. raw:: html
+
+
+ Click to see the full recipe YAML
+
+
+.. code-block:: yaml
+
+ service:
+ replicas: 2
+ # An actual request for readiness probe.
+ readiness_probe:
+ path: /v1/chat/completions
+ post_data:
+ model: $MODEL_NAME
+ messages:
+ - role: user
+ content: Hello! What is your name?
+ max_tokens: 1
+
+ resources:
+ accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB} # We can use cheaper accelerators for 8B model.
+ use_spot: True
+ disk_size: 512 # Ensure model checkpoints can fit.
+ disk_tier: best
+ ports: 8081 # Expose to internet traffic.
+
+ envs:
+ MODEL_NAME: meta-llama/Meta-Llama-3-8B-Instruct
+ HF_TOKEN: # Change to your own huggingface token, or use --env to pass.
+
+ setup: |
+ conda create -n vllm python=3.10 -y
+ conda activate vllm
+
+ pip install vllm==0.4.0.post1
+ # Install Gradio for web UI.
+ pip install gradio openai
+ pip install flash-attn==2.5.7
+
+ run: |
+ conda activate vllm
+ echo 'Starting vllm api server...'
+ python -u -m vllm.entrypoints.openai.api_server \
+ --port 8081 \
+ --model $MODEL_NAME \
+ --trust-remote-code \
+ --tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
+ 2>&1 | tee api_server.log &
+
+ echo 'Waiting for vllm api server to start...'
+ while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
+
+ echo 'Starting gradio server...'
+ git clone https://github.com/vllm-project/vllm.git || true
+ python vllm/examples/gradio_openai_chatbot_webserver.py \
+ -m $MODEL_NAME \
+ --port 8811 \
+ --model-url http://localhost:8081/v1 \
+ --stop-token-ids 128009,128001
+
+.. raw:: html
+
+
+
+Start the serving the Llama-3 8B model on multiple replicas:
+
+.. code-block:: console
+
+ HF_TOKEN="your-huggingface-token" sky serve up -n vllm serving.yaml --env HF_TOKEN
+
+
+Wait until the service is ready:
.. code-block:: console
- sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf
+ watch -n10 sky serve status vllm
+
+
+.. raw:: html
+
+
+ Example outputs:
+
+.. code-block:: console
+
+ Services
+ NAME VERSION UPTIME STATUS REPLICAS ENDPOINT
+ vllm 1 35s READY 2/2 xx.yy.zz.100:30001
+
+ Service Replicas
+ SERVICE_NAME ID VERSION IP LAUNCHED RESOURCES STATUS REGION
+ vllm 1 1 xx.yy.zz.121 18 mins ago 1x GCP({'L4': 1}) READY us-east4
+ vllm 2 1 xx.yy.zz.245 18 mins ago 1x GCP({'L4': 1}) READY us-east4
+
+.. raw:: html
+
+
+
+After the service is READY, you can find a single endpoint for the service and access the service with the endpoint:
+
+.. code-block:: console
+
+ ENDPOINT=$(sky serve status --endpoint 8081 vllm)
+ curl -L http://$ENDPOINT/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "meta-llama/Meta-Llama-3-8B-Instruct",
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant."
+ },
+ {
+ "role": "user",
+ "content": "Who are you?"
+ }
+ ],
+ "stop_token_ids": [128009, 128001]
+ }'
+
+To enable autoscaling, you could specify additional configs in `services`:
+
+.. code-block:: yaml
+
+ services:
+ replica_policy:
+ min_replicas: 0
+ max_replicas: 3
+ target_qps_per_replica: 2
+
+This will scale the service up to when the QPS exceeds 2 for each replica.
+
+
+**Optional**: Connect a GUI to the endpoint
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+
+It is also possible to access the Llama-3 service with a separate GUI frontend, so the user requests send to the GUI will be load-balanced across replicas.
+
+.. raw:: html
+
+
+ Click to see the full GUI YAML
+
+.. code-block:: yaml
+
+ envs:
+ MODEL_NAME: meta-llama/Meta-Llama-3-70B-Instruct
+ ENDPOINT: x.x.x.x:3031 # Address of the API server running vllm.
+
+ resources:
+ cpus: 2
+
+ setup: |
+ conda activate vllm
+ if [ $? -ne 0 ]; then
+ conda create -n vllm python=3.10 -y
+ conda activate vllm
+ fi
+
+ # Install Gradio for web UI.
+ pip install gradio openai
+
+ run: |
+ conda activate vllm
+ export PATH=$PATH:/sbin
+ WORKER_IP=$(hostname -I | cut -d' ' -f1)
+ CONTROLLER_PORT=21001
+ WORKER_PORT=21002
+
+ echo 'Starting gradio server...'
+ git clone https://github.com/vllm-project/vllm.git || true
+ python vllm/examples/gradio_openai_chatbot_webserver.py \
+ -m $MODEL_NAME \
+ --port 8811 \
+ --model-url http://$ENDPOINT/v1 \
+ --stop-token-ids 128009,128001 | tee ~/gradio.log
+
+.. raw:: html
+
+
+
+1. Start the chat web UI:
+
+.. code-block:: console
+
+ sky launch -c gui ./gui.yaml --env ENDPOINT=$(sky serve status --endpoint vllm)
+
+
+2. Then, we can access the GUI at the returned gradio link:
+
+.. code-block:: console
+
+ | INFO | stdout | Running on public URL: https://6141e84201ce0bb4ed.gradio.live
+
From c1b4e4157c0b4154f950adaea85a259fc629c758 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Mon, 22 Apr 2024 17:21:48 -0700
Subject: [PATCH 094/413] [Core][Distributed] use absolute path for library
file (#4271)
---
vllm/utils.py | 32 ++++++++++++++++++++++++++++++--
1 file changed, 30 insertions(+), 2 deletions(-)
diff --git a/vllm/utils.py b/vllm/utils.py
index fbe86dacaeb99..15c8818cc4506 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -553,6 +553,34 @@ def nccl_integrity_check(filepath):
return version.value
+@lru_cache(maxsize=None)
+def find_library(lib_name: str) -> str:
+ """
+ Find the library file in the system.
+ `lib_name` is full filename, with both prefix and suffix.
+ This function resolves `lib_name` to the full path of the library.
+ """
+ # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
+ # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
+ # `/sbin/ldconfig` should exist in all Linux systems.
+ # `/sbin/ldconfig` searches the library in the system
+ libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
+ # each line looks like the following:
+ # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
+ locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
+ # `LD_LIBRARY_PATH` searches the library in the user-defined paths
+ env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
+ if not locs and env_ld_library_path:
+ locs = [
+ os.path.join(dir, lib_name)
+ for dir in env_ld_library_path.split(":")
+ if os.path.exists(os.path.join(dir, lib_name))
+ ]
+ if not locs:
+ raise ValueError(f"Cannot find {lib_name} in the system.")
+ return locs[0]
+
+
def find_nccl_library():
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
@@ -572,9 +600,9 @@ def find_nccl_library():
)
else:
if torch.version.cuda is not None:
- so_file = vllm_nccl_path or "libnccl.so.2"
+ so_file = vllm_nccl_path or find_library("libnccl.so.2")
elif torch.version.hip is not None:
- so_file = "librccl.so.1"
+ so_file = find_library("librccl.so.1")
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
From 34128a697ed2dd4d88a9829f35445fbfc5b85c64 Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Tue, 23 Apr 2024 02:53:01 +0100
Subject: [PATCH 095/413] Fix `autodoc` directives (#4272)
Co-authored-by: Harry Mellor
---
docs/source/dev/engine/async_llm_engine.rst | 5 ++---
docs/source/dev/engine/llm_engine.rst | 6 +++---
docs/source/dev/sampling_params.rst | 3 ++-
3 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/docs/source/dev/engine/async_llm_engine.rst b/docs/source/dev/engine/async_llm_engine.rst
index 47db1e0a401b1..93fc310cb543b 100644
--- a/docs/source/dev/engine/async_llm_engine.rst
+++ b/docs/source/dev/engine/async_llm_engine.rst
@@ -1,7 +1,6 @@
-
AsyncLLMEngine
=================================
-.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
- :members: generate, abort
+.. autoclass:: vllm.AsyncLLMEngine
+ :members:
:show-inheritance:
diff --git a/docs/source/dev/engine/llm_engine.rst b/docs/source/dev/engine/llm_engine.rst
index 1de6d7adc87c6..0b8c1e219d7c9 100644
--- a/docs/source/dev/engine/llm_engine.rst
+++ b/docs/source/dev/engine/llm_engine.rst
@@ -1,6 +1,6 @@
LLMEngine
=================================
-.. autoclass:: vllm.engine.llm_engine.LLMEngine
- :members: add_request, abort_request, step
- :show-inheritance:
\ No newline at end of file
+.. autoclass:: vllm.LLMEngine
+ :members:
+ :show-inheritance:
diff --git a/docs/source/dev/sampling_params.rst b/docs/source/dev/sampling_params.rst
index 844859b3ec1f0..ef3d1509bda6d 100644
--- a/docs/source/dev/sampling_params.rst
+++ b/docs/source/dev/sampling_params.rst
@@ -1,4 +1,5 @@
Sampling Params
===============
-.. automodule:: vllm.sampling_params.SamplingParams
\ No newline at end of file
+.. autoclass:: vllm.SamplingParams
+ :members:
From 0ae11f78ab89556d5d867abb98f8a132f7507269 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Tue, 23 Apr 2024 13:32:44 +0900
Subject: [PATCH 096/413] [Mypy] Part 3 fix typing for nested directories for
most of directory (#4161)
---
.github/workflows/mypy.yaml | 29 ++++++++++---------
format.sh | 26 ++++++++---------
pyproject.toml | 6 ++--
vllm/attention/backends/abstract.py | 2 +-
vllm/attention/backends/rocm_flash_attn.py | 1 +
vllm/attention/backends/torch_sdpa.py | 3 +-
vllm/attention/backends/xformers.py | 1 +
vllm/core/block/block_table.py | 1 +
vllm/core/block/common.py | 6 ++--
vllm/core/block/interfaces.py | 6 ++--
.../device_communicators/custom_all_reduce.py | 16 +++++-----
.../device_communicators/pynccl.py | 22 ++++++++------
.../device_communicators/pynccl_utils.py | 5 +++-
vllm/engine/output_processor/interfaces.py | 5 ++--
vllm/engine/output_processor/multi_step.py | 5 ++--
vllm/engine/output_processor/single_step.py | 9 +++---
vllm/engine/output_processor/util.py | 4 ++-
vllm/entrypoints/openai/api_server.py | 6 ++--
vllm/entrypoints/openai/protocol.py | 13 +++++----
vllm/entrypoints/openai/serving_chat.py | 2 +-
vllm/entrypoints/openai/serving_completion.py | 5 +++-
vllm/entrypoints/openai/serving_engine.py | 19 +++++++-----
vllm/lora/lora.py | 2 +-
vllm/model_executor/layers/ops/sample.py | 3 +-
.../model_executor/layers/rotary_embedding.py | 3 +-
vllm/transformers_utils/configs/jais.py | 6 ++--
.../tokenizer_group/__init__.py | 2 +-
.../tokenizer_group/ray_tokenizer_group.py | 2 ++
.../transformers_utils/tokenizers/baichuan.py | 4 +--
29 files changed, 126 insertions(+), 88 deletions(-)
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
index 477ce9bc9ce85..9f1855696e20a 100644
--- a/.github/workflows/mypy.yaml
+++ b/.github/workflows/mypy.yaml
@@ -32,19 +32,20 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
- mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/attention --config-file pyproject.toml
+ # TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
-
- mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
- mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
- # TODO(sang): Follow up
- # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
+ mypy vllm/distributed --config-file pyproject.toml
+ mypy vllm/entrypoints --config-file pyproject.toml
+ mypy vllm/executor --config-file pyproject.toml
+ mypy vllm/usage --config-file pyproject.toml
+ mypy vllm/*.py --config-file pyproject.toml
+ mypy vllm/transformers_utils --config-file pyproject.toml
+ mypy vllm/engine --config-file pyproject.toml
+ mypy vllm/worker --config-file pyproject.toml
+ mypy vllm/spec_decode --config-file pyproject.toml
+ # TODO(sang): Fix nested dir
+ mypy vllm/model_executor/*.py --config-file pyproject.toml
+ # TODO(sang): Fix nested dir
+ # mypy vllm/lora/*.py --config-file pyproject.toml
diff --git a/format.sh b/format.sh
index 84ee88b5b4c8a..bd2e9e89e1806 100755
--- a/format.sh
+++ b/format.sh
@@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
-mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
-
-# TODO(sang): Follow up
-mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
-mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
-# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
+mypy vllm/distributed --config-file pyproject.toml
+mypy vllm/entrypoints --config-file pyproject.toml
+mypy vllm/executor --config-file pyproject.toml
+mypy vllm/usage --config-file pyproject.toml
+mypy vllm/*.py --config-file pyproject.toml
+mypy vllm/transformers_utils --config-file pyproject.toml
+mypy vllm/engine --config-file pyproject.toml
+mypy vllm/worker --config-file pyproject.toml
+mypy vllm/spec_decode --config-file pyproject.toml
+mypy vllm/model_executor/*.py --config-file pyproject.toml
+# mypy vllm/lora/*.py --config-file pyproject.toml
CODESPELL_EXCLUDES=(
diff --git a/pyproject.toml b/pyproject.toml
index b870a4b85897b..a171d45b4e064 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,15 +46,17 @@ ignore = [
python_version = "3.8"
ignore_missing_imports = true
- check_untyped_defs = true
+check_untyped_defs = true
+follow_imports = "skip"
files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
+ # Ignore triton kernels in ops.
+ 'vllm/attention/ops/.*\.py$'
]
-
[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py
index 7a4ccecf702f4..be747c9900368 100644
--- a/vllm/attention/backends/abstract.py
+++ b/vllm/attention/backends/abstract.py
@@ -116,7 +116,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
+ attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError
diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py
index dbaa71fd16add..7c5863a030ff5 100644
--- a/vllm/attention/backends/rocm_flash_attn.py
+++ b/vllm/attention/backends/rocm_flash_attn.py
@@ -248,6 +248,7 @@ def forward(
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
+ assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py
index d21b54b16db4b..55a7ce59ac6e0 100644
--- a/vllm/attention/backends/torch_sdpa.py
+++ b/vllm/attention/backends/torch_sdpa.py
@@ -106,7 +106,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
- attn_metadata: TorchSDPAMetadata,
+ attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@@ -136,6 +136,7 @@ def forward(
kv_scale)
if attn_metadata.is_prompt:
+ assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py
index b745a04a143b4..572a4dc79a719 100644
--- a/vllm/attention/backends/xformers.py
+++ b/vllm/attention/backends/xformers.py
@@ -288,6 +288,7 @@ def _run_memory_efficient_xformers_forward(
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
+ assert attn_metadata.prompt_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py
index 560267e55ea3a..f1b65b2514f76 100644
--- a/vllm/core/block/block_table.py
+++ b/vllm/core/block/block_table.py
@@ -104,6 +104,7 @@ def append_token_ids(self,
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
+ assert self._blocks is not None
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py
index 50c70533c4fbc..f11234a0bf2dd 100644
--- a/vllm/core/block/common.py
+++ b/vllm/core/block/common.py
@@ -99,7 +99,7 @@ def __init__(
refcounter: RefCounter,
allocator: BlockAllocator,
):
- self._copy_on_writes = defaultdict(list)
+ self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator
@@ -138,6 +138,8 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
prev_block=block.prev_block).block_id
# Track src/dst copy.
+ assert src_block_id is not None
+ assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id)
return block_id
@@ -180,6 +182,6 @@ def recurse(block: Block, lst: List[Block]) -> None:
recurse(block.prev_block, lst)
lst.append(block)
- all_blocks = []
+ all_blocks: List[Block] = []
recurse(last_block, all_blocks)
return all_blocks
diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py
index fbceacf0ec417..50ce922118124 100644
--- a/vllm/core/block/interfaces.py
+++ b/vllm/core/block/interfaces.py
@@ -52,8 +52,7 @@ def __call__(
class BlockAllocator(ABC):
@abstractmethod
- def allocate_mutable(self, prev_block: Optional[Block],
- device: Device) -> Block:
+ def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
@@ -98,8 +97,7 @@ class NoFreeBlocksError(ValueError):
class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod
- def allocate_mutable(self, prev_block: Optional[Block],
- device: Device) -> Block:
+ def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index 58cbe77baf7e0..9dbb427d91ff1 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -1,6 +1,6 @@
import os
from contextlib import contextmanager
-from typing import List, Optional
+from typing import Any, List, Optional
import torch
import torch.distributed as dist
@@ -18,7 +18,7 @@
logger = init_logger(__name__)
-_CA_HANDLE = None
+_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
@@ -51,7 +51,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
- return False
+ return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ:
@@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
- return
+ return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
@@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
+ return None
+
@contextmanager
def _nvml():
@@ -224,14 +226,14 @@ def _get_ipc_meta(self, inp: torch.Tensor):
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
- all_data = [None] * self.world_size
+ all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
- handles.append(all_data[i][0])
- offsets.append(all_data[i][1])
+ handles.append(all_data[i][0]) # type: ignore
+ offsets.append(all_data[i][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
index c57a4f59d442c..0707afe922f40 100644
--- a/vllm/distributed/device_communicators/pynccl.py
+++ b/vllm/distributed/device_communicators/pynccl.py
@@ -107,9 +107,10 @@ def ncclGetUniqueId() -> NcclUniqueId:
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
+ncclDataType_t = ctypes.c_int
-# enums
-class ncclDataType_t(ctypes.c_int):
+
+class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
@@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10
@classmethod
- def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
+ def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
@@ -148,7 +149,10 @@ def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
raise ValueError(f"Unsupported dtype: {dtype}")
-class ncclRedOp_t(ctypes.c_int):
+ncclRedOp_t = ctypes.c_int
+
+
+class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
@@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5
@classmethod
- def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
+ def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
@@ -180,8 +184,8 @@ def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
- ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
- ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
+ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
+ ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]
# equivalent to c declaration:
@@ -251,8 +255,8 @@ def all_reduce(self,
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
- ncclDataType_t.from_torch(tensor.dtype),
- ncclRedOp_t.from_torch(op), self.comm,
+ ncclDataTypeEnum.from_torch(tensor.dtype),
+ ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py
index aeb73015733d1..916dc814af7eb 100644
--- a/vllm/distributed/device_communicators/pynccl_utils.py
+++ b/vllm/distributed/device_communicators/pynccl_utils.py
@@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
+ assert comm is not None
comm.stream = stream
yield
finally:
@@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
+ assert comm is not None
comm.all_reduce(input_, op)
@@ -62,8 +64,9 @@ def destroy_process_group() -> None:
def get_world_size() -> int:
"""Returns the world size."""
+ assert comm is not None
return comm.world_size
-def get_nccl_backend():
+def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm
diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py
index 9ddac7a04cb36..f307ea4da3011 100644
--- a/vllm/engine/output_processor/interfaces.py
+++ b/vllm/engine/output_processor/interfaces.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Callable, Iterable, List
+from typing import Callable, List
from transformers import PreTrainedTokenizer
@@ -8,6 +8,7 @@
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
+from vllm.utils import Counter
class SequenceGroupOutputProcessor(ABC):
@@ -27,7 +28,7 @@ def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
- seq_counter: Iterable[int],
+ seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py
index 50da0d35fcec1..39e99d06ed875 100644
--- a/vllm/engine/output_processor/multi_step.py
+++ b/vllm/engine/output_processor/multi_step.py
@@ -1,4 +1,4 @@
-from typing import Callable, Iterable, List
+from typing import Callable, List
from transformers import PreTrainedTokenizer
@@ -11,6 +11,7 @@
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
+from vllm.utils import Counter
logger = init_logger(__name__)
@@ -33,7 +34,7 @@ def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
- seq_counter: Iterable[int],
+ seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py
index b32937327ba7f..7e9d652446703 100644
--- a/vllm/engine/output_processor/single_step.py
+++ b/vllm/engine/output_processor/single_step.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Tuple, Union
+from typing import Dict, List, Tuple, Union
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
@@ -10,6 +10,7 @@
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
+from vllm.utils import Counter
logger = init_logger(__name__)
@@ -33,7 +34,7 @@ def __init__(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
- seq_counter: Iterable[int],
+ seq_counter: Counter,
stop_checker: StopChecker,
):
self.scheduler_config = scheduler_config
@@ -69,7 +70,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
- parent_child_dict = {
+ parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
}
@@ -92,7 +93,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
continue
# Fork the parent sequence if there are multiple child samples.
for child_sample in child_samples[:-1]:
- new_child_seq_id = next(self.seq_counter)
+ new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py
index 5fbb09a857a46..d076fee8c2a36 100644
--- a/vllm/engine/output_processor/util.py
+++ b/vllm/engine/output_processor/util.py
@@ -8,7 +8,9 @@ def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
- output_by_sequence_group = [[] for _ in range(num_seq_groups)]
+ output_by_sequence_group: List[List[SamplerOutput]] = [
+ [] for _ in range(num_seq_groups)
+ ]
for step in sampler_outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index d6673976bb775..37d76b8e74055 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -18,6 +18,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ ChatCompletionResponse,
CompletionRequest, ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@@ -26,8 +27,8 @@
TIMEOUT_KEEP_ALIVE = 5 # seconds
-openai_serving_chat: OpenAIServingChat = None
-openai_serving_completion: OpenAIServingCompletion = None
+openai_serving_chat: OpenAIServingChat
+openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__)
@@ -95,6 +96,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
+ assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index cf779d44c816b..d9763d024eb83 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -4,7 +4,8 @@
from typing import Dict, List, Literal, Optional, Union
import torch
-from pydantic import BaseModel, Field, conint, model_validator
+from pydantic import BaseModel, Field, model_validator
+from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
@@ -30,7 +31,7 @@ class ModelPermission(BaseModel):
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
- is_blocking: str = False
+ is_blocking: bool = False
class ModelCard(BaseModel):
@@ -56,7 +57,7 @@ class UsageInfo(BaseModel):
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
- type: str = Literal["text", "json_object"]
+ type: Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel):
@@ -152,6 +153,7 @@ def to_sampling_params(self) -> SamplingParams:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
+ assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
@@ -213,7 +215,7 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
- n: Optional[int] = 1
+ n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
@@ -235,7 +237,7 @@ class CompletionRequest(BaseModel):
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
- truncate_prompt_tokens: Optional[conint(ge=1)] = None
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
@@ -289,6 +291,7 @@ def to_sampling_params(self):
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
+ assert self.logit_bias is not None
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index f35eab15bc487..d502dd0a4eb75 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -115,12 +115,12 @@ async def chat_completion_stream_generator(
first_iteration = True
# Send response for each token for each request.n (index)
+ assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
try:
async for res in result_generator:
- res: RequestOutput
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index b7e2530a69b51..211b2e0424c3e 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -185,6 +185,7 @@ async def completion_stream_generator(
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
+ assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
@@ -202,6 +203,7 @@ async def completion_stream_generator(
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
+ assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
@@ -279,7 +281,7 @@ def request_output_to_completion_response(
created_time: int,
model_name: str,
) -> CompletionResponse:
- choices = []
+ choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
@@ -289,6 +291,7 @@ def request_output_to_completion_response(
prompt_text = final_res.prompt
for output in final_res.outputs:
+ assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 376b581052d85..610e807cae4c7 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -4,7 +4,9 @@
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
-from pydantic import conint
+from pydantic import Field
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
+from typing_extensions import Annotated
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -45,7 +47,8 @@ def __init__(self,
]
self.max_model_len = 0
- self.tokenizer = None
+ # Lazy initialized
+ self.tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
try:
event_loop = asyncio.get_running_loop()
@@ -92,7 +95,7 @@ async def show_available_models(self) -> ModelList:
def _create_logprobs(
self,
token_ids: List[int],
- top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
+ top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
@@ -108,6 +111,7 @@ def _create_logprobs(
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
+ assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
@@ -116,6 +120,7 @@ def _create_logprobs(
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
+ assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
@@ -155,9 +160,9 @@ def create_streaming_error_response(
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
- return
+ return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
- return
+ return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
@@ -165,7 +170,7 @@ async def _check_model(self, request) -> Optional[ErrorResponse]:
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
- return
+ return None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
@@ -177,7 +182,7 @@ def _validate_prompt_and_tokenize(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
- truncate_prompt_tokens: Optional[conint(ge=1)] = None
+ truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py
index 21c2196eb2739..fefad16700fe3 100644
--- a/vllm/lora/lora.py
+++ b/vllm/lora/lora.py
@@ -33,7 +33,7 @@ def __init__(
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
- return
+ return self
self.lora_b *= self.scaling
self.scaling = 1
return self
diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py
index a19e9461f41f7..d08ae6064aa2a 100644
--- a/vllm/model_executor/layers/ops/sample.py
+++ b/vllm/model_executor/layers/ops/sample.py
@@ -29,8 +29,8 @@ def _multi_split_sample(
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
+ logprobs: torch.Tensor,
*,
- logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
@@ -167,6 +167,7 @@ def sample(
sampled_logprobs_size = (0, 0)
logprobs = probs
+ assert logprobs is not None
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index 6519781c8a8eb..a5225148d7828 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -108,7 +108,8 @@ def _forward(
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
- self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
+ self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
+ positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py
index 94f438716f8bf..b06a946f34a47 100644
--- a/vllm/transformers_utils/configs/jais.py
+++ b/vllm/transformers_utils/configs/jais.py
@@ -222,13 +222,15 @@ def _alibi_scaling_validation(self):
f"got {alibi_scaling_type}")
if (alibi_scaling_factor is not None
and not isinstance(alibi_scaling_factor, float)
- or alibi_scaling_factor <= 1.0):
+ or (alibi_scaling_factor is not None
+ and alibi_scaling_factor <= 1.0)):
raise ValueError(
f"`alibi_scaling`'s factor field must be a float > 1.0,"
f"got {alibi_scaling_factor}")
if (alibi_dynamic_scaling is not None
and not isinstance(alibi_dynamic_scaling, int)
- or alibi_dynamic_scaling <= 1):
+ or (alibi_dynamic_scaling is not None
+ and alibi_dynamic_scaling <= 1)):
raise ValueError(
f"`alibi_scaling`'s `train_seq_len` field must be an"
f"integer > 1, got {alibi_dynamic_scaling}")
diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py
index a3b979e8fbc13..69380d67f9b94 100644
--- a/vllm/transformers_utils/tokenizer_group/__init__.py
+++ b/vllm/transformers_utils/tokenizer_group/__init__.py
@@ -11,7 +11,7 @@
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
else:
- RayTokenizerGroupPool = None
+ RayTokenizerGroupPool = None # type: ignore
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
index c00b02fdbbbc0..f3cdc00564dbb 100644
--- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
+++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
@@ -89,6 +89,7 @@ def encode(self,
This is blocking.
"""
self._ensure_queue_initialized()
+ assert self._idle_actors is not None
if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
@@ -120,6 +121,7 @@ async def encode_async(
This is non-blocking.
"""
self._ensure_queue_initialized()
+ assert self._idle_actors is not None
actor = await self._idle_actors.get()
try:
diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py
index 79894035cb1f1..76daabc41e0a2 100644
--- a/vllm/transformers_utils/tokenizers/baichuan.py
+++ b/vllm/transformers_utils/tokenizers/baichuan.py
@@ -114,9 +114,9 @@ def _convert_id_to_token(self, index):
token = self.sp_model.IdToPiece(index)
return token
- def convert_tokens_to_string(self, tokens):
+ def convert_tokens_to_string(self, tokens: List[str]):
"""Converts a sequence of tokens (string) in a single string."""
- current_sub_tokens = []
+ current_sub_tokens: List[str] = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
From 8f2ea22bde67161895152e7f7ad602ac96ea326e Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Tue, 23 Apr 2024 00:49:08 -0700
Subject: [PATCH 097/413] [Core] Some simplification of WorkerWrapper changes
(#4183)
---
vllm/executor/ray_gpu_executor.py | 90 ++++++++++++++-----------------
vllm/worker/worker_base.py | 9 ++--
2 files changed, 45 insertions(+), 54 deletions(-)
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index f779b0f8a5113..d0b5e682bb6f7 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -2,6 +2,7 @@
import os
import pickle
from collections import defaultdict
+from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.engine.ray_utils import RayWorkerWrapper, ray
@@ -136,16 +137,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
- all_args_to_update_environment_variables = []
- for (node_id, _) in worker_node_and_gpu_ids:
- all_args_to_update_environment_variables.append([{
- "CUDA_VISIBLE_DEVICES":
- ",".join(map(str, node_gpus[node_id])),
- "VLLM_INSTANCE_ID":
- VLLM_INSTANCE_ID,
- "VLLM_TRACE_FUNCTION":
- os.getenv("VLLM_TRACE_FUNCTION", "0"),
- }])
+ all_args_to_update_environment_variables = [({
+ "CUDA_VISIBLE_DEVICES":
+ ",".join(map(str, node_gpus[node_id])),
+ "VLLM_INSTANCE_ID":
+ VLLM_INSTANCE_ID,
+ "VLLM_TRACE_FUNCTION":
+ os.getenv("VLLM_TRACE_FUNCTION", "0"),
+ }, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
@@ -156,10 +155,9 @@ def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs
- init_worker_all_kwargs = []
-
# Initialize the actual workers inside worker wrapper.
- for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
+ init_worker_all_kwargs = []
+ for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
@@ -265,40 +263,40 @@ def _run_workers(
self,
method: str,
*args,
- driver_args: Optional[Tuple[Any]] = None,
+ driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
- all_args: Optional[List[List[Any]]] = None,
+ all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
- """Runs the given method on all workers.
- all_args and all_kwargs are used to pass heterogeneous arguments,
- i.e. different arguments for each worker.
+ """Runs the given method on all workers. Can be used in the following
+ ways:
+
+ - args/kwargs: All workers share the same args/kwargs
+ - args/kwargs and driver_args/driver_kwargs: Driver worker has
+ different args
+ - all_args/all_kwargs: args/kwargs for each worker are specified
+ individually
"""
- if driver_args is None:
- driver_args = args
- if driver_kwargs is None:
- driver_kwargs = kwargs
-
- # for mypy type checking
- assert driver_args is not None
- assert driver_kwargs is not None
- if all_args is None:
- all_args = [driver_args] + [args] * len(self.workers)
- if all_kwargs is None:
- all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
-
- # for mypy type checking
- assert all_args is not None
- assert all_kwargs is not None
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
+ if driver_args is None:
+ driver_args = args if all_args is None else all_args[0]
+ if driver_kwargs is None:
+ driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
+
+ count = len(self.workers)
+ all_worker_args = repeat(args, count) if all_args is None \
+ else islice(all_args, 1, None)
+ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
+ else islice(all_kwargs, 1, None)
+
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
@@ -310,22 +308,17 @@ def _run_workers(
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
- ) in zip(self.workers, all_args[1:], all_kwargs[1:])
+ ) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
- if driver_args is None:
- driver_args = args
- if driver_kwargs is None:
- driver_kwargs = kwargs
-
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
- method, *all_args[0], **all_kwargs[0])
+ method, *driver_args, **driver_kwargs)
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
- method, *all_args[0], **all_kwargs[0]))
+ method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
@@ -383,6 +376,10 @@ def _check_if_any_actor_is_dead(self):
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.driver_executor = make_async(self.driver_worker.execute_method)
+
async def _run_workers_async(
self,
method: str,
@@ -399,13 +396,8 @@ async def _run_workers_async(
if driver_kwargs is None:
driver_kwargs = kwargs
- # Run the driver worker asynchronously.
- def helper():
- return self.driver_worker.execute_method(method, *driver_args,
- **driver_kwargs)
-
- driver_executor = make_async(helper)
- coros.append(driver_executor())
+ coros.append(
+ self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index 783dff3a43404..bcd04e0f98db6 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -108,7 +108,8 @@ def __init__(self,
self.worker_class_name = worker_class_name
self.worker = None
- def update_environment_variables(self, envs: Dict[str, str]) -> None:
+ @staticmethod
+ def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
@@ -138,10 +139,8 @@ def init_worker(self, *args, **kwargs):
def execute_method(self, method, *args, **kwargs):
try:
- if hasattr(self, method):
- executor = getattr(self, method)
- else:
- executor = getattr(self.worker, method)
+ target = self if self.worker is None else self.worker
+ executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
From 050f285ff6e7bbe898ee751770b2571972166bef Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Tue, 23 Apr 2024 17:02:11 +0900
Subject: [PATCH 098/413] [Core] Scheduling optimization 2 (#4280)
---
tests/core/test_scheduler.py | 3 ++-
vllm/core/scheduler.py | 10 ++++++++--
vllm/sequence.py | 5 +++++
3 files changed, 15 insertions(+), 3 deletions(-)
diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py
index a2511238506b0..ab471d206618b 100644
--- a/tests/core/test_scheduler.py
+++ b/tests/core/test_scheduler.py
@@ -563,7 +563,8 @@ def cannot_append_second_group(seq_group, num_lookahead_slots):
assert len(output.preempted) == 2
# Verify budgets are updated.
assert budget.num_batched_tokens == 1
- assert budget.num_curr_seqs == 1
+ # NOTE: When enable_chunk is False, num_seqs budget is not updated.
+ # assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
# Nothing is copied.
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 8d7db09bbea08..99f7a34d336a4 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -395,12 +395,12 @@ def _schedule_running(
# We can have up to 1 running prefill at any given time in running
# queue, which means we can guarantee chunk size is at least 1.
assert num_running_tokens != 0
- num_running_seqs = seq_group.get_max_num_running_seqs()
running_queue.popleft()
while not self._can_append_slots(seq_group):
budget.subtract_num_batched_tokens(seq_group.request_id,
num_running_tokens)
+ num_running_seqs = seq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
@@ -439,7 +439,13 @@ def _schedule_running(
token_chunk_size=1))
budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens)
- budget.add_num_seqs(seq_group.request_id, num_running_seqs)
+ # OPTIMIZATION: Note that get_max_num_running_seqs is
+ # expensive. For the default scheduling chase where
+ # enable_chunking is False, num_seqs are updated before running
+ # this method, so we don't have to update it again here.
+ if enable_chunking:
+ num_running_seqs = seq_group.get_max_num_running_seqs()
+ budget.add_num_seqs(seq_group.request_id, num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 7dcacab6f2ab6..b296b37a84f15 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -508,6 +508,11 @@ def get_num_uncomputed_tokens(self) -> int:
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
+ # Optimization. We don't need to call get_seqs if we don't need to
+ # filter by states.
+ if status is None:
+ return len(self.seqs_dict)
+
return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int:
From 62b8aebc6f06b5c8fafa1f27893cd4f9bb11fa8b Mon Sep 17 00:00:00 2001
From: Cade Daniel
Date: Tue, 23 Apr 2024 01:02:36 -0700
Subject: [PATCH 099/413] [Speculative decoding 7/9] Speculative decoding
end-to-end correctness tests. (#3951)
---
tests/samplers/test_rejection_sampler.py | 8 +-
tests/samplers/test_sampler.py | 3 +-
tests/spec_decode/e2e/__init__.py | 0
tests/spec_decode/e2e/conftest.py | 45 +-
tests/spec_decode/e2e/test_compatibility.py | 169 ++++++
tests/spec_decode/e2e/test_correctness.py | 540 ++++++++++++++++--
tests/spec_decode/test_metrics.py | 4 +-
tests/spec_decode/test_multi_step_worker.py | 4 +-
tests/spec_decode/test_spec_decode_worker.py | 40 +-
tests/spec_decode/utils.py | 7 +-
vllm/config.py | 67 ++-
vllm/engine/arg_utils.py | 18 +-
vllm/engine/llm_engine.py | 38 +-
vllm/engine/metrics.py | 23 +-
vllm/executor/gpu_executor.py | 1 +
.../layers/rejection_sampler.py | 7 +
vllm/model_executor/layers/sampler.py | 182 +++++-
vllm/spec_decode/batch_expansion.py | 70 ++-
vllm/spec_decode/interfaces.py | 4 +-
vllm/spec_decode/metrics.py | 31 +-
vllm/spec_decode/multi_step_worker.py | 29 +-
vllm/spec_decode/spec_decode_worker.py | 49 +-
22 files changed, 1164 insertions(+), 175 deletions(-)
create mode 100644 tests/spec_decode/e2e/__init__.py
create mode 100644 tests/spec_decode/e2e/test_compatibility.py
diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py
index d2c3a798d3087..13b5b80cccfdc 100644
--- a/tests/samplers/test_rejection_sampler.py
+++ b/tests/samplers/test_rejection_sampler.py
@@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids,
)
+ # Bonus tokens are currently disabled. Verify they're set to -1.
+ # See https://github.com/vllm-project/vllm/issues/4212
+ expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
+
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
# Expect all bonus tokens to be included.
- assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
+ assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
@@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat(
- (recovered_token_ids, bonus_token_ids), dim=-1)
+ (recovered_token_ids, expected_bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token.
assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size),
diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py
index dbbe13b8da060..52a2b0ca52aaa 100644
--- a/tests/samplers/test_sampler.py
+++ b/tests/samplers/test_sampler.py
@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
- return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
+ return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
+ for prob in probs], None)
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
diff --git a/tests/spec_decode/e2e/__init__.py b/tests/spec_decode/e2e/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py
index 1d99cb5d32219..59fb8311fc5b7 100644
--- a/tests/spec_decode/e2e/conftest.py
+++ b/tests/spec_decode/e2e/conftest.py
@@ -1,3 +1,5 @@
+from typing import List, Tuple
+
import pytest
from tests.conftest import cleanup
@@ -6,28 +8,34 @@
@pytest.fixture
-def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- baseline_llm_kwargs, seed):
- return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+def baseline_llm_generator(request, common_llm_kwargs,
+ per_test_common_llm_kwargs, baseline_llm_kwargs,
+ seed):
+ return create_llm_generator("baseline", request, common_llm_kwargs,
+ per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)
@pytest.fixture
-def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
+def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
- return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- test_llm_kwargs, seed)
+ return create_llm_generator("test", request, common_llm_kwargs,
+ per_test_common_llm_kwargs, test_llm_kwargs,
+ seed)
-def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- distinct_llm_kwargs, seed):
+def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
+ per_test_common_llm_kwargs, distinct_llm_kwargs,
+ seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
+ test_name = request.node.name
def generator_inner():
+ print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
set_random_seed(seed)
@@ -36,6 +44,23 @@ def generator_inner():
del llm
cleanup()
- for llm in generator_inner():
- yield llm
+ def generator_outer():
+ for llm in generator_inner():
+ yield llm
+ del llm
+
+ return generator_outer
+
+
+def get_output_from_llm_generator(
+ llm_generator, prompts,
+ sampling_params) -> Tuple[List[str], List[List[int]]]:
+ tokens = []
+ token_ids = []
+ for llm in llm_generator():
+ outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
+ token_ids = [output.outputs[0].token_ids for output in outputs]
+ tokens = [output.outputs[0].text for output in outputs]
del llm
+
+ return tokens, token_ids
diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py
new file mode 100644
index 0000000000000..fde950c14382c
--- /dev/null
+++ b/tests/spec_decode/e2e/test_compatibility.py
@@ -0,0 +1,169 @@
+import pytest
+
+from vllm import SamplingParams
+
+from .conftest import get_output_from_llm_generator
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "model": "JackFram/llama-68m",
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ {
+ # Expect failure as spec decode not supported by
+ # Ray backend.
+ "worker_use_ray": True,
+ },
+ ])
+@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_xfail_ray(test_llm_generator):
+ """Verify that speculative decoding with Ray fails.
+ """
+ output_len = 128
+ temperature = 0.0
+
+ prompts = [
+ "Hello, my name is",
+ ]
+
+ sampling_params = SamplingParams(
+ max_tokens=output_len,
+ ignore_eos=True,
+ temperature=temperature,
+ )
+
+ with pytest.raises(AssertionError,
+ match="Speculative decoding not yet supported for "):
+ get_output_from_llm_generator(test_llm_generator, prompts,
+ sampling_params)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "model": "JackFram/llama-68m",
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [
+ {
+ "enable_chunked_prefill": True,
+ },
+])
+@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
+ """Verify that speculative decoding with chunked prefill fails.
+ """
+ output_len = 128
+ temperature = 0.0
+
+ prompts = [
+ "Hello, my name is",
+ ]
+
+ sampling_params = SamplingParams(
+ max_tokens=output_len,
+ ignore_eos=True,
+ temperature=temperature,
+ )
+
+ with pytest.raises(ValueError,
+ match="Speculative decoding and chunked prefill"):
+ get_output_from_llm_generator(test_llm_generator, prompts,
+ sampling_params)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "model": "meta-llama/Llama-2-7b-chat-hf",
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ {
+ # Speculative max model len > overridden max model len should raise.
+ "max_model_len": 128,
+ "speculative_max_model_len": 129,
+ },
+ {
+ # Speculative max model len > draft max model len should raise.
+ # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
+ "speculative_max_model_len": 2048 + 1,
+ },
+ {
+ # Speculative max model len > target max model len should raise.
+ # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
+ "speculative_max_model_len": 4096 + 1,
+ },
+ ])
+@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
+ """Verify that speculative decoding validates speculative_max_model_len.
+ """
+ output_len = 128
+ temperature = 0.0
+
+ prompts = [
+ "Hello, my name is",
+ ]
+
+ sampling_params = SamplingParams(
+ max_tokens=output_len,
+ ignore_eos=True,
+ temperature=temperature,
+ )
+
+ with pytest.raises(ValueError, match="cannot be larger than"):
+ get_output_from_llm_generator(test_llm_generator, prompts,
+ sampling_params)
+
+
+@pytest.mark.parametrize("common_llm_kwargs", [{
+ "model": "JackFram/llama-68m",
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+}])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
+ """Verify that speculative decoding with block manager v1 fails.
+ """
+ output_len = 128
+ temperature = 0.0
+
+ prompts = [
+ "Hello, my name is",
+ ]
+
+ sampling_params = SamplingParams(
+ max_tokens=output_len,
+ ignore_eos=True,
+ temperature=temperature,
+ )
+
+ with pytest.raises(ValueError,
+ match="Speculative decoding requires usage of the V2"):
+ get_output_from_llm_generator(test_llm_generator, prompts,
+ sampling_params)
diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py
index a8ebd66841eb2..0536cc4ecde76 100644
--- a/tests/spec_decode/e2e/test_correctness.py
+++ b/tests/spec_decode/e2e/test_correctness.py
@@ -1,11 +1,42 @@
+"""The tests in this file verify end-to-end speculative decoding correctness.
+
+This docstring details important information on the testing methodology.
+
+Most of the tests rely on "greedy equality", where we expect the output of
+speculative decoding on a sequence to exactly match the output of normal non-
+speculative decoding.
+
+Since speculative decoding with rejection sampling guarantees that the output
+distribution matches the target model's output distribution (up to hardware
+numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
+equality. This gives us good coverage of temp=0.
+
+For temp>0, we rely on unit tests on the rejection sampler to verify that the
+output distribution is the same with spec decode vs. no spec decode (this would
+be prohibitively expensive to run with a real model).
+
+NOTE: Speculative decoding's distribution equality requires that the measured
+distributions of the target model and proposal model be deterministic given the
+same input. vLLM largely guarantees this.
+
+@cadedaniel has seen cases where the output probabilities of a draft/target
+model change slightly with certain batch sizes or prompts, even with Torch
+determinism flags set. It is unclear if this is a bug in vLLM, due to non-
+determinism in on-device batched operations, a bug in vLLM's spec decode
+implementation, or the "hardware numerics" limitations. Either way, rejection
+sampling ensures the output distribution matches the target model, but it breaks
+greedy-equality tests for those batch sizes/prompts.
+"""
+
from itertools import cycle
-from typing import List, Tuple
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
+from .conftest import get_output_from_llm_generator
+
@pytest.mark.parametrize(
"common_llm_kwargs",
@@ -14,9 +45,6 @@
# Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m",
- # Skip real loading for fast test.
- "load_format": "dummy",
-
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@@ -31,22 +59,15 @@
"num_speculative_tokens": 5,
},
{
- "speculative_model": "JackFram/llama-68m",
- "num_speculative_tokens": 1,
- },
- {
- # No spec decode.
+ # Verify the detokenizer assertions in the test work when spec
+ # decode is disabled.
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
-@pytest.mark.parametrize("batch_size", [1])
-# NOTE: We should run more permutations of this test (more BS, more seeds). But
-# because our spec decode generates gibberish token ids, the likelihood of
-# emitting an invalid token combination is nontrivial. This causes divergence in
-# behavior of vLLM detokenization vs. hf tokenizer, for example when two "utf-
-# start" bytes are emitted.
+@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
-def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
+def test_spec_decode_e2e_with_detokenization(test_llm_generator,
+ batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine
generates the correct number of tokens (via ignore_eos=True), and that the
detokenization matches HF transformers.
@@ -67,8 +88,6 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
- skip_special_tokens=True,
- spaces_between_special_tokens=False,
)
batch_tokens, batch_token_ids = get_output_from_llm_generator(
@@ -77,9 +96,10 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
# Expect a generation for each prompt in the batch.
assert len(batch_token_ids) == len(prompts)
- # Expect each generation to have expected number of tokens (note
- # ignore_eos=True).
- assert all(len(token_ids) == output_len for token_ids in batch_token_ids)
+ # Expect each generation to have expected number of tokens (note ignore_eos
+ # is True).
+ assert [len(token_ids)
+ for token_ids in batch_token_ids] == ([output_len] * batch_size)
# Expect detokenized string to match.
tok = AutoTokenizer.from_pretrained("JackFram/llama-68m")
@@ -92,14 +112,111 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
- # Use a small model for a fast test.
- "model": "JackFram/llama-68m",
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True,
+
+ # Print spec metrics.
+ "disable_log_stats": False,
+ }])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ # Try two different tiny base models.
+ # Note that one is equal to the draft model, another isn't.
+ {
+ "model": "JackFram/llama-68m",
+ },
+ {
+ "model": "JackFram/llama-160m",
+ },
+ ])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use long output len for the small model test.
+ 1536,
+ ])
+@pytest.mark.parametrize("batch_size", [1])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
+ baseline_llm_generator, test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify greedy equality on a tiny model with batch size of one.
+
+ Since this test is cheaper than other e2e correctness tests, we generate
+ with a higher output_len.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
- # Skip real loading for fast test.
- "load_format": "dummy",
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True,
+
+ # Print spec metrics.
+ "disable_log_stats": False,
+ }])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ # Try two different tiny base models.
+ # Note that one is equal to the draft model, another isn't.
+ {
+ "model": "JackFram/llama-68m",
+ },
+ {
+ "model": "JackFram/llama-160m",
+ },
+ ])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use small output len for fast test.
+ 256,
+ ])
+@pytest.mark.parametrize("batch_size", [64])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
+ baseline_llm_generator, test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify greedy equality on a tiny model and large batch size.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@@ -109,43 +226,372 @@ def test_spec_decode_e2e_logical_flow(test_llm_generator, batch_size: int):
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
+ # Try two different tiny base models.
+ # Note that one is equal to the draft model, another isn't.
{
- # Expect failure as spec decode not supported by
- # Ray backend.
- "worker_use_ray": True,
+ "model": "JackFram/llama-68m",
+ },
+ {
+ "model": "JackFram/llama-160m",
},
])
-@pytest.mark.parametrize("test_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize("max_output_len", [
+ 256,
+])
+@pytest.mark.parametrize("batch_size", [32])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
+ baseline_llm_generator, test_llm_generator, batch_size: int,
+ max_output_len: int):
+ """Verify greedy equality on a tiny model, with a large batch size, and when
+ sampling respects the EOS token.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len,
+ force_output_len=False)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ # A "real" model (not tiny).
+ "model": "meta-llama/Llama-2-7b-chat-hf",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True,
+
+ # Print spec metrics.
+ "disable_log_stats": False,
+ }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize("batch_size", [1])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use decently long output len for a high quality test.
+ 256,
+ ])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
+ baseline_llm_generator, test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify greedy equality on a "real" model and batch size of 1. This is
+ separate from large BS tests to make identifying the source of bugs easier.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ # A "real" model (not tiny).
+ "model": "meta-llama/Llama-2-7b-chat-hf",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True,
+
+ # Print spec metrics.
+ "disable_log_stats": False,
+ }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize("batch_size", [32])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use smaller output len for fast test.
+ 64,
+ ])
@pytest.mark.parametrize("seed", [1])
-def test_spec_decode_xfail(test_llm_generator):
- """Verify that speculative decoding with Ray fails.
+def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
+ baseline_llm_generator, test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify greedy equality with a "real" model on a nontrivial batch size.
+ This is the closest test to a real production workload.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "block_size": 8,
+ # 2 for small prompt, 256//8 for generated.
+ "num_gpu_blocks_override": 2 + 256 // 8,
+ "max_model_len": (2 + 256 // 8) * 8,
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [
+ {
+ "model": "JackFram/llama-160m",
+ },
+])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use small output len for fast test.
+ 256,
+ ])
+@pytest.mark.parametrize("batch_size", [4])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_e2e_greedy_correctness_with_preemption(
+ baseline_llm_generator, test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify greedy equality, even when some sequences are preempted mid-
+ generation.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "model": "JackFram/llama-160m",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize(
+ "per_test_common_llm_kwargs",
+ [
+ # As of this writing, vLLM only compiles with these 3 block sizes by
+ # default.
+ {
+ "block_size": 8,
+ },
+ {
+ "block_size": 16,
+ },
+ {
+ "block_size": 32,
+ },
+ ])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize("test_llm_kwargs", [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+ },
+])
+@pytest.mark.parametrize("batch_size", [2])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use smaller output len for fast test.
+ 32,
+ ])
+@pytest.mark.parametrize("seed", [1])
+def test_spec_decode_different_block_size(baseline_llm_generator,
+ test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify greedy equality over different block sizes.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "model": "JackFram/llama-160m",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize(
+ "test_llm_kwargs",
+ [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": 5,
+
+ # Artificially limit the draft model max model len; this forces vLLM
+ # to skip speculation once the sequences grow beyond 32-k tokens.
+ "speculative_max_model_len": 32,
+ },
+ ])
+@pytest.mark.parametrize("batch_size", [8])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # This must be a good bit larger than speculative_max_model_len so that
+ # we can test the case where all seqs are skipped, but still small to
+ # ensure fast test.
+ 64,
+ ])
+@pytest.mark.parametrize("seed", [1])
+def test_skip_speculation(baseline_llm_generator, test_llm_generator,
+ batch_size: int, output_len: int):
+ """Verify greedy equality when some (or all) sequences skip speculation.
+ We do this by setting the max model len of the draft model to an
+ artificially low value, such that when the sequences grow beyond it, they
+ are skipped in speculative decoding.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+@pytest.mark.parametrize(
+ "common_llm_kwargs",
+ [{
+ "model": "JackFram/llama-68m",
+
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
+
+ # Required for spec decode.
+ "use_v2_block_manager": True
+ }])
+@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
+@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
+@pytest.mark.parametrize(
+ "test_llm_kwargs",
+ [
+ {
+ "speculative_model": "JackFram/llama-68m",
+ "num_speculative_tokens": k,
+ }
+ # Try a range of common k, as well as large speculation.
+ for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
+ ])
+@pytest.mark.parametrize("batch_size", [2])
+@pytest.mark.parametrize(
+ "output_len",
+ [
+ # Use smaller output len for fast test.
+ 32,
+ ])
+@pytest.mark.parametrize("seed", [1])
+def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
+ output_len: int):
+ """Verify that speculative decoding produces exact equality to without spec
+ decode with many different values of k.
+ """
+ run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len=output_len,
+ force_output_len=True)
+
+
+def run_greedy_equality_correctness_test(baseline_llm_generator,
+ test_llm_generator,
+ batch_size,
+ max_output_len,
+ force_output_len: bool,
+ print_tokens: bool = False):
+ """Helper method that compares the outputs of both the baseline LLM and
+ the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
+ the same when temperature is zero.
"""
- output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ "San Francisco is know for its",
+ "Facebook was created in 2004 by",
+ "Curious George is a",
+ "Python 3.11 brings improvements to its",
]
+ prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
+
+ # If the test requires that we generated max_output_len tokens, then set the
+ # sampling params to ignore eos token.
+ ignore_eos = force_output_len
+
sampling_params = SamplingParams(
- max_tokens=output_len,
- ignore_eos=True,
+ max_tokens=max_output_len,
+ ignore_eos=ignore_eos,
temperature=temperature,
)
- with pytest.raises(AssertionError,
- match="Speculative decoding not yet supported for "):
- get_output_from_llm_generator(test_llm_generator, prompts,
- sampling_params)
+ spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
+ test_llm_generator, prompts, sampling_params)
+ (baseline_batch_tokens,
+ baseline_batch_token_ids) = get_output_from_llm_generator(
+ baseline_llm_generator, prompts, sampling_params)
-def get_output_from_llm_generator(
- llm_generator, prompts,
- sampling_params) -> Tuple[List[str], List[List[int]]]:
- for llm in llm_generator:
- outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
- token_ids = [output.outputs[0].token_ids for output in outputs]
- tokens = [output.outputs[0].text for output in outputs]
- del llm
+ assert len(baseline_batch_token_ids) == len(prompts)
+ assert len(spec_batch_token_ids) == len(prompts)
- return tokens, token_ids
+ for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
+ spec_tokens) in enumerate(
+ zip(baseline_batch_token_ids, baseline_batch_tokens,
+ spec_batch_token_ids, spec_batch_tokens)):
+ if print_tokens:
+ print(f'{i=} {baseline_tokens=}')
+ print(f'{i=} {spec_tokens=}')
+ print(f'{i=} {baseline_token_ids=}')
+ print(f'{i=} {spec_token_ids=}')
+ assert baseline_token_ids == spec_token_ids
diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py
index 36e91672069dc..312878804b86e 100644
--- a/tests/spec_decode/test_metrics.py
+++ b/tests/spec_decode/test_metrics.py
@@ -119,7 +119,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
num_draft_tokens = 0
k = 5
- num_possible_tokens = AsyncMetricsCollector.get_max_num_accepted_tokens(
+ max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens(
num_draft_tokens, k)
rej_sampler = MagicMock()
@@ -153,7 +153,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
assert (metrics.draft_acceptance_rate == num_accepted_tokens /
num_draft_tokens)
assert (metrics.system_efficiency == num_emitted_tokens /
- num_possible_tokens)
+ max_num_emitted_tokens)
else:
assert math.isnan(metrics.draft_acceptance_rate)
assert math.isnan(metrics.system_efficiency)
diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py
index d6edbab579afd..e7aaa1ff4eff8 100644
--- a/tests/spec_decode/test_multi_step_worker.py
+++ b/tests/spec_decode/test_multi_step_worker.py
@@ -344,8 +344,8 @@ def test_draft_proposals_no_speculations():
assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs)
- assert proposals.proposal_token_ids.shape == torch.Size([0, k])
- assert proposals.proposal_probs.shape[:-1] == torch.Size([0, k])
+ assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
+ assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])
assert proposals.proposal_lens.shape == torch.Size([batch_size])
assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]
diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py
index 0a3110775e2d6..d24d726c9c0cf 100644
--- a/tests/spec_decode/test_spec_decode_worker.py
+++ b/tests/spec_decode/test_spec_decode_worker.py
@@ -1,4 +1,5 @@
import random
+from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
@@ -62,8 +63,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
"""
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
+ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
+ target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -144,8 +145,10 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
"""
vocab_size = 32_000
- draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
- target_worker = mock_worker(vocab_size=vocab_size)
+ draft_worker = mock_worker(cls=MultiStepWorker,
+ vocab_size=vocab_size,
+ use_spec=False)
+ target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -202,17 +205,16 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
num_lookahead_slots=k)
assert len(rejection_sampler.call_args_list) == 1
- args, _ = rejection_sampler.call_args_list[0]
- (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
- actual_proposal_token_ids) = args
+ _, kwargs = rejection_sampler.call_args_list[0]
+ actual = SimpleNamespace(**kwargs)
- assert torch.equal(actual_bonus_token_ids,
+ assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
- actual_proposal_scores,
+ actual.target_probs,
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
- assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
- assert torch.equal(actual_proposal_probs, proposal_probs)
+ assert torch.equal(actual.draft_token_ids, proposal_token_ids)
+ assert torch.equal(actual.draft_probs, proposal_probs)
@pytest.mark.parametrize('k', [1, 2, 6])
@@ -224,8 +226,10 @@ def test_correctly_formats_output(k: int, batch_size: int):
"""
vocab_size = 32_000
- draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
- target_worker = mock_worker(vocab_size=vocab_size)
+ draft_worker = mock_worker(cls=MultiStepWorker,
+ vocab_size=vocab_size,
+ use_spec=False)
+ target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -336,8 +340,10 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
"""
vocab_size = 32_000
- draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
- target_worker = mock_worker(vocab_size=vocab_size)
+ draft_worker = mock_worker(cls=MultiStepWorker,
+ vocab_size=vocab_size,
+ use_spec=False)
+ target_worker = mock_worker(vocab_size=vocab_size, use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
@@ -500,8 +506,8 @@ def test_init_device():
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
well as other GPU initialization.
"""
- draft_worker = mock_worker(cls=MultiStepWorker)
- target_worker = mock_worker()
+ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
+ target_worker = mock_worker(use_spec=False)
rejection_sampler = MagicMock(spec=RejectionSampler)
rejection_sampler.token_id_dtype = torch.int64
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py
index d04b6029493f4..4f8295d25cf41 100644
--- a/tests/spec_decode/utils.py
+++ b/tests/spec_decode/utils.py
@@ -63,11 +63,14 @@ def create_execute_model_data(
def mock_worker(cls=None,
vocab_size: int = 30_000,
max_model_len: int = 2048,
- rank: int = 0) -> MagicMock:
+ rank: int = 0,
+ use_spec: bool = True) -> MagicMock:
if cls is None:
cls = Worker
- worker = MagicMock(spec=cls)
+ spec = cls if use_spec else None
+
+ worker = MagicMock(spec=spec)
worker.vocab_size = vocab_size
worker.max_model_len = max_model_len
worker.rank = rank
diff --git a/vllm/config.py b/vllm/config.py
index 97ede0faa21ab..2ff42de08f8f7 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -655,6 +655,9 @@ def maybe_create_spec_config(
target_dtype: str,
speculative_model: Optional[str],
num_speculative_tokens: Optional[int],
+ speculative_max_model_len: Optional[int],
+ enable_chunked_prefill: bool,
+ use_v2_block_manager: bool,
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
@@ -672,6 +675,15 @@ def maybe_create_spec_config(
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.
+ speculative_max_model_len (Optional[int]): The maximum model len of
+ the speculative model. Used when testing the ability to skip
+ speculation for some sequences.
+ enable_chunked_prefill (bool): Whether vLLM is configured to use
+ chunked prefill or not. Used for raising an error since its not
+ yet compatible with spec decode.
+ use_v2_block_manager (bool): Whether vLLM is configured to use the
+ v2 block manager or not. Used for raising an error since the v2
+ block manager is required with spec decode.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
@@ -690,12 +702,21 @@ def maybe_create_spec_config(
assert (speculative_model is not None
and num_speculative_tokens is not None)
+ if enable_chunked_prefill:
+ raise ValueError(
+ "Speculative decoding and chunked prefill are "
+ f"currently mutually exclusive ({enable_chunked_prefill=}).")
+
+ if not use_v2_block_manager:
+ raise ValueError(
+ "Speculative decoding requires usage of the V2 "
+ "block manager. Enable it with --use-v2-block-manager.")
+
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
- draft_max_model_len = None
draft_model_config = ModelConfig(
model=speculative_model,
@@ -707,7 +728,7 @@ def maybe_create_spec_config(
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
- max_model_len=draft_max_model_len,
+ max_model_len=None,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
@@ -715,6 +736,13 @@ def maybe_create_spec_config(
max_logprobs=target_model_config.max_logprobs,
)
+ draft_model_config.max_model_len = (
+ SpeculativeConfig._maybe_override_draft_max_model_len(
+ speculative_max_model_len,
+ draft_model_config.max_model_len,
+ target_model_config.max_model_len,
+ ))
+
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
@@ -725,6 +753,41 @@ def maybe_create_spec_config(
num_speculative_tokens,
)
+ @staticmethod
+ def _maybe_override_draft_max_model_len(
+ speculative_max_model_len: Optional[int],
+ draft_max_model_len: int,
+ target_max_model_len: int,
+ ) -> int:
+ """Determine the max sequence len for the draft model. This is usually
+ the draft_max_model_len, but may be the target_max_model_len if it is
+ less than the draft_max_model_len, or may be speculative_max_model_len
+ if it is specified.
+
+ This is necessary so that sequences do not exceed the capacity of the
+ draft model or the target model.
+
+ speculative_max_model_len is mainly used for testing that sequences can
+ skip speculation.
+ """
+
+ if speculative_max_model_len is not None:
+
+ if speculative_max_model_len > draft_max_model_len:
+ raise ValueError(f"{speculative_max_model_len=} cannot be "
+ f"larger than {draft_max_model_len=}")
+
+ if speculative_max_model_len > target_max_model_len:
+ raise ValueError(f"{speculative_max_model_len=} cannot be "
+ f"larger than {target_max_model_len=}")
+
+ return speculative_max_model_len
+
+ return min(
+ draft_max_model_len,
+ target_max_model_len,
+ )
+
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 5de20633ffdd6..6a6ac49ae3211 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -73,6 +73,7 @@ class EngineArgs:
# Speculative decoding configuration.
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
+ speculative_max_model_len: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
@@ -237,7 +238,7 @@ def add_cli_args(
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
- choices=[8, 16, 32, 128],
+ choices=[8, 16, 32],
help='Token block size for contiguous chunks of '
'tokens.')
@@ -420,17 +421,25 @@ def add_cli_args(
parser.add_argument(
'--speculative-model',
type=str,
- default=None,
+ default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
- default=None,
+ default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
+ parser.add_argument(
+ '--speculative-max-model-len',
+ type=str,
+ default=EngineArgs.speculative_max_model_len,
+ help='The maximum sequence length supported by the '
+ 'draft model. Sequences over this length will skip '
+ 'speculation.')
+
parser.add_argument('--model-loader-extra-config',
type=str,
default=EngineArgs.model_loader_extra_config,
@@ -481,6 +490,9 @@ def create_engine_config(self, ) -> EngineConfig:
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
+ speculative_max_model_len=self.speculative_max_model_len,
+ enable_chunked_prefill=self.enable_chunked_prefill,
+ use_v2_block_manager=self.use_v2_block_manager,
)
scheduler_config = SchedulerConfig(
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d96025ea1fb6a..19e58fb1722cf 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -22,7 +22,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
- SequenceGroup)
+ SequenceGroup, SequenceStage)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -480,9 +480,12 @@ def _process_model_outputs(
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
- # If uncomputed tokens > 0, it means prefill is chunked.
- # We don't need to process outputs in that case.
- if seq_group.get_num_uncomputed_tokens() == 0:
+
+ # If all sequences in the sequence group are in DECODE, then we can
+ # process the output tokens. Otherwise, they are (chunked) prefill
+ # samples and should not be processed.
+ stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
+ if all(stage == SequenceStage.DECODE for stage in stages):
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
@@ -569,7 +572,8 @@ def step(self) -> List[RequestOutput]:
# Log stats.
if self.log_stats:
- self.stat_logger.log(self._get_stats(scheduler_outputs))
+ self.stat_logger.log(
+ self._get_stats(scheduler_outputs, model_output=output))
return request_outputs
@@ -578,9 +582,18 @@ def do_log_stats(self) -> None:
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
- def _get_stats(self,
- scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
- """Get Stats to be Logged to Prometheus."""
+ def _get_stats(
+ self,
+ scheduler_outputs: Optional[SchedulerOutputs],
+ model_output: Optional[List[SamplerOutput]] = None) -> Stats:
+ """Get Stats to be Logged to Prometheus.
+
+ Args:
+ scheduler_outputs: Optional, used to populate metrics related to
+ the scheduled batch,
+ model_output: Optional, used to emit speculative decoding metrics
+ which are created by the workers.
+ """
now = time.time()
# KV Cache Usage in %.
@@ -637,6 +650,14 @@ def _get_stats(self,
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
+ # Spec decode, if enabled, emits specialized metrics from the worker in
+ # sampler output.
+ if model_output and (model_output[0].spec_decode_worker_metrics
+ is not None):
+ spec_decode_metrics = model_output[0].spec_decode_worker_metrics
+ else:
+ spec_decode_metrics = None
+
return Stats(
now=now,
num_running=num_running,
@@ -649,6 +670,7 @@ def _get_stats(self,
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
+ spec_decode_metrics=spec_decode_metrics,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index 04e27e69ce0f3..25e96f6c7eaf7 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -1,6 +1,6 @@
import time
from dataclasses import dataclass
-from typing import Dict, List, Protocol
+from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@@ -8,6 +8,9 @@
from vllm.logger import init_logger
+if TYPE_CHECKING:
+ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
+
logger = init_logger(__name__)
disable_created_metrics()
@@ -118,6 +121,8 @@ class Stats:
time_per_output_tokens: List[float]
time_e2e_requests: List[float]
+ spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
+
class SupportsMetricsInfo(Protocol):
@@ -235,3 +240,19 @@ def log(self, stats: Stats) -> None:
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
+
+ if stats.spec_decode_metrics is not None:
+ logger.info(
+ self._format_spec_decode_metrics_str(
+ stats.spec_decode_metrics))
+
+ def _format_spec_decode_metrics_str(
+ self, metrics: "SpecDecodeWorkerMetrics") -> str:
+
+ return ("Speculative metrics: "
+ f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
+ f"System efficiency: {metrics.system_efficiency:.3f}, "
+ f"Number of speculative tokens: {metrics.num_spec_tokens}, "
+ f"Number of accepted tokens: {metrics.accepted_tokens}, "
+ f"Number of draft tokens tokens: {metrics.draft_tokens}, "
+ f"Number of emitted tokens tokens: {metrics.emitted_tokens}.")
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index 77c997f97956e..d413a7d27ff37 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -83,6 +83,7 @@ def _init_spec_worker(self):
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
+ # TODO allow draft-model specific load config.
load_config=self.load_config,
local_rank=0,
rank=0,
diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py
index ecd2bd0fce3a3..5edbbf2c70a49 100644
--- a/vllm/model_executor/layers/rejection_sampler.py
+++ b/vllm/model_executor/layers/rejection_sampler.py
@@ -144,6 +144,7 @@ def _batch_modified_rejection_sampling(
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
+ # NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
@@ -307,6 +308,12 @@ def _create_output(
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
+ # We disable bonus tokens because it causes corrupt KV cache for
+ # proposal methods that require KV cache. We can fix it by "prefilling"
+ # the bonus token in the proposer. The following issue tracks the fix.
+ # https://github.com/vllm-project/vllm/issues/4212
+ output_with_bonus_tokens[:, -1] = -1
+
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index 03bf38caebe0e..c4b11cb33a677 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -35,6 +35,14 @@ class Sampler(nn.Module):
in logits for each token in the input prompt.
"""
+ def __init__(self):
+ super().__init__()
+
+ # Whether or not the SamplerOutput should have on-device tensors
+ # containing the sampled token ids and probabilities. This is used by
+ # speculative decoding.
+ self.include_gpu_probs_tensor = False
+
def forward(
self,
logits: torch.Tensor,
@@ -79,13 +87,45 @@ def forward(
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
- sample_results = _sample(probs, logprobs, sampling_metadata,
- sampling_tensors)
+ sample_results, maybe_sampled_tokens_tensor = _sample(
+ probs,
+ logprobs,
+ sampling_metadata,
+ sampling_tensors,
+ include_gpu_probs_tensor=self.include_gpu_probs_tensor,
+ modify_greedy_probs=self._should_modify_greedy_probs_inplace,
+ )
+
+ if self.include_gpu_probs_tensor:
+ assert maybe_sampled_tokens_tensor is not None
+ sampled_tokens_tensor = maybe_sampled_tokens_tensor
+ on_device_tensors = (probs, sampled_tokens_tensor)
+ else:
+ on_device_tensors = None
+
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
- return _build_sampler_output(sample_results, sampling_metadata,
- prompt_logprobs, sample_logprobs)
+ return _build_sampler_output(sample_results,
+ sampling_metadata,
+ prompt_logprobs,
+ sample_logprobs,
+ on_device_tensors=on_device_tensors)
+
+ @property
+ def _should_modify_greedy_probs_inplace(self) -> bool:
+ """Whether or not the sampler should modify the probability distribution
+ of greedily-sampled tokens such that multinomial sampling would sample
+ the greedily-sampled token.
+
+ In other words, if True then we set the probability of the greedily-
+ sampled token to 1.
+
+ This is used by speculative decoding, which requires that the sampling
+ method be encoded into the probability distribution.
+ """
+ # Modify greedy probs if include_gpu_probs_tensor is set.
+ return self.include_gpu_probs_tensor
def _get_bin_counts_and_mask(
@@ -359,7 +399,9 @@ def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
-) -> List[Tuple[List[int], List[int]]]:
+ include_gpu_probs_tensor: bool,
+ modify_greedy_probs: bool,
+) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
@@ -371,6 +413,15 @@ def _sample_with_torch(
sample_metadata = {}
multinomial_samples = {}
+ # Create output tensor for sampled token ids.
+ if include_gpu_probs_tensor:
+ sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
+ 1,
+ dtype=torch.long,
+ device=logprobs.device)
+ else:
+ sampled_token_ids_tensor = None
+
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
@@ -383,9 +434,25 @@ def _sample_with_torch(
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
+ long_sample_indices = sample_indices.long()
+
if sampling_type == SamplingType.GREEDY:
- greedy_samples = torch.argmax(logprobs[sample_indices.long()],
+ greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
+
+ if include_gpu_probs_tensor:
+ # Store sampled tokens in output tensor.
+ sampled_token_ids_tensor[
+ long_sample_indices] = greedy_samples.unsqueeze(-1)
+
+ if modify_greedy_probs:
+ # If required, modify the probabilities such that sampling from
+ # the modified distribution would always sample the argmax
+ # token id.
+ _modify_greedy_probs_inplace(logprobs, probs,
+ long_sample_indices,
+ greedy_samples)
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
@@ -397,15 +464,23 @@ def _sample_with_torch(
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
+
multinomial_samples[sampling_type] = _multinomial(
- probs[sample_indices.long()], max_best_of_in_batch,
+ probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
+
+ if include_gpu_probs_tensor:
+ # Store sampled tokens in output tensor.
+ sampled_token_ids_tensor[
+ long_sample_indices] = multinomial_samples[sampling_type]
+
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# GPU<->CPU sync happens in the loop below.
+ # This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
@@ -427,7 +502,7 @@ def _sample_with_torch(
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
- return sample_results
+ return sample_results, sampled_token_ids_tensor
def _sample_with_triton_kernel(
@@ -511,12 +586,17 @@ def _sample_with_triton_kernel(
def _sample(
- probs: torch.Tensor,
- logprobs: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- sampling_tensors: SamplingTensors,
-) -> List[Tuple[List[int], List[int]]]:
- return _sample_with_torch(probs, logprobs, sampling_metadata)
+ probs: torch.Tensor, logprobs: torch.Tensor,
+ sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
+ include_gpu_probs_tensor: bool, modify_greedy_probs: bool
+) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+ return _sample_with_torch(
+ probs,
+ logprobs,
+ sampling_metadata,
+ include_gpu_probs_tensor=include_gpu_probs_tensor,
+ modify_greedy_probs=modify_greedy_probs,
+ )
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
@@ -680,12 +760,73 @@ def _get_logprobs(
return result_prompt_logprobs, result_sample_logprobs
+def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
+ sample_indices: torch.Tensor,
+ greedy_samples: torch.Tensor) -> None:
+ """Modify the probability distributions of the greedily-sampled tokens such
+ that each sampled token has a "probability" of 1.0. This is required by
+ speculative decoding, which depends on the sampling method being encoded
+ within the probability distribution for correctness.
+
+ # Why do we only need to do this for greedy sampling?
+
+ vLLM's sampler performs the following steps for greedy or multinomial
+ (random) sampling:
+ 1. Get logits from model.
+ 2. Modify logits according to per-sequence sampling parameters.
+ - Multiply by temperature, top-k and top-p masking, penalize tokens
+ according to their frequency, etc.
+ 3. Sample a token.
+ - Random sampling simply samples from the modified probability
+ distribution.
+ - Greedy sampling performs `argmax` to obtain the token with the
+ highest likelihood.
+
+ Ignoring greedy sampling for a moment, we find that the computed probability
+ distribution has the following property: we can sample from it independently
+ and find that the token sampled by the Sampler has a frequency corresponding
+ to how often we see it in our sampling. In other words, for tokens sampled
+ with vLLM's random SamplingType, the computed probability distribution
+ encodes the sampling methodology completely.
+
+ Greedy sampling does not normally have this property. vLLM modifies logits
+ according to sampling params, then performs `argmax`, then returns the
+ sampled token and the computed probability distribution. If we sample from
+ the distribution, we'll find the likelihood of the greedily-sampled token
+ is not always 1.0.
+
+ Since lossless speculative decoding requires that the sampling methodology
+ be encoded within the probability distribution, we are motivated to modify
+ the probability distribution such that the sampled token has probability 1
+ when speculative decoding is used.
+
+ NOTE: Alternatively, we could use an extremely low temperature to achieve
+ greedy sampling using multinomial computation and unite the codepaths. This
+ has implications on the overall design of the sampler, e.g. how to record
+ accurate logprobs for the user, so this improvement is deferred to later.
+ """
+ logprobs[sample_indices, :] = -float('inf')
+ logprobs[sample_indices, greedy_samples] = 0.0
+ probs[sample_indices, :] = 0
+ probs[sample_indices, greedy_samples] = 1.0
+
+
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
+ on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput:
+ """Construct Python objects with the output of sampling.
+
+ Args:
+ on_device_tensors: Tuple containing on-device tensors with the
+ probabilities used in sampling and the sampled token ids. This
+ allows post-processing without copies to CPU/serialization, e.g. in
+ speculative decoding rejection sampling.
+ """
+
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
@@ -701,4 +842,15 @@ def _build_sampler_output(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
- return SamplerOutput(outputs=sampler_output)
+
+ # If not specified, store None values in SamplerOutput.
+ if on_device_tensors is not None:
+ sampled_token_probs, sampled_token_ids = on_device_tensors
+ else:
+ sampled_token_probs, sampled_token_ids = (None, None)
+
+ return SamplerOutput(
+ outputs=sampler_output,
+ sampled_token_probs=sampled_token_probs,
+ sampled_token_ids=sampled_token_ids,
+ )
diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py
index bbc5b1778854f..c29b838f854c0 100644
--- a/vllm/spec_decode/batch_expansion.py
+++ b/vllm/spec_decode/batch_expansion.py
@@ -6,8 +6,8 @@
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
-from vllm.spec_decode.util import (get_all_seq_ids, maybe_mock_device_tensors,
- nvtx_range, sampler_output_to_torch,
+from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
+ sampler_output_to_torch,
split_batch_by_proposal_len)
from vllm.worker.worker_base import WorkerBase
@@ -72,10 +72,16 @@ def score_proposals(
proposal_lens_list = proposals.proposal_lens.tolist()
proposal_token_ids_list = proposals.proposal_token_ids.tolist()
+ # Filter the list to ignore -1 proposals.
+ proposal_token_ids_list_without_skips = [
+ proposals for proposals in proposal_token_ids_list
+ if -1 not in proposals
+ ]
+
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list,
- proposal_token_ids_list=proposal_token_ids_list,
+ proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list,
)
@@ -89,7 +95,7 @@ def score_proposals(
target_sampler_output = target_sampler_output[0]
all_tokens, all_probs = self._contract_batch(
- original_bs=len(seq_group_metadata_list),
+ contracted_bs=len(seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
@@ -128,14 +134,21 @@ def _expand_batch(
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input(
- spec_seqs, proposal_token_ids_list)
+ seq_group_metadata_list=spec_seqs,
+ proposal_token_ids=proposal_token_ids_list,
+ # NOTE: We determine the seq ids in the expanded batch using the
+ # full seq_group_metadata_list, instead of only spec_seqs.
+ target_seq_ids_iter=self._create_target_seq_id_iterator(
+ seq_ids=get_all_seq_ids(seq_group_metadata_list)),
+ )
+
num_scoring_tokens = len(target_seq_group_metadata_list)
target_seq_group_metadata_list.extend(non_spec_seqs)
return (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens)
- def _contract_batch(self, original_bs: int,
+ def _contract_batch(self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
@@ -144,42 +157,41 @@ def _contract_batch(self, original_bs: int,
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
- """
-
- # We mock the device tensors until PR 7/9 is merged (e2e correctness).
- # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
- maybe_mock_device_tensors(
- sampler_output=target_sampler_output,
- batch_size=len(non_spec_indices) + num_scoring_tokens,
- vocab_size=self._vocab_size,
- device=self._device,
- )
+ contracted_bs is the original batch size, and the batch size that the
+ target_sampler_output will be contracted to.
+ """
(target_token_ids, target_probs, non_spec_target_token_ids,
non_spec_target_probs) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
- batch_size, k = proposals.proposal_token_ids.shape
+ expanded_batch_size, k = proposals.proposal_token_ids.shape
+
+ # The number of tokens in the expanded batch used for speculation is
+ # equal to the total expanded batch size minus the number of samples for
+ # non-speculative sequences.
+ non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
+ spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape(
- batch_size, k + 1)
- target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
+ spec_expanded_bs, k + 1)
+ target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
- all_tokens = torch.full(size=(original_bs, k + 1),
+ all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
- all_probs = torch.zeros(original_bs,
+ all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
if non_spec_indices:
- all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
+ all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
if spec_indices:
@@ -189,20 +201,22 @@ def _contract_batch(self, original_bs: int,
return all_tokens, all_probs
def _create_scoring_model_input(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
+ target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
"""Given the original input sequences and proposed tokens from the draft
model, create a list of target sequences that can be used for scoring.
+
+ target_seq_ids_iter provides sequence ids for the expanded batch,
+ fulfilling the requirement that no seq id in the expanded batch is equal
+ to the seq id in the original batch.
"""
if not seq_group_metadata_list:
return []
- target_seq_ids_iter = self._create_target_seq_id_iterator(
- get_all_seq_ids(seq_group_metadata_list))
-
target_seq_group_metadata = list(
chain.from_iterable(
self._create_target_seq_group_metadata(
diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py
index f0715120192e5..dd040779922e9 100644
--- a/vllm/spec_decode/interfaces.py
+++ b/vllm/spec_decode/interfaces.py
@@ -24,9 +24,9 @@ class SpeculativeProposals:
def __repr__(self):
return (f"SpeculativeProposals("
- f"proposal_token_ids={self.proposal_token_ids.shape}, "
+ f"proposal_token_ids={self.proposal_token_ids}, "
f"proposal_probs={self.proposal_probs.shape}, "
- f"proposal_lens={self.proposal_lens.shape})")
+ f"proposal_lens={self.proposal_lens})")
@dataclass
diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py
index d1e72b6640548..ab1d96c558de7 100644
--- a/vllm/spec_decode/metrics.py
+++ b/vllm/spec_decode/metrics.py
@@ -147,15 +147,16 @@ def _collect_rejsample_metrics(
emitted_tokens = self._aggregate_num_emitted_tokens.item()
draft_tokens = self._aggregate_num_draft_tokens
- num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
+ max_num_emitted_tokens = self.get_max_num_emitted_tokens(
+ draft_tokens, k)
if draft_tokens > 0:
draft_acceptance_rate = accepted_tokens / draft_tokens
else:
draft_acceptance_rate = float("nan")
- if num_possible_tokens > 0:
- system_efficiency = emitted_tokens / num_possible_tokens
+ if max_num_emitted_tokens > 0:
+ system_efficiency = emitted_tokens / max_num_emitted_tokens
else:
system_efficiency = float("nan")
@@ -169,8 +170,22 @@ def _collect_rejsample_metrics(
)
@staticmethod
- def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
- # Divide by k since batch size can be variable.
- total_num_spec_seqs = draft_tokens / k
- num_accepted_per_seq_if_all_accepted = k + 1
- return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)
+ def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
+ """Calculate the number of emitted tokens, assuming all tokens are
+ accepted.
+
+ This is equal to the number of sequences that have been speculated on,
+ times (speculation len + 1). The +1 comes from the bonus token.
+ """
+ # Determine the number of sequences that have been speculated on. Since
+ # the batch size can be variable, we divide by k.
+ assert draft_tokens % k == 0
+ total_num_spec_seqs = draft_tokens // k
+
+ # A single sequence may emit k accepted tokens and one bonus token in
+ # the best case.
+ num_emitted_per_seq_if_all_accepted = k + 1
+
+ # The max num of emitted tokens is the number of speculated sequences
+ # times the max emitted per seq.
+ return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted
diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py
index 8b722476853fa..7cf338bbae5f0 100644
--- a/vllm/spec_decode/multi_step_worker.py
+++ b/vllm/spec_decode/multi_step_worker.py
@@ -6,8 +6,7 @@
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
-from vllm.spec_decode.util import (maybe_mock_device_tensors,
- sampler_output_to_torch)
+from vllm.spec_decode.util import sampler_output_to_torch
from vllm.worker.worker import Worker
@@ -329,12 +328,15 @@ def _merge_outputs(
"""
if maybe_sampler_output is None:
# If no speculative tokens, the sampler output will be None.
- # In this case we return empty tensors.
- proposal_tokens = torch.zeros(0,
- max_proposal_len,
- dtype=torch.long,
- device=self._device)
- proposal_probs = torch.zeros(0,
+ # In this case we return empty proposals.
+ proposal_tokens = torch.full(size=(
+ batch_size,
+ max_proposal_len,
+ ),
+ fill_value=-1,
+ dtype=torch.long,
+ device=self._device)
+ proposal_probs = torch.zeros(batch_size,
max_proposal_len,
self._vocab_size,
dtype=torch.float32,
@@ -345,17 +347,6 @@ def _merge_outputs(
return proposal_tokens, proposal_probs, proposal_lens_tensor
sampler_output = maybe_sampler_output
-
- # We mock the device tensors until PR 7/9 is merged (e2e correctness).
- # https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
- for step_output in sampler_output:
- maybe_mock_device_tensors(
- sampler_output=step_output,
- batch_size=len(proposal_lens),
- vocab_size=self._vocab_size,
- device=self._device,
- )
-
proposal_tokens, proposal_probs = sampler_output_to_torch(
sampler_output)
diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py
index 68a2a774ef4b7..2c6642f5a3c81 100644
--- a/vllm/spec_decode/spec_decode_worker.py
+++ b/vllm/spec_decode/spec_decode_worker.py
@@ -111,6 +111,32 @@ def init_device(self) -> None:
device=self.device,
vocab_size=self._vocab_size)
+ self._configure_model_sampler_for_spec_decode()
+
+ def _configure_model_sampler_for_spec_decode(self):
+ """Configure model sampler to emit GPU tensors. This allows spec decode
+ to keep data on device without transferring to CPU and serializing,
+ which significantly reduces overhead of rejection sampling.
+
+ NOTE(cade): This breaks abstraction boundaries pretty badly. The better
+ design is to have the "move to CPU and serialize" sampling decision be
+ done outside of the model/sampler; this way the "last-mile" worker
+ object which interfaces with the scheduler can serialize and incur the
+ performance hit as necessary. This allows us to run the worker several
+ iterations in a row without incurring the "move to CPU and serialize"
+ performance penalty.
+
+ Since this requires a large change to vLLM, we defer it to later and
+ temporarily accept this broken abstraction boundary.
+
+ NOTE(cade): This will require a special check if the proposer worker
+ does not have a sampler (e.g. ngram speculation).
+ """
+ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
+ ) = True
+ (self.proposer_worker.model_runner.model.sampler.
+ include_gpu_probs_tensor) = True
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
@@ -286,15 +312,26 @@ def _verify_tokens(
select_proposal_len_zero=True)
original_indices = spec_indices + non_spec_indices
- proposal_probs = proposal_scores.probs[spec_indices, :-1]
- bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
+ # Get probabilities of target model, excluding bonus token.
+ proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1]
+
+ # Get non-speculative sampled tokens from target model.
non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
+ # Get bonus tokens from target model.
+ bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
+
+ # Get probabilities according to proposal method.
+ proposal_probs = proposals.proposal_probs[spec_indices]
+
+ # Get proposed tokens.
+ proposal_token_ids = proposals.proposal_token_ids[spec_indices]
+
accepted_token_ids = self.rejection_sampler(
- proposal_probs,
- bonus_token_ids,
- proposals.proposal_probs,
- proposals.proposal_token_ids,
+ target_probs=proposal_verifier_probs,
+ bonus_token_ids=bonus_token_ids,
+ draft_probs=proposal_probs,
+ draft_token_ids=proposal_token_ids,
)
# Append output tokens from non-speculative sequences to
From d3c8180ac4143f4affd2ef26855058e96b72b5f5 Mon Sep 17 00:00:00 2001
From: Jack Gordley
Date: Tue, 23 Apr 2024 12:06:29 +0100
Subject: [PATCH 100/413] [Bugfix] Fixing max token error message for openai
compatible server (#4016)
---
vllm/entrypoints/openai/serving_engine.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 610e807cae4c7..31da27a447c6c 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -206,6 +206,12 @@ def _validate_prompt_and_tokenize(
token_num = len(input_ids)
if request.max_tokens is None:
+ if token_num >= self.max_model_len:
+ raise ValueError(
+ f"This model's maximum context length is "
+ f"{self.max_model_len} tokens. However, you requested "
+ f"{token_num} tokens in the messages, "
+ f"Please reduce the length of the messages.", )
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
From d87f39e9a9dd149f5dd7a58b4d98b21f713827b6 Mon Sep 17 00:00:00 2001
From: DefTruth <31974251+DefTruth@users.noreply.github.com>
Date: Wed, 24 Apr 2024 00:28:35 +0800
Subject: [PATCH 101/413] [Bugfix] Add init_cached_hf_modules to
RayWorkerWrapper (#4286)
---
vllm/executor/ray_gpu_executor.py | 2 ++
vllm/worker/worker_base.py | 7 ++++++-
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index d0b5e682bb6f7..e69f104e7d5a4 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -100,6 +100,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
+ trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote())
@@ -110,6 +111,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
+ trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index bcd04e0f98db6..b5dade0a770a0 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -103,10 +103,15 @@ class WorkerWrapperBase:
def __init__(self,
worker_module_name=None,
- worker_class_name=None) -> None:
+ worker_class_name=None,
+ trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
+ if trust_remote_code:
+ # note: lazy import to avoid importing torch before initializing
+ from vllm.utils import init_cached_hf_modules
+ init_cached_hf_modules()
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
From d86285a4a4b79b883620d2878c0b52b22ad4c640 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Tue, 23 Apr 2024 09:45:52 -0700
Subject: [PATCH 102/413] [Core][Logging] Add last frame information for better
debugging (#4278)
---
vllm/logger.py | 18 ++++++++++++++++--
1 file changed, 16 insertions(+), 2 deletions(-)
diff --git a/vllm/logger.py b/vllm/logger.py
index 046f0e9099a4b..341fc473585d7 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -83,13 +83,27 @@ def _trace_calls(log_path, root_dir, frame, event, arg=None):
return
# Log every function call or return
try:
+ last_frame = frame.f_back
+ if last_frame is not None:
+ last_filename = last_frame.f_code.co_filename
+ last_lineno = last_frame.f_lineno
+ last_func_name = last_frame.f_code.co_name
+ else:
+ # initial frame
+ last_filename = ""
+ last_lineno = 0
+ last_func_name = ""
with open(log_path, 'a') as f:
if event == 'call':
f.write(f"{datetime.datetime.now()} Call to"
- f" {func_name} in {filename}:{lineno}\n")
+ f" {func_name} in {filename}:{lineno}"
+ f" from {last_func_name} in {last_filename}:"
+ f"{last_lineno}\n")
else:
f.write(f"{datetime.datetime.now()} Return from"
- f" {func_name} in {filename}:{lineno}\n")
+ f" {func_name} in {filename}:{lineno}"
+ f" to {last_func_name} in {last_filename}:"
+ f"{last_lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
From 62b5166bd4c458c3a8f6eda89d3ef9db14a4c2c8 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Tue, 23 Apr 2024 09:51:41 -0700
Subject: [PATCH 103/413] [CI] Add ccache for wheel builds job (#4281)
---
.github/workflows/publish.yml | 3 +++
1 file changed, 3 insertions(+)
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index fc97e33c19af2..4b9fc3d04d872 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -56,6 +56,9 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
+ - name: Setup ccache
+ uses: hendrikmuhs/ccache-action@v1.2
+
- name: Set up Linux Env
if: ${{ runner.os == 'Linux' }}
run: |
From 2b7949c1c2e34de41d9cfc84dd0e377cc6bd58c2 Mon Sep 17 00:00:00 2001
From: James Fleming
Date: Tue, 23 Apr 2024 13:59:33 -0400
Subject: [PATCH 104/413] AQLM CUDA support (#3287)
Co-authored-by: mgoin
---
CMakeLists.txt | 1 +
benchmarks/kernels/benchmark_aqlm.py | 302 ++++++++
csrc/ops.h | 15 +
csrc/pybind.cpp | 2 +
csrc/quantization/aqlm/gemm_kernels.cu | 712 ++++++++++++++++++
examples/aqlm_example.py | 46 ++
tests/models/test_aqlm.py | 95 +++
vllm/model_executor/layers/linear.py | 41 +-
.../layers/quantization/__init__.py | 2 +
.../layers/quantization/aqlm.py | 373 +++++++++
.../model_executor/layers/quantization/awq.py | 4 +-
.../layers/quantization/gptq.py | 3 +-
.../layers/quantization/marlin.py | 3 +-
.../layers/quantization/squeezellm.py | 4 +-
14 files changed, 1592 insertions(+), 11 deletions(-)
create mode 100644 benchmarks/kernels/benchmark_aqlm.py
create mode 100644 csrc/quantization/aqlm/gemm_kernels.cu
create mode 100644 examples/aqlm_example.py
create mode 100644 tests/models/test_aqlm.py
create mode 100644 vllm/model_executor/layers/quantization/aqlm.py
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1845151181284..b2d0cf3e568b7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -173,6 +173,7 @@ set(VLLM_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
+ "csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu")
diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py
new file mode 100644
index 0000000000000..9602d20bcbc74
--- /dev/null
+++ b/benchmarks/kernels/benchmark_aqlm.py
@@ -0,0 +1,302 @@
+import argparse
+import os
+import sys
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from vllm._C import ops
+from vllm.model_executor.layers.quantization.aqlm import (
+ dequantize_weight, generic_dequantize_gemm, get_int_dtype,
+ optimized_dequantize_gemm)
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+
+def torch_mult(
+ input: torch.Tensor, # [..., in_features]
+ weights: torch.Tensor,
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+) -> torch.Tensor:
+ output = F.linear(input, weights)
+ return output
+
+
+def dequant_out_scale(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ if bias is None:
+ output = F.linear(input, weights, bias)
+ orig_shape = output.shape
+ flattened_output = output.view(-1, output.size(-1))
+ f_scales = scales.view(-1, scales.shape[0])
+ b_scales = f_scales.expand(flattened_output.shape[0], -1)
+ flattened_output *= b_scales
+ return flattened_output.view(orig_shape)
+ else:
+ b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
+ -1, weights.shape[1])
+ weights *= b_scales
+ return F.linear(input, weights, bias)
+
+
+def dequant_weight_scale(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
+ -1, weights.shape[1])
+ weights *= b_scales
+ return F.linear(input, weights, bias)
+
+
+def dequant_no_scale(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ return F.linear(input, weights, bias)
+
+
+# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
+# the generic pytorch version.
+# Just visual comparison.
+def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
+
+ n = parts.sum().item()
+
+ device = torch.device('cuda:0')
+
+ code_range = (1 << bits) // 2
+ ingroups = 8
+
+ codes = torch.randint(-code_range,
+ code_range,
+ size=(n, k // ingroups, nbooks),
+ dtype=get_int_dtype(bits),
+ device=device)
+
+ codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
+ dtype=torch.float16,
+ device=device)
+
+ count = 0
+ for index in range(16):
+ for i in range(8):
+ for book in range(nbooks):
+ codebooks[book, index, 0, i] = count * (10**book)
+ count += 1
+
+ print("codes shape", codes.shape)
+
+ for i in range(16):
+ for book in range(nbooks):
+ codes[0, i, book] = i
+ codes[0, -i, book] = i
+
+ weights = dequantize_weight(codes, codebooks, None)
+ weights2 = ops.aqlm_dequant(codes, codebooks, parts)
+
+ print("weights shape:", weights.shape)
+ print("weights2 shape:", weights2.shape)
+
+ print("weights are:", weights)
+ print("weights2 are:", weights2)
+
+ print("first 128 weights are", weights[0, 0:128].to(torch.int32))
+ print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
+
+ print("last 128 weights are", weights[0, -128:])
+ print("last 128 weights2 are:", weights2[0, -128:])
+
+
+def main():
+
+ parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
+
+ # Add arguments
+ parser.add_argument("--nbooks",
+ type=int,
+ default=1,
+ help="Number of codebooks (default: 1)")
+ parser.add_argument("--bits",
+ type=int,
+ default=16,
+ help="Number of bits per code element (default: 16)")
+ parser.add_argument(
+ "--test",
+ type=bool,
+ default=False,
+ help="Run the decompression/dequant tester rather than benchmarking "
+ "(default: False)")
+
+ # Parse the arguments
+ args = parser.parse_args()
+
+ # Extract values
+ nbooks = args.nbooks
+ bits = args.bits
+
+ if args.test:
+ dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
+ return
+
+ # Otherwise, benchmark.
+ methods = [
+ ops.aqlm_gemm,
+ dequant_out_scale,
+ generic_dequantize_gemm,
+ optimized_dequantize_gemm,
+ dequant_weight_scale,
+ torch_mult,
+ dequant_no_scale,
+ ]
+
+ filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
+ print(f"writing benchmarks to file {filename}")
+ with open(filename, "w") as f:
+ sys.stdout = f
+
+ print('m | k | n | n parts', end='')
+ for method in methods:
+ print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
+ print('')
+
+ # These are reasonable prefill sizes.
+ ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
+ (4096, (11008, 11008)), (11008, (4096, )))
+
+ # reasonable ranges for m.
+ for m in [
+ 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
+ 128, 256, 512, 1024, 1536, 2048, 3072, 4096
+ ]:
+ print(f'{m}', file=sys.__stdout__)
+ for ksp in ksandpartions:
+ run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
+ methods)
+
+ sys.stdout = sys.__stdout__
+
+
+def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
+ methods):
+
+ # I didn't see visible improvements from increasing these, but feel free :)
+ num_warmup_trials = 1
+ num_trials = 1
+
+ num_calls = 100
+
+ # warmup.
+ for method in methods:
+ for _ in range(num_warmup_trials):
+ run_timing(
+ num_calls=num_calls,
+ m=m,
+ k=k,
+ parts=parts,
+ nbooks=nbooks,
+ bits=bits,
+ method=method,
+ )
+
+ n = parts.sum().item()
+ print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
+
+ for method in methods:
+ best_time_us = 1e20
+ for _ in range(num_trials):
+ kernel_dur_ms = run_timing(
+ num_calls=num_calls,
+ m=m,
+ k=k,
+ parts=parts,
+ nbooks=nbooks,
+ bits=bits,
+ method=method,
+ )
+
+ kernel_dur_us = 1000 * kernel_dur_ms
+
+ if kernel_dur_us < best_time_us:
+ best_time_us = kernel_dur_us
+
+ print(f' | {kernel_dur_us:.0f}', end='')
+
+ print('')
+
+
+def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
+ nbooks: int, bits: int, method) -> float:
+
+ n = parts.sum().item()
+
+ device = torch.device('cuda:0')
+
+ input = torch.randn((1, m, k), dtype=torch.float16, device=device)
+
+ code_range = (1 << bits) // 2
+ ingroups = 8
+
+ codes = torch.randint(-code_range,
+ code_range,
+ size=(n, k // ingroups, nbooks),
+ dtype=get_int_dtype(bits),
+ device=device)
+
+ codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
+ dtype=torch.float16,
+ device=device)
+
+ scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
+
+ # for comparison to just a pytorch mult.
+ weights = torch.randn((n, k), dtype=torch.float16, device=device)
+
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+
+ start_event.record()
+
+ if method is torch_mult:
+ for i in range(num_calls):
+ torch_mult(input, weights, scales)
+ else:
+ for i in range(num_calls):
+ method(input, codes, codebooks, scales, parts, None)
+
+ end_event.record()
+ end_event.synchronize()
+
+ dur_ms = start_event.elapsed_time(end_event) / num_calls
+ return dur_ms
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/csrc/ops.h b/csrc/ops.h
index 41ecc1e89371b..a379c910d9cf3 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -86,6 +86,21 @@ void gelu_fast(
torch::Tensor& input);
#ifndef USE_ROCM
+torch::Tensor aqlm_gemm(
+ const torch::Tensor& input,
+ const torch::Tensor& codes,
+ const torch::Tensor& codebooks,
+ const torch::Tensor& scales,
+ const torch::Tensor& codebook_partition_sizes,
+ const std::optional& bias
+);
+
+torch::Tensor aqlm_dequant(
+ const torch::Tensor& codes,
+ const torch::Tensor& codebooks,
+ const torch::Tensor& codebook_partition_sizes
+);
+
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp
index de02afc162113..42e92e5382e8e 100644
--- a/csrc/pybind.cpp
+++ b/csrc/pybind.cpp
@@ -63,6 +63,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops
#ifndef USE_ROCM
+ ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
+ ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu
new file mode 100644
index 0000000000000..4415316e1e8cd
--- /dev/null
+++ b/csrc/quantization/aqlm/gemm_kernels.cu
@@ -0,0 +1,712 @@
+/*
+ * Modified by Neural Magic
+ * Adapted from https://github.com/Vahe1994/AQLM
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+namespace vllm {
+namespace aqlm {
+
+__global__ void Code1x16MatVec(
+ const int4* __restrict__ A,
+ const int4* __restrict__ B,
+ int4* __restrict__ C,
+ const int4* __restrict__ codebook,
+ const int prob_m,
+ const int prob_k,
+ const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
+ const int codebook_stride // as int4.
+) {
+ int a_gl_stride = prob_k / 8 / 8;
+ int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
+ bool pred = a_gl_rd < prob_m;
+
+ if (pred)
+ {
+ // advance to the correct codebook, this easy because we only multiply one column of the codebook.
+ auto codebook_size = &codebook_a_sizes.x;
+ while (a_gl_rd >= *codebook_size)
+ {
+ codebook += codebook_stride;
+ ++codebook_size;
+ }
+ }
+
+ int b_gl_rd = 0;
+ int c_gl_wr = a_gl_rd;
+ a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
+ int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
+
+ __shared__ int4 sh_b[32 * 9];
+ float res = 0;
+
+ int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
+ while (iters--) {
+ // We pad shared memory to avoid bank conflicts during reads
+ __syncthreads();
+ for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
+ if (b_gl_rd + i < prob_k / 8)
+ sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
+ }
+ __syncthreads();
+ b_gl_rd += 32 * 8;
+
+ int b_sh_rd = 9 * (threadIdx.x % 32);
+ if (pred && a_gl_rd < a_gl_end) {
+ const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]);
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ uint32_t dec[4];
+ // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
+ // actually help us; this brings > 2x speedup.
+ asm volatile (
+ "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
+ : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
+ : "l"((void*) &codebook[enc[i]])
+ );
+ half2* a = reinterpret_cast(&dec);
+ half2* b = reinterpret_cast(&sh_b[b_sh_rd]);
+ half2 res2 = {};
+ #pragma unroll
+ for (int j = 0; j < 4; j++)
+ res2 = __hfma2(a[j], b[j], res2);
+ res += __half2float(res2.x) + __half2float(res2.y);
+ b_sh_rd++;
+ }
+ a_gl_rd += 32;
+ }
+ }
+
+ if (pred) {
+ #pragma unroll
+ for (int i = 16; i > 0; i /= 2)
+ res += __shfl_down_sync(0xffffffff, res, i);
+ if (threadIdx.x % 32 == 0)
+ reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
+ }
+}
+
+__global__ void Code2x8MatVec(
+ const int4* __restrict__ A,
+ const int4* __restrict__ B,
+ int4* __restrict__ C,
+ const int4* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
+ const int codebook_stride // as int4.
+
+) {
+ int a_gl_stride = prob_k / 8 / 8;
+ int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
+ bool pred = a_gl_rd < prob_m;
+
+ if (pred)
+ {
+ // advance to the correct codebook, this easy because we only multiply one column of the codebook.
+ auto codebook_size = &codebook_a_sizes.x;
+ while (a_gl_rd >= *codebook_size)
+ {
+ codebook += codebook_stride;
+ ++codebook_size;
+ }
+ }
+
+ int b_gl_rd = 0;
+ int c_gl_wr = a_gl_rd;
+ a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
+ int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
+ int lane = threadIdx.x % 8;
+
+ extern __shared__ int4 sh[];
+ int4* sh_b = sh;
+ int4* sh_code = sh_b + 32 * 9;
+ int4* sh_code0 = sh_code;
+ int4* sh_code1 = sh_code + 256 * 8;
+
+ for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
+ int4 dec = codebook[i];
+ #pragma unroll
+ for (int j = 0; j < 8; j++)
+ sh_code[8 * i + (j + lane) % 8] = dec;
+ }
+ __syncthreads();
+
+ float res = 0;
+
+ int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
+ while (iters--) {
+ // We pad shared memory to avoid bank conflicts during reads
+ __syncthreads();
+ for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
+ if (b_gl_rd + i < prob_k / 8)
+ sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
+ }
+ __syncthreads();
+ b_gl_rd += 32 * 8;
+
+ int b_sh_rd = 9 * (threadIdx.x % 32);
+ if (pred && a_gl_rd < a_gl_end) {
+ const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]);
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]);
+ half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]);
+ half2* b = reinterpret_cast(&sh_b[b_sh_rd]);
+ half2 res2 = {};
+ #pragma unroll
+ for (int j = 0; j < 4; j++)
+ res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
+ res += __half2float(res2.x) + __half2float(res2.y);
+ b_sh_rd++;
+ }
+ a_gl_rd += 32;
+ }
+ }
+
+ if (pred) {
+ #pragma unroll
+ for (int i = 16; i > 0; i /= 2)
+ res += __shfl_down_sync(0xffffffff, res, i);
+ if (threadIdx.x % 32 == 0)
+ reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
+ }
+}
+
+
+__global__ void Code1x16Dequant(
+ const int4* __restrict__ A,
+ int4* __restrict__ C,
+ const int4* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
+ const int codebook_stride // as int4
+) {
+ int a_gl_stride = prob_k / 8 / 8;
+ int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
+ bool pred = a_gl_rd < prob_m;
+
+ if (pred)
+ {
+ // advance to the correct codebook, this easy because we only multiply one column of the codebook.
+ auto codebook_size = &codebook_a_sizes.x;
+ while (a_gl_rd >= *codebook_size)
+ {
+ codebook += codebook_stride;
+ ++codebook_size;
+ }
+ }
+
+ a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
+ int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
+
+ int c_gl_stride = prob_k / 8;
+ int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
+ c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
+
+ int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
+ while (iters--) {
+ if (pred && a_gl_rd < a_gl_end) {
+ const uint16_t* enc = reinterpret_cast(&A[a_gl_rd]);
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ int4 chunk;
+ auto dec = reinterpret_cast(&chunk);
+ // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
+ // actually help us; this brings > 2x speedup.
+ asm volatile (
+ "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
+ : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
+ : "l"((void*) &codebook[enc[i]])
+ );
+
+ C[a_gl_rd * 8 + i] = chunk;
+ }
+ }
+ a_gl_rd += 32;
+ }
+}
+
+
+__global__ void Code2x8Dequant(
+ const int4* __restrict__ A,
+ int4* __restrict__ C,
+ const int4* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
+ const int codebook_stride // as int4
+) {
+ int a_gl_stride = prob_k / 8 / 8;
+ int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
+ bool pred = a_gl_rd < prob_m;
+
+ if (pred)
+ {
+ // advance to the correct codebook, this easy because we only multiply one column of the codebook.
+ auto codebook_size = &codebook_a_sizes.x;
+ while (a_gl_rd >= *codebook_size)
+ {
+ codebook += codebook_stride;
+ ++codebook_size;
+ }
+ }
+
+ a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
+ int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
+ int lane = threadIdx.x % 8;
+
+ int c_gl_stride = prob_k / 8;
+ int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
+ c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
+
+ extern __shared__ int4 sh[];
+ int4* sh_code = sh;
+ int4* sh_code0 = sh_code;
+ int4* sh_code1 = sh_code + 256 * 8;
+
+ for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
+ int4 dec = codebook[i];
+ #pragma unroll
+ for (int j = 0; j < 8; j++)
+ sh_code[8 * i + (j + lane) % 8] = dec;
+ }
+ __syncthreads();
+
+ float res = 0;
+
+ int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
+ while (iters--) {
+ if (pred && a_gl_rd < a_gl_end) {
+ const uint8_t* enc = reinterpret_cast(&A[a_gl_rd]);
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ int4 chunk;
+ half2* a0 = reinterpret_cast(&sh_code0[8 * enc[2 * i + 0] + lane]);
+ half2* a1 = reinterpret_cast(&sh_code1[8 * enc[2 * i + 1] + lane]);
+ #pragma unroll
+ for (int j = 0; j < 4; j++)
+ reinterpret_cast(&chunk)[j] = __hadd2(a0[j], a1[j]);
+ C[a_gl_rd * 8 + i] = chunk;
+ }
+ }
+ a_gl_rd += 32;
+ }
+}
+
+inline int ceildiv(int a, int b) {
+ return (a + b - 1) / b;
+}
+
+const int THREAD_M = 16;
+
+void code1x16_matvec_cuda(
+ const void* __restrict__ A,
+ const void* __restrict__ B,
+ void* __restrict__ C,
+ const void* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes,
+ const int codebook_stride
+) {
+ int sms;
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
+ int waves = 0;
+ int thread_m;
+ do {
+ waves++;
+ thread_m = ceildiv(prob_m, waves * sms);
+ } while (thread_m > THREAD_M);
+
+ int blocks = ceildiv(prob_m, thread_m);
+ int threads = 32 * thread_m;
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ Code1x16MatVec<<>>(
+ (const int4*) A,
+ (const int4*) B,
+ (int4*) C,
+ (const int4*) codebook,
+ prob_m,
+ prob_k,
+ codebook_a_sizes,
+ codebook_stride
+ );
+}
+
+void code2x8_matvec_cuda(
+ const void* __restrict__ A,
+ const void* __restrict__ B,
+ void* __restrict__ C,
+ const void* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes,
+ const int codebook_stride
+) {
+ int sms;
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
+ int waves = 0;
+ int thread_m;
+ do {
+ waves++;
+ thread_m = ceildiv(prob_m, waves * sms);
+ } while (thread_m > THREAD_M);
+
+ int blocks = ceildiv(prob_m, thread_m);
+ int threads = 32 * thread_m;
+ int shared = 16 * (2 * 256 * 8 + 32 * 9);
+ cudaFuncSetAttribute(
+ Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
+ );
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ Code2x8MatVec<<>>(
+ (const int4*) A,
+ (const int4*) B,
+ (int4*) C,
+ (const int4*) codebook,
+ prob_m,
+ prob_k,
+ codebook_a_sizes,
+ codebook_stride
+ );
+}
+
+void code1x16_dequant_cuda(
+ const void* __restrict__ A,
+ void* __restrict__ C,
+ const void* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
+ const int codebook_stride // as int4.
+) {
+ int sms;
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
+ int waves = 0;
+ int thread_m;
+ do {
+ waves++;
+ thread_m = ceildiv(prob_m, waves * sms);
+ } while (thread_m > THREAD_M);
+
+ int blocks = ceildiv(prob_m, thread_m);
+ int threads = 32 * thread_m;
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+ Code1x16Dequant<<>>(
+ (const int4*) A,
+ (int4*) C,
+ (const int4*) codebook,
+ prob_m,
+ prob_k,
+ codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
+ codebook_stride // as int4.
+ );
+}
+
+// Dequantizes the code and codebook into weights.
+void code2x8_dequant_cuda(
+ const void* __restrict__ A,
+ void* __restrict__ C,
+ const void* __restrict__ codebook,
+ int prob_m,
+ int prob_k,
+ const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
+ const int codebook_stride // as int4
+) {
+ int sms;
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
+ int waves = 0;
+ int thread_m;
+ do {
+ waves++;
+ thread_m = ceildiv(prob_m, waves * sms);
+ } while (thread_m > THREAD_M);
+
+ int blocks = ceildiv(prob_m, thread_m);
+ int threads = 32 * thread_m;
+ int shared = 16 * (2 * 256 * 8 + 32 * 9);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+
+ cudaFuncSetAttribute(
+ Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
+ );
+ Code2x8Dequant<<>>(
+ (const int4*) A,
+ (int4*) C,
+ (const int4*) codebook,
+ prob_m,
+ prob_k,
+ codebook_a_sizes,
+ codebook_stride
+ );
+}
+
+int codebook_stride(const torch::Tensor& codebooks)
+{
+ return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
+}
+
+void code1x16_matvec(
+ const torch::Tensor& A,
+ const torch::Tensor& B,
+ torch::Tensor& C,
+ const torch::Tensor& codebook,
+ const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
+ int prob_m = C.size(0);
+ int prob_k = B.size(0);
+
+ code1x16_matvec_cuda(
+ A.data_ptr(),
+ B.data_ptr(),
+ C.data_ptr(),
+ codebook.data_ptr(),
+ prob_m,
+ prob_k,
+ codebook_a_sizes,
+ codebook_stride(codebook)
+ );
+}
+
+torch::Tensor code1x16_matmat(
+ const torch::Tensor& input,
+ const torch::Tensor& codes,
+ const torch::Tensor& codebooks,
+ const torch::Tensor& scales,
+ const int4 codebook_a_sizes,
+ const std::optional& bias) {
+ auto input_sizes = input.sizes();
+ auto out_features = codes.size(0) * codebooks.size(2);
+ auto flat_input = input.reshape({-1, input.size(-1)});
+ auto flat_output = torch::empty({flat_input.size(0), out_features},
+ torch::TensorOptions()
+ .dtype(input.dtype())
+ .device(input.device())
+ );
+
+ for (int i = 0; i < flat_input.size(0); ++i) {
+ auto input_vec = flat_input.index({i});
+ auto output_vec = flat_output.index({i});
+ code1x16_matvec(
+ codes.squeeze(2),
+ input_vec,
+ output_vec,
+ codebooks,
+ codebook_a_sizes
+ );
+ }
+ flat_output *= scales.flatten().unsqueeze(0);
+
+ if (bias.has_value()) {
+ flat_output += bias->unsqueeze(0);
+ }
+
+ auto output_sizes = input_sizes.vec();
+ output_sizes.pop_back();
+ output_sizes.push_back(-1);
+ auto output = flat_output.reshape(output_sizes);
+ return output;
+}
+
+void code2x8_matvec(
+ const torch::Tensor& A,
+ const torch::Tensor& B,
+ torch::Tensor& C,
+ const torch::Tensor& codebook,
+ const int4 codebook_a_sizes
+) {
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
+ int prob_m = C.size(0);
+ int prob_k = B.size(0);
+ code2x8_matvec_cuda(
+ A.data_ptr(),
+ B.data_ptr(),
+ C.data_ptr(),
+ codebook.data_ptr(),
+ prob_m,
+ prob_k,
+ codebook_a_sizes,
+ 2 * codebook_stride(codebook)
+ );
+}
+
+torch::Tensor code2x8_matmat(
+ const torch::Tensor& input,
+ const torch::Tensor& codes,
+ const torch::Tensor& codebooks,
+ const torch::Tensor& scales,
+ const int4 codebook_a_sizes,
+ const std::optional& bias
+) {
+ auto input_sizes = input.sizes();
+ auto out_features = codes.size(0) * codebooks.size(2);
+ auto flat_input = input.reshape({-1, input.size(-1)});
+ auto flat_output = torch::empty({flat_input.size(0), out_features},
+ torch::TensorOptions()
+ .dtype(input.dtype())
+ .device(input.device())
+ );
+
+ for (int i = 0; i < flat_input.size(0); ++i) {
+ auto input_vec = flat_input.index({i});
+ auto output_vec = flat_output.index({i});
+ code2x8_matvec(
+ codes.squeeze(2),
+ input_vec,
+ output_vec,
+ codebooks,
+ codebook_a_sizes
+ );
+ }
+ flat_output *= scales.flatten().unsqueeze(0);
+ if (bias.has_value()) {
+ flat_output += bias->unsqueeze(0);
+ }
+
+ auto output_sizes = input_sizes.vec();
+ output_sizes.pop_back();
+ output_sizes.push_back(-1);
+ auto output = flat_output.reshape(output_sizes);
+ return output;
+}
+
+// Accumulate the partition sizes.
+int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
+{
+ int4 cumulative_sizes;
+ auto cumulative_size = &cumulative_sizes.x;
+ int i = 0;
+ int last = 0;
+ assert(codebook_partition_sizes.size(0) <= 4);
+ for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
+ {
+ *cumulative_size = codebook_partition_sizes[i].item() + last;
+ last = *cumulative_size;
+ }
+ // fill in the rest with unreachable.
+ for (; i < 4; ++i, ++cumulative_size)
+ {
+ *cumulative_size = last*10;
+ }
+ return cumulative_sizes;
+}
+
+} // namespace aqlm
+} // namespace vllm
+
+
+torch::Tensor aqlm_gemm(
+ const torch::Tensor& input,
+ const torch::Tensor& codes,
+ const torch::Tensor& codebooks,
+ const torch::Tensor& scales,
+ const torch::Tensor& codebook_partition_sizes,
+ const std::optional& bias
+)
+{
+ int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
+
+ int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
+ int const entries = codebooks.size(1);
+
+ if (nbooks == 1 && entries == (1 << 16))
+ {
+ return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
+ }
+ if (nbooks == 2 && entries == (1 << 8))
+ {
+ return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
+ }
+
+ TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
+ return {};
+}
+
+torch::Tensor aqlm_dequant(
+ const torch::Tensor& codes,
+ const torch::Tensor& codebooks,
+ const torch::Tensor& codebook_partition_sizes
+)
+{
+ int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
+
+ int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
+ int const entries = codebooks.size(1);
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
+ int rows = codes.size(1);
+ int cols = codes.size(0);
+
+ auto in_features = codes.size(1) * 8;
+ auto out_features = codes.size(0);
+
+ assert(out_features = codebook_partition_sizes.sum().item());
+
+ auto weights = torch::empty({out_features, in_features},
+ torch::TensorOptions()
+ .dtype(codebooks.dtype())
+ .device(codebooks.device())
+ );
+
+ if (nbooks == 1 && entries == (1 << 16))
+ {
+ vllm::aqlm::code1x16_dequant_cuda(
+ codes.data_ptr(),
+ weights.data_ptr(),
+ codebooks.data_ptr(),
+ out_features,
+ in_features,
+ cumulative_sizes,
+ vllm::aqlm::codebook_stride(codebooks));
+
+ // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
+ // weights *= scales.index({"...", 0, 0});
+
+ return weights;
+ }
+
+ if (nbooks == 2 && entries == (1 << 8))
+ {
+ vllm::aqlm::code2x8_dequant_cuda(
+ codes.data_ptr(),
+ weights.data_ptr(),
+ codebooks.data_ptr(),
+ out_features,
+ in_features,
+ cumulative_sizes,
+ vllm::aqlm::codebook_stride(codebooks));
+
+ // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
+ // weights *= scales.index({"...", 0, 0});
+
+ return weights;
+ }
+
+ TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
+ return {};
+}
diff --git a/examples/aqlm_example.py b/examples/aqlm_example.py
new file mode 100644
index 0000000000000..e7c17fa0362ae
--- /dev/null
+++ b/examples/aqlm_example.py
@@ -0,0 +1,46 @@
+import argparse
+
+from vllm import LLM, SamplingParams
+
+
+def main():
+
+ parser = argparse.ArgumentParser(description='AQLM examples')
+
+ parser.add_argument('--model',
+ '-m',
+ type=str,
+ default=None,
+ help='model path, as for HF')
+ parser.add_argument('--choice',
+ '-c',
+ type=int,
+ default=0,
+ help='known good models by index, [0-4]')
+ parser.add_argument('--tensor_parallel_size',
+ '-t',
+ type=int,
+ default=1,
+ help='tensor parallel size')
+
+ args = parser.parse_args()
+
+ models = [
+ "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
+ "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
+ "ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
+ "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
+ "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
+ ]
+
+ model = LLM(args.model if args.model is not None else models[args.choice],
+ tensor_parallel_size=args.tensor_parallel_size)
+
+ sampling_params = SamplingParams(max_tokens=100, temperature=0)
+ outputs = model.generate("Hello my name is",
+ sampling_params=sampling_params)
+ print(outputs[0].outputs[0].text)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py
new file mode 100644
index 0000000000000..a7abc011f57d7
--- /dev/null
+++ b/tests/models/test_aqlm.py
@@ -0,0 +1,95 @@
+"""Compare the outputs of a AQLM model between vLLM and HF Transformers
+
+Run `pytest tests/models/test_aqlm.py`.
+"""
+
+import pytest
+import torch
+
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+
+capability = torch.cuda.get_device_capability()
+capability = capability[0] * 10 + capability[1]
+aqlm_not_supported = (capability <
+ QUANTIZATION_METHODS["aqlm"].get_min_capability())
+
+# In this test we hardcode prompts and generations for the model so we don't
+# need to require the AQLM package as a dependency
+example_prompts = [
+ 'vLLM is a high-throughput and memory-efficient inference and serving '
+ 'engine for LLMs.\n',
+ 'Briefly describe the major milestones in the development of artificial '
+ 'intelligence from 1950 to 2020.\n',
+ 'Compare and contrast artificial intelligence with human intelligence in '
+ 'terms of processing information.\n',
+ 'Describe the basic components of a neural network and how it can be '
+ 'trained.\n',
+ 'Write a short story about a robot that dreams for the first time.\n',
+ 'Analyze the impact of the COVID-19 pandemic on global economic structures '
+ 'and future business models.\n',
+ 'Explain the cultural significance of the Mona Lisa painting, and how its '
+ 'perception might vary in Western versus Eastern societies.\n',
+ "Translate the following English sentence into Japanese, French, and "
+ "Swahili: 'The early bird catches the worm.'\n"
+]
+
+# These ground truth generations were generated using `transformers==4.38.1
+# aqlm==1.1.0 torch==2.2.0`
+# and the below code:
+# ```python
+# from transformers import AutoTokenizer, AutoModelForCausalLM
+# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
+# quantized_model = AutoModelForCausalLM.from_pretrained(model_id,
+# torch_dtype="auto", device_map="cuda").cuda()
+# tokenizer = AutoTokenizer.from_pretrained(model_id)
+# outputs = []
+# for prompt in example_prompts:
+# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
+# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32)
+# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:]))
+# print(outputs)
+# ```
+ground_truth_generations = [
+ '\n### Features\n\n- **High-throughput**: v',
+ 'The major milestones in the development of artificial intelligence from '
+ '195',
+ 'Compare and contrast artificial intelligence with human intelligence in '
+ 'terms of processing information. The',
+ 'Explain the difference between supervised and unsupervised learning.'
+ '\nExplain',
+ 'Write a short story about a robot that dreams for the first time. The',
+ 'Analyze the impact of the COVID-19 pandemic on global economic',
+ 'The Mona Lisa is a painting by Leonardo da Vinci, and it',
+ 'The early bird catches the worm.\nThe early bird catches the'
+]
+
+
+@pytest.mark.skipif(aqlm_not_supported,
+ reason="AQLM is not supported on this GPU type.")
+@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [16])
+@pytest.mark.parametrize("num_logprobs", [1])
+def test_models(
+ vllm_runner,
+ example_prompts,
+ model: str,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+) -> None:
+
+ vllm_model = vllm_runner(model, dtype=dtype)
+ vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
+ max_tokens,
+ num_logprobs)
+
+ # loop through the prompts to compare against the ground truth generations
+ for prompt_idx in range(len(example_prompts)):
+ vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[
+ prompt_idx]
+
+ print("Prompt: ", repr(example_prompts[prompt_idx]))
+ print("Reference output:", repr(ground_truth_generations[prompt_idx]))
+ print("Output output: ", repr(vllm_output_str))
+ assert vllm_output_str == ground_truth_generations[prompt_idx]
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index d466d8807fc64..e56af9075e2fd 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -31,7 +31,7 @@ class LinearMethodBase(ABC):
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int, input_size: int,
+ output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
@@ -70,9 +70,10 @@ def __init__(self, separate_bias_add: bool = False):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int, input_size: int,
+ output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
+ output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
@@ -127,7 +128,7 @@ def __init__(
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_method.create_weights(self, self.input_size,
- self.output_size, self.input_size,
+ [self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
@@ -161,6 +162,8 @@ class ColumnParallelLinear(torch.nn.Module):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method.
+ output_sizes: list of output sizes packed into one output, like for QKV
+ the list would be size 3.
"""
def __init__(
@@ -172,6 +175,7 @@ def __init__(
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
linear_method: Optional[LinearMethodBase] = None,
+ output_sizes: Optional[List[int]] = None,
):
super().__init__()
@@ -188,10 +192,12 @@ def __init__(
self.params_dtype = params_dtype
if linear_method is None:
linear_method = UnquantizedLinearMethod()
+ if output_sizes is None:
+ output_sizes = [output_size]
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.input_size,
- self.output_size_per_partition,
+ [x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
@@ -268,14 +274,17 @@ def __init__(
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
- skip_bias_add, params_dtype, linear_method)
+ skip_bias_add, params_dtype, linear_method,
+ self.output_sizes)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
+
param_data = param.data
output_dim = getattr(param, "output_dim", None)
+ is_metadata = getattr(param, "is_metadata", False)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
@@ -328,6 +337,11 @@ def weight_loader(self,
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
+ elif is_metadata:
+ # metadata indicates fixed size concatenated along dim 0
+ shard_size = loaded_weight.shape[0]
+ shard_offset = loaded_shard_id * shard_size
+ param_data = param_data.narrow(0, shard_offset, shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
@@ -393,8 +407,14 @@ def __init__(
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
+ output_sizes = [
+ self.num_heads * tp_size * self.head_size,
+ self.num_kv_heads * tp_size * self.head_size,
+ self.num_kv_heads * tp_size * self.head_size
+ ]
+
super().__init__(input_size, output_size, bias, False, skip_bias_add,
- params_dtype, linear_method)
+ params_dtype, linear_method, output_sizes)
def weight_loader(self,
param: Parameter,
@@ -402,6 +422,7 @@ def weight_loader(self,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
+ is_metadata = getattr(param, "is_metadata", False)
if loaded_shard_id is None:
# Loaded weight is already packed.
@@ -469,6 +490,12 @@ def weight_loader(self,
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
+ elif is_metadata:
+ # metadata indicates fixed size concatenated along dim 0
+ shard_size = loaded_weight.shape[0]
+ shard_index = ["q", "k", "v"].index(loaded_shard_id)
+ param_data = param_data.narrow(0, shard_index * shard_size,
+ shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
@@ -536,7 +563,7 @@ def __init__(
self.linear_method = linear_method
self.linear_method.create_weights(self,
self.input_size_per_partition,
- self.output_size,
+ [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index 0344d6e4e3e45..a525add458499 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -1,5 +1,6 @@
from typing import Type
+from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@@ -9,6 +10,7 @@
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS = {
+ "aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": FP8Config,
"gptq": GPTQConfig,
diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py
new file mode 100644
index 0000000000000..6115b1de679ad
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/aqlm.py
@@ -0,0 +1,373 @@
+# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
+# and https://arxiv.org/pdf/2401.06118.pdf
+
+import math
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+
+from vllm._C import ops
+from vllm.model_executor.layers.linear import (LinearMethodBase,
+ set_weight_attrs)
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+
+
+def get_int_dtype(nbits: int) -> torch.dtype:
+ if nbits <= 8:
+ return torch.int8
+ if nbits <= 16:
+ return torch.int16
+ if nbits <= 32:
+ return torch.int32
+ if nbits <= 64:
+ return torch.int64
+ raise ValueError(f"No dtype available for {nbits}-bit codebooks")
+
+
+@torch.inference_mode()
+def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
+ return data.to(torch.int64) % (2**nbits)
+
+
+def dequantize_weight(codes: torch.Tensor,
+ codebooks: torch.Tensor,
+ scales: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Decode float weights from quantization codes. Differentiable.
+ :param codes: tensor of integer quantization codes, shape
+ [*dims, num_out_groups, num_in_groups, num_codebooks]
+ :param codebooks: tensor of vectors for each quantization code,
+ [num_codebooks, codebook_size, out_group_size, in_group_size]
+ :param scales: weight will be multiplied by this factor, must be
+ broadcastble with
+ [*dims, out_groups, num_in_groups, out_group_size, in_group_size]
+ :return: reconstructed weight tensor of shape
+ [*dims, num_in_groups*group_size]
+ """
+ num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
+ num_codebooks, codebook_size, out_group_size, in_group_size = \
+ codebooks.shape
+ out_features = num_out_groups * out_group_size
+ in_features = num_in_groups * in_group_size
+ codebook_offsets = torch.arange(
+ 0, num_codebooks * codebook_size, codebook_size,
+ device=codes.device) # shape: [num_codebooks]
+ reconstructed_weight_flat = F.embedding_bag(
+ codes.flatten(0, -2) + codebook_offsets,
+ codebooks.flatten(0, 1).flatten(-2, -1),
+ mode="sum"
+ ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
+ # * in_group_size]
+
+ reconstructed_weight_groupwise = reconstructed_weight_flat.view(
+ list(codes.shape[:-3]) +
+ [num_out_groups, num_in_groups, out_group_size, in_group_size])
+ if scales is not None:
+ reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
+ scales)
+ return reconstructed_weight_groupwise.swapaxes(
+ -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
+
+
+def dequantize_gemm(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+ dequantized_weight = dequantize_weight(
+ unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
+ codebooks,
+ scales,
+ )
+ return F.linear(input, dequantized_weight, bias)
+
+
+# Generic dequantization, slow but flexible.
+def generic_dequantize_gemm(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+ output_shape = input.shape[:-1] + (scales.shape[0], )
+ output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
+ num_outputs = len(output_partition_sizes)
+
+ # break the inputs and codebooks apart then combine the outputs.
+ # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
+ # multiply at the end.
+ num_codebooks = codebooks.shape[0] // num_outputs
+ assert (scales.shape[0] == codes.shape[0])
+ assert (sum(output_partition_sizes) == scales.shape[0])
+ output_offset = 0
+ codebooks_offset = 0
+ for output_size in output_partition_sizes:
+ shard_output = dequantize_gemm(
+ input, codes.narrow(0, output_offset, output_size),
+ codebooks.narrow(0, codebooks_offset, num_codebooks),
+ scales.narrow(0, output_offset, output_size), None
+ if bias is None else bias.narrow(0, output_offset, output_size))
+
+ output_slice = output.narrow(-1, output_offset, output_size)
+ assert (output_slice.shape == shard_output.shape)
+ output_slice.copy_(shard_output)
+ output_offset += output_size
+ codebooks_offset += num_codebooks
+ return output
+
+
+# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
+# at 6 and 9 times faster than the generic version above, respectively.
+def optimized_dequantize_gemm(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ if bias is None:
+ # scaling the output is fastest, so we do that when possible.
+ output = F.linear(input, weights, bias)
+ orig_shape = output.shape
+ flattened_output = output.view(-1, output.size(-1))
+ f_scales = scales.view(-1, scales.shape[0])
+ b_scales = f_scales.expand(flattened_output.shape[0], -1)
+ flattened_output *= b_scales
+ return output.view(orig_shape)
+ else:
+ b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
+ -1, weights.shape[1])
+ weights *= b_scales
+ return F.linear(input, weights, bias)
+
+
+class AQLMConfig(QuantizationConfig):
+ """Config class for AQLM.
+
+ Reference: https://github.com/Vahe1994/AQLM
+ """
+
+ def __init__(
+ self,
+ in_group_size: int,
+ nbits_per_codebook: int,
+ num_codebooks: int,
+ out_group_size: int,
+ ) -> None:
+ self.in_group_size = in_group_size
+ self.nbits_per_codebook = nbits_per_codebook
+ self.num_codebooks = num_codebooks
+ self.out_group_size = out_group_size
+
+ # out_group_size > 1 is untested, and probably won't work as-is.
+ assert (self.out_group_size == 1)
+ self.pack_factor = (self.in_group_size * self.out_group_size)
+
+ def __repr__(self) -> str:
+ return (f"AQLMConfig(in_group_size={self.in_group_size}, "
+ f"nbits_per_codebook={self.nbits_per_codebook}, "
+ f"num_codebooks={self.num_codebooks}, "
+ f"out_group_size={self.out_group_size})")
+
+ @classmethod
+ def get_name(cls) -> str:
+ return "aqlm"
+
+ @classmethod
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+ return [torch.half]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ return 70
+
+ @classmethod
+ def get_config_filenames(cls) -> List[str]:
+ return [] # no extra configs.
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
+ in_group_size = cls.get_from_keys(config, ["in_group_size"])
+ nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
+ num_code_books = cls.get_from_keys(config, ["num_codebooks"])
+ out_group_size = cls.get_from_keys(config, ["out_group_size"])
+ return cls(in_group_size, nbits_per_codebook, num_code_books,
+ out_group_size)
+
+ def get_linear_method(self) -> "AQLMLinearMethod":
+ return AQLMLinearMethod(self)
+
+ def get_scaled_act_names(self) -> List[str]:
+ return []
+
+
+class AQLMLinearMethod(LinearMethodBase):
+ """Linear method for AQLM.
+
+ Args:
+ quant_config: The AQLM quantization config.
+ """
+
+ def __init__(self, quant_config: AQLMConfig):
+ self.quant_config = quant_config
+
+ def create_weights(self, layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_partition_sizes: List[int], input_size: int,
+ output_size: int, params_dtype: torch.dtype,
+ **extra_weight_attrs):
+ del output_size # Unused.
+ del input_size # Unused.
+
+ if params_dtype != torch.half:
+ raise ValueError("Only half is currently supported by aqlm")
+ if input_size_per_partition % self.quant_config.in_group_size != 0:
+ raise ValueError(
+ "The input size is not aligned with the quantized "
+ "weight shape. This can be caused by too large "
+ "tensor parallel size.")
+
+ output_size_per_partition = sum(output_partition_sizes)
+ if output_size_per_partition % self.quant_config.out_group_size != 0:
+ raise ValueError(
+ "The output size is not aligned with the quantized "
+ "weight shape. This can be caused by too large "
+ "tensor parallel size.")
+
+ codes = Parameter(
+ torch.empty(
+ # There could actually be two pack factors, one along input and
+ # one along output, but we don't currently support
+ # out_group_size, and only the one along output needs to be
+ # marked with "packed_dim" in order for QKVLinear to work.
+ output_size_per_partition,
+ input_size_per_partition // self.quant_config.pack_factor,
+ self.quant_config.num_codebooks,
+ dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
+ ),
+ requires_grad=False,
+ )
+
+ set_weight_attrs(
+ codes,
+ {
+ "input_dim": 1,
+ "output_dim": 0,
+ "packed_dim": 1,
+ "pack_factor": self.quant_config.pack_factor,
+ },
+ )
+
+ codebooks = Parameter(
+ torch.empty(
+ self.quant_config.num_codebooks * len(output_partition_sizes),
+ 2**self.quant_config.nbits_per_codebook,
+ self.quant_config.out_group_size,
+ self.quant_config.in_group_size,
+ dtype=params_dtype,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(
+ codebooks,
+ {
+ # metadata indicates fixed size concatenated along dim 0
+ "is_metadata":
+ True,
+ "output_partition_sizes":
+ torch.tensor(output_partition_sizes, device='cpu'),
+ },
+ )
+
+ scales = Parameter(
+ torch.empty(
+ (
+ output_size_per_partition //
+ self.quant_config.out_group_size,
+ 1,
+ 1,
+ 1,
+ ),
+ dtype=params_dtype,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(
+ scales,
+ {
+ "output_dim": 0,
+ "packed_dim": 0,
+ "pack_factor": self.quant_config.out_group_size
+ },
+ )
+
+ layer.register_parameter("codes", codes)
+ set_weight_attrs(codes, extra_weight_attrs)
+ layer.register_parameter("codebooks", codebooks)
+ set_weight_attrs(codebooks, extra_weight_attrs)
+ layer.register_parameter("scales", scales)
+ set_weight_attrs(scales, extra_weight_attrs)
+
+ def apply_weights(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ codebooks = layer.codebooks
+ codes = layer.codes
+ scales = layer.scales
+ output_partition_sizes = getattr(codebooks, "output_partition_sizes",
+ None)
+
+ nbooks = codes.shape[2]
+ ingroups = codebooks.shape[3]
+ outgroups = codebooks.shape[2]
+ bits = codebooks.shape[1]
+
+ # We support these formats with dedicated gemm and decompression
+ # kernels.
+ if ingroups == 8 and outgroups == 1 and (
+ (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
+
+ # thresholds determined by timings on an A6000, one GPU
+ use_gemv = math.prod(x.shape[:-1]) <= 6
+
+ return ops.aqlm_gemm(
+ x,
+ codes,
+ codebooks,
+ scales,
+ output_partition_sizes,
+ bias,
+ ) if use_gemv else optimized_dequantize_gemm(
+ x,
+ codes,
+ codebooks,
+ scales,
+ output_partition_sizes,
+ bias,
+ )
+
+ # fall back all unoptimized formats
+ return generic_dequantize_gemm(
+ x,
+ codes,
+ codebooks,
+ scales,
+ output_partition_sizes,
+ bias,
+ )
diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py
index 98651aed8be0e..4f75134ee1889 100644
--- a/vllm/model_executor/layers/quantization/awq.py
+++ b/vllm/model_executor/layers/quantization/awq.py
@@ -81,7 +81,7 @@ def __init__(self, quant_config: AWQConfig):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int, input_size: int,
+ output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
@@ -89,6 +89,8 @@ def create_weights(self, layer: torch.nn.Module,
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
+
+ output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index f370b94a210ee..92a5cdb9af928 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -91,7 +91,7 @@ def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int,
+ output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@@ -103,6 +103,7 @@ def create_weights(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
+ output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py
index bf0500f1155a1..00c3c404c2d7a 100644
--- a/vllm/model_executor/layers/quantization/marlin.py
+++ b/vllm/model_executor/layers/quantization/marlin.py
@@ -93,7 +93,7 @@ def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int,
+ output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
@@ -106,6 +106,7 @@ def create_weights(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
+ output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py
index 661ff9c55d0d1..cc44447d347b8 100644
--- a/vllm/model_executor/layers/quantization/squeezellm.py
+++ b/vllm/model_executor/layers/quantization/squeezellm.py
@@ -70,7 +70,7 @@ def __init__(self, quant_config: SqueezeLLMConfig):
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int, input_size: int,
+ output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
@@ -78,6 +78,8 @@ def create_weights(self, layer: torch.nn.Module,
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
+
+ output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
From 1e8f4252aa163041094a8fedbb701a98d7087d7c Mon Sep 17 00:00:00 2001
From: Cyrus Leung
Date: Wed, 24 Apr 2024 02:19:03 +0800
Subject: [PATCH 105/413] [Bugfix][Frontend] Raise exception when file-like
chat template fails to be opened (#4292)
---
tests/async_engine/test_chat_template.py | 19 ++++++++++++++-----
vllm/entrypoints/openai/serving_chat.py | 23 +++++++++++++++--------
2 files changed, 29 insertions(+), 13 deletions(-)
diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py
index 6972ae1dee4a1..8d6ad6706fb0e 100644
--- a/tests/async_engine/test_chat_template.py
+++ b/tests/async_engine/test_chat_template.py
@@ -76,20 +76,29 @@ def test_load_chat_template():
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
-def test_no_load_chat_template():
+def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()
+ mock_serving_chat = MockServingChat(tokenizer)
+
+ with pytest.raises(ValueError, match="looks like a file path"):
+ OpenAIServingChat._load_chat_template(mock_serving_chat,
+ chat_template=template)
+
+
+def test_no_load_chat_template_literallike():
+ # Testing chatml template
+ template = "{{ messages }}"
+ tokenizer = MockTokenizer()
+
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template
- # Test assertions
- assert template_content is not None
- # Hard coded value for template_chatml.jinja
- assert template_content == """../../examples/does_not_exist"""
+ assert template_content == template
@pytest.mark.asyncio
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index d502dd0a4eb75..2ff335eb71073 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -319,23 +319,30 @@ async def chat_completion_full_generator(
return response
def _load_chat_template(self, chat_template):
+ tokenizer = self.tokenizer
+
if chat_template is not None:
try:
with open(chat_template, "r") as f:
- self.tokenizer.chat_template = f.read()
- except OSError:
+ tokenizer.chat_template = f.read()
+ except OSError as e:
+ JINJA_CHARS = "{}\n"
+ if not any(c in chat_template for c in JINJA_CHARS):
+ msg = (f"The supplied chat template ({chat_template}) "
+ f"looks like a file path, but it failed to be "
+ f"opened. Reason: {e}")
+ raise ValueError(msg) from e
+
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
- self.tokenizer.chat_template = codecs.decode(
+ tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info(
- f"Using supplied chat template:\n{self.tokenizer.chat_template}"
- )
- elif self.tokenizer.chat_template is not None:
+ f"Using supplied chat template:\n{tokenizer.chat_template}")
+ elif tokenizer.chat_template is not None:
logger.info(
- f"Using default chat template:\n{self.tokenizer.chat_template}"
- )
+ f"Using default chat template:\n{tokenizer.chat_template}")
else:
logger.warning(
"No chat template provided. Chat API will not work.")
From eace8bf0b9118877c390e6d490502214c39db132 Mon Sep 17 00:00:00 2001
From: Philipp Moritz
Date: Tue, 23 Apr 2024 18:18:23 -0700
Subject: [PATCH 106/413] [Kernel] FP8 support for MoE kernel / Mixtral (#4244)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
---
CMakeLists.txt | 1 +
csrc/ops.h | 5 +
csrc/pybind.cpp | 1 +
csrc/quantization/fp8/fp8_cuda_kernels.cu | 103 ++++++++++++
vllm/_custom_ops.py | 10 +-
...me=NVIDIA_H100_80GB_HBM3,dtype=float8.json | 146 ++++++++++++++++++
.../layers/fused_moe/fused_moe.py | 93 ++++++++---
vllm/model_executor/model_loader/loader.py | 2 +
vllm/model_executor/model_loader/utils.py | 1 +
vllm/model_executor/models/mixtral.py | 44 +++++-
10 files changed, 385 insertions(+), 21 deletions(-)
create mode 100644 csrc/quantization/fp8/fp8_cuda_kernels.cu
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b2d0cf3e568b7..4a99985d9abc4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
+ "csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
diff --git a/csrc/ops.h b/csrc/ops.h
index a379c910d9cf3..ff7a3de1a0a8c 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -146,6 +146,11 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);
+void scaled_fp8_quant(
+ torch::Tensor& out,
+ torch::Tensor& input,
+ torch::Tensor& scale);
+
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp
index 42e92e5382e8e..a5b16c5abc3ed 100644
--- a/csrc/pybind.cpp
+++ b/csrc/pybind.cpp
@@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
+ ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu
new file mode 100644
index 0000000000000..c3337cede1282
--- /dev/null
+++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu
@@ -0,0 +1,103 @@
+#include
+#include
+#include
+
+#include
+
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+
+namespace vllm {
+
+__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
+ float old;
+ old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
+ __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
+
+ return old;
+}
+
+// Compute the absolute maximum m of the input tensor and store
+// m / float8_e4m3::max() in *scale. Each thread block performs a
+// reduction tree and the memory in scale is atomically updated.
+// So to get the right answer, *scale needs to be initialized to
+// a value <= 0.0 and we need to wait for all thread blocks to
+// finish before consuming *scale.
+template
+__global__ void segmented_max_reduction(
+ float* __restrict__ scale,
+ const scalar_t* __restrict__ input,
+ int64_t num_elems) {
+ __shared__ float cache[1024];
+ int i = blockDim.x * blockIdx.x + threadIdx.x;
+
+ // First store maximum for all values processes by
+ // the current thread in cache[threadIdx.x]
+ scalar_t tmp = 0.0;
+ while (i < num_elems) {
+ float x = static_cast(input[i]);
+ tmp = max(tmp, fabs(x));
+ i += blockDim.x * gridDim.x;
+ }
+ cache[threadIdx.x] = tmp;
+
+ __syncthreads();
+
+ // Now perform parallel reduction within the thread block
+ int ib = blockDim.x / 2;
+ while (ib != 0) {
+ if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
+ cache[threadIdx.x] = cache[threadIdx.x + ib];
+ }
+ __syncthreads();
+ ib /= 2;
+ }
+ // Finally, since cache[0] contains the maximum for this thread block,
+ // atomically write the max to the target location
+ if (threadIdx.x == 0) {
+ atomicMaxFloat(scale, cache[0] / std::numeric_limits::max());
+ }
+}
+
+template
+__global__ void scaled_fp8_quant_kernel(
+ c10::Float8_e4m3fn* __restrict__ out,
+ const scalar_t* __restrict__ input,
+ const float* __restrict__ scale,
+ int64_t num_elems) {
+ int i = blockDim.x * blockIdx.x + threadIdx.x;
+ while (i < num_elems) {
+ out[i] = static_cast(input[i] / *scale);
+ i += blockDim.x * gridDim.x;
+ }
+}
+
+} // namespace vllm
+
+void scaled_fp8_quant(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input, // [..., d]
+ torch::Tensor& scale) // [1]
+{
+ int64_t num_tokens = input.numel() / input.size(-1);
+ int64_t num_elems = input.numel();
+ dim3 grid(num_tokens);
+ dim3 block(1024);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(),
+ "scaled_fp8_quant_kernel",
+ [&] {
+ vllm::segmented_max_reduction<<>>(
+ scale.data_ptr(),
+ input.data_ptr(),
+ num_elems);
+ vllm::scaled_fp8_quant_kernel<<>>(
+ out.data_ptr(),
+ input.data_ptr(),
+ scale.data_ptr(),
+ num_elems);
+ });
+}
+
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index a0837a20875fe..e4b16ed918d1a 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -1,4 +1,4 @@
-from typing import Dict, Optional
+from typing import Dict, Optional, Tuple
import torch
@@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
+# fp8
+def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+ output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
+ vllm_ops.scaled_fp8_quant(output, input, scale)
+ return output, scale
+
+
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
new file mode 100644
index 0000000000000..2ad07bf79a25c
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
@@ -0,0 +1,146 @@
+{
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "48": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "64": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "96": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "128": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "256": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "512": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 64,
+ "num_warps": 8,
+ "num_stages": 4
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 377b6588dbf47..ac7c30e2a9727 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -21,6 +21,8 @@ def fused_moe_kernel(
a_ptr,
b_ptr,
c_ptr,
+ a_scale_ptr,
+ b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
@@ -49,6 +51,7 @@ def fused_moe_kernel(
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
+ use_fp8: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
@@ -111,6 +114,10 @@ def fused_moe_kernel(
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
+ if use_fp8:
+ a_scale = tl.load(a_scale_ptr)
+ b_scale = tl.load(b_scale_ptr + off_experts)
+
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@@ -129,7 +136,10 @@ def fused_moe_kernel(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
- accumulator += tl.dot(a, b)
+ if use_fp8:
+ accumulator = tl.dot(a, b, acc=accumulator)
+ else:
+ accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -140,7 +150,10 @@ def fused_moe_kernel(
other=0)
accumulator = accumulator * moe_weight[:, None]
- accumulator = accumulator.to(compute_type)
+ if use_fp8:
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
+ else:
+ accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -207,15 +220,24 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
+ B_scale: torch.Tensor, topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
- config: Dict[str, Any]) -> None:
+ config: Dict[str, Any], compute_type: tl.dtype,
+ use_fp8: bool) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
+ if not use_fp8:
+ A_scale = None
+ assert B_scale is None
+ else:
+ A, A_scale = ops.scaled_fp8_quant(A)
+ assert B_scale is not None
+
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
@@ -223,6 +245,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A,
B,
C,
+ A_scale,
+ B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -240,18 +264,21 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
- compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
+ compute_type=compute_type,
+ use_fp8=use_fp8,
**config,
)
-def get_config_file_name(E: int, N: int) -> str:
+def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
- return f"E={E},N={N},device_name={device_name}.json"
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
@functools.lru_cache
-def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
+def get_moe_configs(E: int, N: int,
+ dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
@@ -263,7 +290,7 @@ def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
# First look up if an optimized configuration is available in the configs
# directory
- json_file_name = get_config_file_name(E, N)
+ json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
@@ -288,6 +315,9 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
+ use_fp8: bool = False,
+ w1_scale: Optional[torch.Tensor] = None,
+ w2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -305,6 +335,12 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
+ - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
+ products for w1 and w2. Defaults to False.
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
+ w1.
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
+ w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@@ -358,7 +394,8 @@ def fused_moe(
config = override_config
else:
# First try to load optimal config from the file
- configs = get_moe_configs(E, w2.shape[2])
+ configs = get_moe_configs(E, w2.shape[2],
+ "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
@@ -394,17 +431,37 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
- invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
- topk_weights, topk_ids, sorted_token_ids,
- expert_ids, num_tokens_post_padded, False,
- topk_ids.shape[1], config)
+ invoke_fused_moe_kernel(hidden_states,
+ w1,
+ intermediate_cache1,
+ w1_scale,
+ topk_weights,
+ topk_ids,
+ sorted_token_ids,
+ expert_ids,
+ num_tokens_post_padded,
+ False,
+ topk_ids.shape[1],
+ config,
+ compute_type=tl.float16,
+ use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
- invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3,
- topk_weights, topk_ids, sorted_token_ids,
- expert_ids, num_tokens_post_padded, True, 1,
- config)
+ invoke_fused_moe_kernel(intermediate_cache2,
+ w2,
+ intermediate_cache3,
+ w2_scale,
+ topk_weights,
+ topk_ids,
+ sorted_token_ids,
+ expert_ids,
+ num_tokens_post_padded,
+ True,
+ 1,
+ config,
+ compute_type=tl.float16,
+ use_fp8=use_fp8)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 64cd186506bdb..f75c35a69d4a9 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -232,6 +232,8 @@ def load_model(self, *, model_config: ModelConfig,
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
+ if hasattr(module, "process_weights_after_loading"):
+ module.process_weights_after_loading()
return model.eval()
diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py
index a0a3b2784614d..f7e0f56c1a46e 100644
--- a/vllm/model_executor/model_loader/utils.py
+++ b/vllm/model_executor/model_loader/utils.py
@@ -24,6 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
+ and model_config.quantization != "fp8"
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index 4d1755f2bbe63..a33b795d7088e 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -39,6 +39,8 @@
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
+ per_tensor_quantize)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -47,6 +49,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
+from vllm.utils import print_warning_once
class MixtralMoE(nn.Module):
@@ -66,6 +69,7 @@ def __init__(
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
+ linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
@@ -73,6 +77,9 @@ def __init__(
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
+ # FIXME(pcmoritz): Make this more general to support different
+ # quantization schemes
+ self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -97,6 +104,16 @@ def __init__(
device="cuda",
dtype=self.params_dtype))
+ # Scaling factors for FP8 weights
+ self.ws_scale = nn.Parameter(
+ torch.ones(
+ self.num_total_experts, device="cuda", dtype=torch.float32),
+ requires_grad=False) if self.use_fp8 else None
+ self.w2s_scale = nn.Parameter(
+ torch.ones(
+ self.num_total_experts, device="cuda", dtype=torch.float32),
+ requires_grad=False) if self.use_fp8 else None
+
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
@@ -118,6 +135,18 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
+ def process_weights_after_loading(self):
+ if self.use_fp8:
+ ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
+ w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
+ for expert in range(self.num_total_experts):
+ ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
+ self.ws.data[expert, :, :])
+ w2s[expert, :, :], self.w2s_scale[
+ expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
+ self.ws = nn.Parameter(ws, requires_grad=False)
+ self.w2s = nn.Parameter(w2s, requires_grad=False)
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
@@ -129,7 +158,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits,
self.top_k,
renormalize=True,
- inplace=True)
+ inplace=True,
+ use_fp8=self.use_fp8,
+ w1_scale=self.ws_scale,
+ w2_scale=self.w2s_scale)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -171,6 +203,13 @@ def __init__(self,
self.rope_theta = rope_theta
self.sliding_window = sliding_window
+ if isinstance(linear_method, Fp8LinearMethod):
+ print_warning_once(
+ "For Mixtral FP8 quantization, we currently do not quantize "
+ "the attention layers until their FP8 performance is improved."
+ )
+ linear_method = None
+
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
@@ -238,7 +277,8 @@ def __init__(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size)
+ intermediate_size=config.intermediate_size,
+ linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
From 79a268c4ab2cbf44280eebd998b8efc383bac216 Mon Sep 17 00:00:00 2001
From: Robert Shaw
<114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Date: Tue, 23 Apr 2024 21:26:33 -0400
Subject: [PATCH 107/413] [BUG] fixed fp8 conflict with aqlm (#4307)
Fixes fp8 iterface which broke in AQLM merge.
---
.buildkite/test-pipeline.yaml | 3 +++
vllm/model_executor/layers/linear.py | 16 +++++++++++++---
vllm/model_executor/layers/quantization/fp8.py | 3 ++-
3 files changed, 18 insertions(+), 4 deletions(-)
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index f7c1569696249..11cda053260ec 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -96,6 +96,9 @@ steps:
- label: Metrics Test
command: pytest -v -s metrics
+- label: Quantization Test
+ command: pytest -v -s quantization
+
- label: Benchmarks
working_dir: "/vllm-workspace/.buildkite"
commands:
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index e56af9075e2fd..6ad7ae0f63197 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -34,9 +34,19 @@ def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
- """Create weights for a linear layer.
-
- The weights will be set as attributes of the layer."""
+ """Create weights for a linear layer.
+ The weights will be set as attributes of the layer.
+
+ Args:
+ layer: The layer that is using the LinearMethodBase factory.
+ input_size_per_partition: Size of the weight input dim on rank X.
+ output_partition_sizes: Sizes of the output dim of each logical
+ weight on rank X. E.g., output_partition_sizes for QKVLinear
+ is a list contains the width of Wq, Wk, Wv on rank X.
+ input_size: Size of the input dim of the weight across all ranks.
+ output_size: Size of the output dim of the weight across all ranks.
+ params_dtype: Datatype of the parameters.
+ """
raise NotImplementedError
@abstractmethod
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 8df82e0e18edd..01e494c870e71 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -64,12 +64,13 @@ def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
- output_size_per_partition: int,
+ output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
+ output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
From 91f50a6fe240b2c92a99e171bb11d083f82e4a84 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Tue, 23 Apr 2024 18:32:19 -0700
Subject: [PATCH 108/413] [Core][Distributed] use cpu/gloo to initialize pynccl
(#4248)
---
tests/distributed/test_pynccl.py | 15 ++-
.../device_communicators/pynccl.py | 122 ++++++++++--------
.../device_communicators/pynccl_utils.py | 12 +-
vllm/distributed/parallel_state.py | 6 +
vllm/worker/worker.py | 9 +-
5 files changed, 93 insertions(+), 71 deletions(-)
diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py
index d58f621d36b86..6d7d4a5806bd0 100644
--- a/tests/distributed/test_pynccl.py
+++ b/tests/distributed/test_pynccl.py
@@ -5,6 +5,7 @@
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
+from vllm.distributed.parallel_state import init_distributed_environment
from vllm.utils import update_environment_variables
@@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
for p in processes:
p.join()
+ for p in processes:
+ assert p.exitcode == 0
+
-def update_env(fn):
+def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
- def wrapper(env):
+ def wrapped_fn(env):
update_environment_variables(env)
+ init_distributed_environment()
fn()
- return wrapper
+ return wrapped_fn
-@update_env
+@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
@@ -53,7 +58,7 @@ def test_pynccl():
distributed_run(worker_fn, 2)
-@update_env
+@worker_fn_wrapper
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
index 0707afe922f40..fcedf0fed34cb 100644
--- a/vllm/distributed/device_communicators/pynccl.py
+++ b/vllm/distributed/device_communicators/pynccl.py
@@ -20,14 +20,15 @@
# variable in the code.
import ctypes
-import datetime
import platform
+from typing import Optional, Union
# ===================== import region =====================
import torch
import torch.distributed as dist
-from torch.distributed import ReduceOp
+from torch.distributed import ProcessGroup, ReduceOp
+from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
@@ -59,6 +60,18 @@
ncclResult_t = ctypes.c_int
+_c_ncclGetErrorString = nccl.ncclGetErrorString
+_c_ncclGetErrorString.restype = ctypes.c_char_p
+_c_ncclGetErrorString.argtypes = [ncclResult_t]
+
+
+def NCCL_CHECK(result: ncclResult_t) -> None:
+ if result != 0:
+ error_str = _c_ncclGetErrorString(result)
+ error_str = error_str.decode("utf-8")
+ raise RuntimeError(f"NCCL error: {error_str}")
+
+
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
@@ -68,8 +81,7 @@
def ncclGetVersion() -> str:
version = ctypes.c_int()
- result = _c_ncclGetVersion(ctypes.byref(version))
- assert result == 0
+ NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
@@ -91,8 +103,7 @@ class NcclUniqueId(ctypes.Structure):
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
- result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
- assert result == 0
+ NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
return unique_id
@@ -199,66 +210,75 @@ class NCCLCommunicator:
def __init__(
self,
- backend=None,
- init_method=None,
- timeout=datetime.timedelta(seconds=10),
- world_size: int = -1,
- rank: int = -1,
- store=None,
- group_name: str = "",
- pg_options=None,
- local_rank: int = -1,
+ group: Optional[ProcessGroup] = None,
+ device: Optional[Union[int, str, torch.device]] = None,
):
- if not dist.is_initialized():
- backend = backend or "nccl"
- assert backend == 'nccl', (
- "only use nccl backend for starting the NCCL communicator")
- dist.init_process_group(backend=backend,
- init_method=init_method,
- timeout=timeout,
- world_size=world_size,
- rank=rank,
- store=store,
- group_name=group_name,
- pg_options=pg_options)
- self.rank = dist.get_rank()
- self.world_size = dist.get_world_size()
- if local_rank == -1:
- local_rank = self.rank
- self.local_rank = local_rank
- # don't use these args, as they can be -1
- # use `self.rank`, `self.local_rank` and `self.world_size` instead
- del world_size, rank, local_rank
- torch.cuda.set_device(self.local_rank)
+ """
+ Args:
+ group: the process group to work on. If None, it will use the
+ default process group.
+ device: the device to bind the NCCLCommunicator to. If None,
+ it will be bind to f"cuda:{local_rank}".
+ It is the caller's responsibility to make sure each communicator
+ is bind to a unique device.
+ """
+ assert dist.is_initialized()
+ group = get_cpu_world_group() if group is None else group
+ assert dist.get_backend(group) != dist.Backend.NCCL, (
+ "NCCLCommunicator should be attached to a non-NCCL group.")
+ self.group = group
+ self.rank = dist.get_rank(group)
+ self.world_size = dist.get_world_size(group)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
- tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
- self.local_rank)
- dist.broadcast(tensor, src=0)
- byte_list = tensor.cpu().tolist()
+ tensor = torch.ByteTensor(list(self.unique_id.internal))
+ dist.broadcast(tensor, src=0, group=group)
+ byte_list = tensor.tolist()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
- result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
- self.unique_id, self.rank)
- assert result == 0
- self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
+ if device is None:
+ local_rank = get_local_rank()
+ device = torch.device(f"cuda:{local_rank}")
+ elif isinstance(device, int):
+ device = torch.device(f"cuda:{device}")
+ elif isinstance(device, str):
+ device = torch.device(device)
+ # now `device` is a `torch.device` object
+ assert isinstance(device, torch.device)
+ self.device = device
+ # nccl communicator and stream will use this device
+ current_device = torch.cuda.current_device()
+ try:
+ torch.cuda.set_device(device)
+ NCCL_CHECK(
+ _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
+ self.unique_id, self.rank))
+ self.stream = torch.cuda.Stream()
+ finally:
+ torch.cuda.set_device(current_device)
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
+ # nccl communicator created on a specific device
+ # will only work on tensors on the same device
+ # otherwise it will cause "illegal memory access"
+ assert tensor.device == self.device, (
+ f"this nccl communicator is created to work on {self.device}, "
+ f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
- result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
- ctypes.c_void_p(tensor.data_ptr()),
- tensor.numel(),
- ncclDataTypeEnum.from_torch(tensor.dtype),
- ncclRedOpTypeEnum.from_torch(op), self.comm,
- ctypes.c_void_p(stream.cuda_stream))
- assert result == 0
+ NCCL_CHECK(
+ _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
+ ctypes.c_void_p(tensor.data_ptr()),
+ tensor.numel(),
+ ncclDataTypeEnum.from_torch(tensor.dtype),
+ ncclRedOpTypeEnum.from_torch(op), self.comm,
+ ctypes.c_void_p(stream.cuda_stream)))
def __del__(self):
# `dist` module might have been already destroyed
diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py
index 916dc814af7eb..a717fddb695ba 100644
--- a/vllm/distributed/device_communicators/pynccl_utils.py
+++ b/vllm/distributed/device_communicators/pynccl_utils.py
@@ -2,7 +2,7 @@
from typing import Optional
import torch
-from torch.distributed import ReduceOp
+from torch.distributed import ProcessGroup, ReduceOp
from vllm.logger import init_logger
@@ -37,17 +37,11 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
pass
-def init_process_group(world_size: int,
- rank: int,
- init_method: str,
- local_rank: int = -1) -> None:
+def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
- comm = NCCLCommunicator(init_method=init_method,
- world_size=world_size,
- local_rank=local_rank,
- rank=rank)
+ comm = NCCLCommunicator(group=group)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index e2473736375e0..515f2212511b7 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -4,6 +4,7 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
import contextlib
+import os
from typing import Optional
import torch
@@ -73,6 +74,11 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
+ # set the local rank
+ # local_rank is not available in torch ProcessGroup,
+ # see https://github.com/pytorch/pytorch/issues/122816
+ if local_rank == -1 and distributed_init_method == "env://":
+ local_rank = int(os.environ['LOCAL_RANK'])
global _LOCAL_RANK
_LOCAL_RANK = local_rank
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 2203570b37ad6..39ad428f16fe3 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -298,12 +298,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
- pynccl_utils.init_process_group(
- world_size=parallel_config.world_size,
- local_rank=local_rank,
- rank=rank,
- init_method=distributed_init_method,
- )
+ # NOTE(kaichao): By default, pynccl will use information inside
+ # `parallel_state` for initialization.
+ pynccl_utils.init_process_group()
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
From e4bf860a54a302ccb2d80489368d5df686e46923 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Tue, 23 Apr 2024 18:33:12 -0700
Subject: [PATCH 109/413] [CI][Build] change pynvml to nvidia-ml-py (#4302)
---
requirements-cuda.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements-cuda.txt b/requirements-cuda.txt
index c6d2cd46aee54..1bddae4c6f40f 100644
--- a/requirements-cuda.txt
+++ b/requirements-cuda.txt
@@ -3,7 +3,7 @@
# Dependencies for NVIDIA GPUs
ray >= 2.9
-pynvml == 11.5.0
+nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1
xformers == 0.0.25 # Requires PyTorch 2.2.1
From 468d761b32e3b3c5d64eeaa797e54ab809b7e50c Mon Sep 17 00:00:00 2001
From: Woosuk Kwon
Date: Tue, 23 Apr 2024 18:54:33 -0700
Subject: [PATCH 110/413] [Misc] Reduce supported Punica dtypes (#4304)
---
CMakeLists.txt | 12 --------
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu | 4 ---
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu | 4 ---
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu | 4 ---
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu | 4 ---
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu | 4 ---
csrc/punica/bgmv/generator.py | 20 ++++++++++++
csrc/punica/punica_ops.cc | 17 ++++++++++
tests/lora/test_layers.py | 41 +++++++++++++++++--------
16 files changed, 66 insertions(+), 72 deletions(-)
delete mode 100644 csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
delete mode 100644 csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4a99985d9abc4..e9262b57d0867 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -212,23 +212,11 @@ define_gpu_extension_target(
set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
- "csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
- "csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
- "csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
- "csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
- "csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
- "csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
- "csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
- "csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
- "csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
- "csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
- "csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
- "csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
"csrc/punica/punica_ops.cc")
#
diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
deleted file mode 100644
index e8202dff561d9..0000000000000
--- a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
deleted file mode 100644
index 3e7cf31dead0f..0000000000000
--- a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
deleted file mode 100644
index 68277fa6b7d56..0000000000000
--- a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
deleted file mode 100644
index 3b7531b8fbcfc..0000000000000
--- a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
deleted file mode 100644
index b3b74aa3ec904..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
deleted file mode 100644
index 3cc87f5df76a1..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
deleted file mode 100644
index 9eda98bd8ddcf..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
deleted file mode 100644
index 060f9ebb8c2b1..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
deleted file mode 100644
index b37e44570bf40..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
deleted file mode 100644
index 06718cbb0a3e9..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
deleted file mode 100644
index 41fb0e45ef4e6..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
deleted file mode 100644
index 50b7ead9fcefd..0000000000000
--- a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
+++ /dev/null
@@ -1,4 +0,0 @@
-#include "bgmv_config.h"
-#include "bgmv_impl.cuh"
-
-FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py
index c347d4f2ab9f4..9bf7f6358880f 100644
--- a/csrc/punica/bgmv/generator.py
+++ b/csrc/punica/bgmv/generator.py
@@ -18,6 +18,26 @@
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
+ if output_dtype == "fp32":
+ # LoRA A matrix.
+ if input_dtype != weight_dtype:
+ # NOTE(woosuk): While Punica supports the case where the
+ # input and weight dtypes are different, we only generate
+ # the kernels the same dtypes to reduce the binary size.
+ continue
+ elif input_dtype == "fp32":
+ # LoRA B matrix.
+ if output_dtype != weight_dtype:
+ # NOTE(woosuk): While Punica supports the case where the
+ # output and weight dtypes are different, we only generate
+ # the kernels the same dtypes to reduce the binary size.
+ continue
+ elif not (input_dtype == output_dtype == weight_dtype):
+ # NOTE(woosuk): While Punica supports mixed data types for
+ # input, output, and weight, we only generate the kernels with
+ # the same data types to reduce the binary size.
+ continue
+
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc
index 7ebfd851c4feb..a1eaa90e85f27 100644
--- a/csrc/punica/punica_ops.cc
+++ b/csrc/punica/punica_ops.cc
@@ -50,6 +50,23 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
+ // NOTE(woosuk): While Punica supports various combinations of input/output
+ // data types, we limit the supported data types to reduce the binary size.
+ constexpr bool is_input_float = std::is_same::value;
+ constexpr bool is_output_float = std::is_same::value;
+ if (is_input_float) {
+ if (!std::is_same::value) {
+ return false;
+ }
+ } else if (is_output_float) {
+ if (!std::is_same::value) {
+ return false;
+ }
+ } else if (!(std::is_same::value &&
+ std::is_same::value)) {
+ return false;
+ }
+
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u32(feat_in, feat_out): \
diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py
index e9e0c8554c1ef..1616fdfd4cff9 100644
--- a/tests/lora/test_layers.py
+++ b/tests/lora/test_layers.py
@@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
def _pretest():
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
- 1024, vocab_size)
+ 1024,
+ vocab_size,
+ params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, vocab_size:] = 0
logits_processor = LogitsProcessor(
@@ -445,7 +447,7 @@ def _pretest():
num_inputs=8 * num_loras, # * 3,
input_size=(1, 1024),
input_range=(0, 1),
- input_type=torch.float32,
+ input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@@ -494,7 +496,7 @@ def _pretest():
num_inputs=8 * num_loras * 3,
input_size=(1, 1024),
input_range=(0, 1),
- input_type=torch.float32,
+ input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
def create_random_linear_parallel_layer():
if orientation == "row":
- linear = RowParallelLinear(4096, 4096, bias=False)
+ linear = RowParallelLinear(4096,
+ 4096,
+ bias=False,
+ params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = RowParallelLinearWithLoRA(linear)
else:
- linear = ColumnParallelLinear(4096, 4096, bias=False)
+ linear = ColumnParallelLinear(4096,
+ 4096,
+ bias=False,
+ params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ColumnParallelLinearWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
@@ -561,7 +569,7 @@ def create_random_linear_parallel_layer():
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
- input_type=torch.float32,
+ input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@@ -600,7 +608,7 @@ def create_random_linear_parallel_layer():
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
- input_type=torch.float32,
+ input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
def create_column_parallel_packed_layer():
if repeats == 2:
linear = MergedColumnParallelLinear(4096, [4096] * repeats,
- bias=False)
+ bias=False,
+ params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
elif repeats == 3:
- linear = QKVParallelLinear(4096, 64, 32, bias=False)
+ linear = QKVParallelLinear(4096,
+ 64,
+ 32,
+ bias=False,
+ params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedQKVParallelLinearWithLora(linear)
else:
- linear = QKVParallelLinear(4096, 64, 32, bias=False)
+ linear = QKVParallelLinear(4096,
+ 64,
+ 32,
+ bias=False,
+ params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(linear)
@@ -676,7 +693,7 @@ class FakeConfig:
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
- input_type=torch.float32,
+ input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@@ -716,7 +733,7 @@ class FakeConfig:
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
- input_type=torch.float32,
+ input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
From 3cd9b5bb2d4a0d5eed07186ae140f5dc8f839708 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 24 Apr 2024 09:00:20 -0700
Subject: [PATCH 111/413] [Core][Distributed] use existing torch.cuda.device
(#4318)
[Core][Distributed] use existing torch.cuda.device context manager (#4318)
---
vllm/distributed/device_communicators/pynccl.py | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
index fcedf0fed34cb..e922beba44bfa 100644
--- a/vllm/distributed/device_communicators/pynccl.py
+++ b/vllm/distributed/device_communicators/pynccl.py
@@ -250,15 +250,13 @@ def __init__(
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
- current_device = torch.cuda.current_device()
- try:
- torch.cuda.set_device(device)
+ # `torch.cuda.device` is a context manager that changes the
+ # current cuda device to the specified one
+ with torch.cuda.device(device):
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
- finally:
- torch.cuda.set_device(current_device)
def all_reduce(self,
tensor: torch.Tensor,
From 7923dcad12edae7dcfd6e0cf7ce2984b3bcecf0f Mon Sep 17 00:00:00 2001
From: Roger Wang <136131678+ywang96@users.noreply.github.com>
Date: Wed, 24 Apr 2024 09:49:13 -0700
Subject: [PATCH 112/413] [Misc] Update ShareGPT Dataset Sampling in Serving
Benchmark (#4279)
---
benchmarks/benchmark_serving.py | 50 ++++++++++++++++++---------------
1 file changed, 28 insertions(+), 22 deletions(-)
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index 6054df439fa57..2c2d69da4a7d1 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -27,7 +27,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
-from typing import AsyncGenerator, List, Tuple
+from typing import AsyncGenerator, List, Optional, Tuple
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
@@ -58,7 +58,11 @@ def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
+ fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
+ if fixed_output_len is not None and fixed_output_len < 4:
+ raise ValueError("output_len too small")
+
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
@@ -68,38 +72,32 @@ def sample_sharegpt_requests(
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
- # some of these will be filtered out, so sample more than we need
- sampled_indices = random.sample(range(len(dataset)),
- int(num_requests * 1.2))
- dataset = [dataset[i] for i in sampled_indices]
-
- # Tokenize the prompts and completions.
- prompts = [prompt for prompt, _ in dataset]
- prompt_token_ids = tokenizer(prompts).input_ids
- completions = [completion for _, completion in dataset]
- completion_token_ids = tokenizer(completions).input_ids
- tokenized_dataset = []
- for i in range(len(dataset)):
- output_len = len(completion_token_ids[i])
- tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
+ # Shuffle the dataset.
+ random.shuffle(dataset)
- # Filter out too long sequences.
+ # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
- for prompt, prompt_token_ids, output_len in tokenized_dataset:
+ for i in range(len(dataset)):
+ if len(filtered_dataset) == num_requests:
+ break
+
+ # Tokenize the prompts and completions.
+ prompt = dataset[i][0]
+ prompt_token_ids = tokenizer(prompt).input_ids
+ completion = dataset[i][1]
+ completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
+ output_len = len(completion_token_ids
+ ) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
- # This is because TGI causes errors when the input or output length
- # is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
- # Sample the requests.
- sampled_requests = random.sample(filtered_dataset, num_requests)
- return sampled_requests
+ return filtered_dataset
def sample_sonnet_requests(
@@ -361,6 +359,7 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset,
num_requests=args.num_prompts,
tokenizer=tokenizer,
+ fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "sharegpt":
@@ -368,6 +367,7 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
+ fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "sonnet":
@@ -524,6 +524,12 @@ def main(args: argparse.Namespace):
default=1000,
help="Number of prompts to process.",
)
+ parser.add_argument(
+ "--sharegpt-output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the output length "
+ "from the ShareGPT dataset.")
parser.add_argument(
"--sonnet-input-len",
type=int,
From aae08249acca69060d0a8220cab920e00520932c Mon Sep 17 00:00:00 2001
From: alexm-nm <59768536+alexm-nm@users.noreply.github.com>
Date: Wed, 24 Apr 2024 13:35:01 -0400
Subject: [PATCH 113/413] [Bugfix] Fix marlin kernel crash on H100 (#4218)
This PR addresses the Marlin kernel H100 crash that was reported here: neuralmagic#187.
The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one.
---
.../quantization/marlin/marlin_cuda_kernel.cu | 23 +++++++------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu
index cf1b0afdec8b4..002a70001885d 100644
--- a/csrc/quantization/marlin/marlin_cuda_kernel.cu
+++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu
@@ -67,20 +67,13 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
-// Asynchronous global->shared copy with a cache hint indicating that the values
-// may be evicted immediately; used for quantized weights B, which are only
-// accessed precisely once and should thus not pollute the L2 cache which we
-// need for inputs A and outputs C.
-__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) {
+// Asynchronous global->shared copy
+__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
- asm volatile(
- "{\n"
- " .reg .b64 p;\n"
- " createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
- " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
- "}\n" ::"r"(smem),
- "l"(glob_ptr), "n"(BYTES));
+ asm volatile("{\n"
+ " cp.async.cg.shared.global [%0], [%1], %2;\n"
+ "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
@@ -448,14 +441,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
- cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
+ cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
B_ptr[i] += b_gl_rd_delta_o;
}
// Only fetch scales if this tile starts a new group
if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) {
int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
if (s_sh_wr_pred)
- cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
+ cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]);
s_gl_rd += s_gl_rd_delta;
}
}
@@ -750,7 +743,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// write-out
if (group_blocks == -1 && last) {
if (s_sh_wr_pred)
- cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]);
+ cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]);
cp_async_fence();
}
thread_block_reduce();
From 2768884ac4a026609efceef92edea55839af0c30 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Wed, 24 Apr 2024 14:09:44 -0700
Subject: [PATCH 114/413] [Doc] Add note for docker user (#4340)
Co-authored-by: Simon Mo
---
docs/source/serving/deploying_with_docker.rst | 3 +++
1 file changed, 3 insertions(+)
diff --git a/docs/source/serving/deploying_with_docker.rst b/docs/source/serving/deploying_with_docker.rst
index 7ec769630300d..cfc462ff33b90 100644
--- a/docs/source/serving/deploying_with_docker.rst
+++ b/docs/source/serving/deploying_with_docker.rst
@@ -49,3 +49,6 @@ To run vLLM:
--env "HUGGING_FACE_HUB_TOKEN=" \
vllm/vllm-openai
+.. note::
+
+ vLLM docker image is currently designed to be run under the root user (contribution welcomed for changing this!). It will try to load library at runtime under the root user's home directory, e.g. `/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` . If you are running the container under a different user, you may need to change the permissions of the library (and all the parent directories) to allow the user to access it. Then run vLLM with environment variable `VLLM_NCCL_SO_PATH=/root/.config/vllm/nccl/cu12/libnccl.so.2.18.1` .
From a395a638c2f18d549e7d01655cf7a6dbee566f91 Mon Sep 17 00:00:00 2001
From: zifeitong
Date: Wed, 24 Apr 2024 14:10:24 -0700
Subject: [PATCH 115/413] [Misc] Use public API in benchmark_throughput (#4300)
---
benchmarks/benchmark_throughput.py | 29 +++++++++++++----------------
1 file changed, 13 insertions(+), 16 deletions(-)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 6bb889d1eceba..695d06e7b243d 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -103,25 +103,22 @@ def run_vllm(
)
# Add the requests to the engine.
+ prompts = []
+ sampling_params = []
for prompt, _, output_len in requests:
- sampling_params = SamplingParams(
- n=n,
- temperature=0.0 if use_beam_search else 1.0,
- top_p=1.0,
- use_beam_search=use_beam_search,
- ignore_eos=True,
- max_tokens=output_len,
- )
- # FIXME(woosuk): Do not use internal method.
- llm._add_request(
- prompt=prompt,
- prompt_token_ids=None,
- sampling_params=sampling_params,
- )
+ prompts.append(prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=0.0 if use_beam_search else 1.0,
+ top_p=1.0,
+ use_beam_search=use_beam_search,
+ ignore_eos=True,
+ max_tokens=output_len,
+ ))
start = time.perf_counter()
- # FIXME(woosuk): Do not use internal method.
- llm._run_engine(use_tqdm=True)
+ llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start
From 96e90fdeb3c4ebacfe24513556afccb918722b7c Mon Sep 17 00:00:00 2001
From: Caio Mendes
Date: Thu, 25 Apr 2024 00:06:57 -0300
Subject: [PATCH 116/413] [Model] Adds Phi-3 support (#4298)
---
README.md | 1 +
docs/source/models/supported_models.rst | 4 +
vllm/config.py | 2 +-
.../model_executor/layers/rotary_embedding.py | 136 +++++++++++++++++-
vllm/model_executor/models/__init__.py | 1 +
vllm/model_executor/models/llama.py | 14 +-
6 files changed, 149 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
index 947d50d4ad764..fce3de3b70430 100644
--- a/README.md
+++ b/README.md
@@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
+- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index 951fc3aac0c75..f4dd5d52ad873 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it.
- Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
-
+ * - :code:`Phi3ForCausalLM`
+ - Phi-3
+ - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
+ -
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
diff --git a/vllm/config.py b/vllm/config.py
index 2ff42de08f8f7..311a69f822571 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1056,7 +1056,7 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
- if rope_scaling is not None:
+ if rope_scaling is not None and rope_scaling["type"] != "su":
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index a5225148d7828..b8361af61ae3f 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -338,6 +338,114 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache
+class Phi3SuScaledRotaryEmbedding(nn.Module):
+ """Phi3 family of models scaled rotary embedding.
+
+ Based on the original RotaryEmbedding implementation.
+ """
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ original_max_position_embeddings: int,
+ base: int,
+ is_neox_style: bool,
+ short_factor: List[float],
+ long_factor: List[float],
+ short_mscale: float = 1.1,
+ long_mscale: float = 1.225,
+ ):
+ super().__init__()
+
+ if rotary_dim != head_size:
+ raise ValueError(
+ f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
+ head_size ({rotary_dim}!={head_size}).")
+ if is_neox_style is False:
+ raise ValueError(
+ "`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
+
+ self.head_size = head_size
+ self.max_position_embeddings = max_position_embeddings
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self.base = base
+ self.short_factor = short_factor
+ self.long_factor = long_factor
+ self.short_mscale = short_mscale
+ self.long_mscale = long_mscale
+
+ short_cache = self._compute_cos_sin_cache(
+ original_max_position_embeddings, short_factor, short_mscale)
+ short_cache = short_cache.to(torch.get_default_dtype())
+ self.register_buffer("short_cos_sin_cache",
+ short_cache,
+ persistent=False)
+
+ long_cache = self._compute_cos_sin_cache(max_position_embeddings,
+ long_factor, long_mscale)
+ long_cache = long_cache.to(torch.get_default_dtype())
+ self.register_buffer("long_cos_sin_cache",
+ long_cache,
+ persistent=False)
+
+ long_short_cache = torch.cat(
+ [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
+ self.register_buffer("long_short_cos_sin_cache",
+ long_short_cache,
+ persistent=False)
+
+ def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
+ rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
+ inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
+ 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
+ return inv_freq
+
+ def _compute_cos_sin_cache(
+ self,
+ max_position_embeddings: int,
+ rescale_factors: List[float],
+ mscale: float,
+ ) -> torch.Tensor:
+ inv_freq = self._compute_inv_freq(rescale_factors)
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
+ cos = freqs.cos() * mscale
+ sin = freqs.sin() * mscale
+ cache = torch.cat((cos, sin), dim=-1)
+ return cache
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ offsets: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ query = query.view(*query.shape[:-1], -1, self.head_size)
+ key = key.view(*key.shape[:-1], -1, self.head_size)
+
+ k = self.original_max_position_embeddings
+ long_prompt_offset = (torch.any(positions > k).float() *
+ torch.full_like(positions, k)).long()
+ idx = (torch.add(positions, long_prompt_offset)
+ if long_prompt_offset is not None else positions)
+ self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
+ idx.device)
+ idx = torch.add(idx, offsets) if offsets is not None else idx
+ cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
+
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ cos = cos.repeat(1, 2).unsqueeze(-2)
+ sin = sin.repeat(1, 2).unsqueeze(-2)
+
+ query = query * cos + _rotate_neox(query) * sin
+ key = key * cos + _rotate_neox(key) * sin
+
+ return query.flatten(-2), key.flatten(-2)
+
+
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@@ -349,17 +457,26 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
+ if rope_scaling is not None:
+ # Transforms every value that is a list into a tuple for caching calls
+ rope_scaling_tuple = {
+ k: tuple(v) if isinstance(v, list) else v
+ for k, v in rope_scaling.items()
+ }
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
+ else:
+ rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
- tuple(rope_scaling.items()) if rope_scaling is not None else None)
+ rope_scaling_args)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
-
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
- scaling_factor = rope_scaling["factor"]
+ if scaling_type != "su":
+ scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
@@ -383,6 +500,19 @@ def get_rope(
base, is_neox_style,
scaling_factor,
**extra_kwargs)
+ elif scaling_type == "su":
+ short_factor = rope_scaling["short_factor"]
+ long_factor = rope_scaling["long_factor"]
+ original_max_position = rope_scaling[
+ "original_max_position_embeddings"]
+ extra_kwargs = {
+ k: v
+ for k, v in rope_scaling.items()
+ if k in ("short_mscale", "long_mscale")
+ }
+ rotary_emb = Phi3SuScaledRotaryEmbedding(
+ head_size, rotary_dim, max_position, original_max_position,
+ base, is_neox_style, short_factor, long_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py
index 17fc970568042..c0aeab5dd3032 100755
--- a/vllm/model_executor/models/__init__.py
+++ b/vllm/model_executor/models/__init__.py
@@ -46,6 +46,7 @@
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
+ "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index 016e3b039d1e8..c102b40045c92 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -180,6 +180,10 @@ def __init__(
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
+ if rope_scaling is not None and getattr(
+ config, "original_max_position_embeddings", None):
+ rope_scaling["original_max_position_embeddings"] = (
+ config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
@@ -378,11 +382,11 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
+ (".qkv_proj", ".q_proj", "q"),
+ (".qkv_proj", ".k_proj", "k"),
+ (".qkv_proj", ".v_proj", "v"),
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
From 479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Wed, 24 Apr 2024 23:52:22 -0700
Subject: [PATCH 117/413] [Core] Move ray_utils.py from `engine` to `executor`
package (#4347)
---
vllm/__init__.py | 2 +-
vllm/engine/async_llm_engine.py | 2 +-
vllm/engine/llm_engine.py | 2 +-
vllm/executor/ray_gpu_executor.py | 10 ++++++----
vllm/{engine => executor}/ray_utils.py | 0
vllm/transformers_utils/tokenizer_group/__init__.py | 2 +-
.../tokenizer_group/ray_tokenizer_group.py | 2 +-
7 files changed, 11 insertions(+), 9 deletions(-)
rename vllm/{engine => executor}/ray_utils.py (100%)
diff --git a/vllm/__init__.py b/vllm/__init__.py
index 5ca4680227598..ca454efd44b24 100644
--- a/vllm/__init__.py
+++ b/vllm/__init__.py
@@ -3,8 +3,8 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
-from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
+from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 3a2f7db679358..4b007d71e9cfc 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -10,7 +10,7 @@
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
-from vllm.engine.ray_utils import initialize_ray_cluster, ray
+from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 19e58fb1722cf..56c2417d6a6e6 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -15,8 +15,8 @@
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
-from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
+from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index e69f104e7d5a4..14b3f803782c6 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -5,8 +5,8 @@
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
-from vllm.engine.ray_utils import RayWorkerWrapper, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@@ -74,7 +74,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
- self.driver_dummy_worker: RayWorkerWrapper = None
+ self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
@@ -318,6 +318,7 @@ def _run_workers(
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
+ assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
@@ -353,8 +354,9 @@ def _compiled_ray_dag(self):
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
- worker.execute_model_compiled_dag_remote.bind(input_data)
- for worker in self.workers
+ worker.execute_model_compiled_dag_remote.
+ bind( # type: ignore[attr-defined]
+ input_data) for worker in self.workers
])
return forward_dag.experimental_compile()
diff --git a/vllm/engine/ray_utils.py b/vllm/executor/ray_utils.py
similarity index 100%
rename from vllm/engine/ray_utils.py
rename to vllm/executor/ray_utils.py
diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py
index 69380d67f9b94..0195c40c27f60 100644
--- a/vllm/transformers_utils/tokenizer_group/__init__.py
+++ b/vllm/transformers_utils/tokenizer_group/__init__.py
@@ -1,7 +1,7 @@
from typing import Optional
from vllm.config import TokenizerPoolConfig
-from vllm.engine.ray_utils import ray
+from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
index f3cdc00564dbb..7c605416854b8 100644
--- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
+++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
@@ -6,7 +6,7 @@
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
-from vllm.engine.ray_utils import ray
+from vllm.executor.ray_utils import ray
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
From fbf152d976e655f8734561452a1036e63081ccfd Mon Sep 17 00:00:00 2001
From: Isotr0py <41363108+Isotr0py@users.noreply.github.com>
Date: Fri, 26 Apr 2024 00:35:56 +0800
Subject: [PATCH 118/413] [Bugfix][Model] Refactor OLMo model to support new HF
format in transformers 4.40.0 (#4324)
Co-authored-by: Woosuk Kwon
---
README.md | 2 +-
docs/source/models/supported_models.rst | 2 +-
requirements-dev.txt | 1 -
vllm/model_executor/models/__init__.py | 2 +-
vllm/model_executor/models/olmo.py | 305 ++++++++++++------------
5 files changed, 150 insertions(+), 162 deletions(-)
diff --git a/README.md b/README.md
index fce3de3b70430..e59a1c60cc369 100644
--- a/README.md
+++ b/README.md
@@ -74,7 +74,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
-- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
+- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst
index f4dd5d52ad873..ceb658bbd5c66 100644
--- a/docs/source/models/supported_models.rst
+++ b/docs/source/models/supported_models.rst
@@ -101,7 +101,7 @@ Alongside each architecture, we include some popular models that use it.
-
* - :code:`OLMoForCausalLM`
- OLMo
- - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
+ - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
-
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 1317e51b2dd11..d9816828d007d 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -26,7 +26,6 @@ requests
ray
peft
awscli
-ai2-olmo # required for OLMo
# Benchmarking
aiohttp
diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py
index c0aeab5dd3032..6afb2f31c1334 100755
--- a/vllm/model_executor/models/__init__.py
+++ b/vllm/model_executor/models/__init__.py
@@ -42,7 +42,7 @@
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
- "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
+ "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index b92003bc0e067..15527569b9e20 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -1,53 +1,36 @@
# coding=utf-8
# Adapted from
-# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
-# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
-# Copyright 2023 The vLLM team.
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT license.
+# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
+# Copyright 2024 The vLLM team.
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
-# BSD 3-Clause License
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
-# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
-# All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
+# http://www.apache.org/licenses/LICENSE-2.0
#
-# * Redistributions of source code must retain the above copyright notice, this
-# list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-# this list of conditions and the following disclaimer in the documentation
-# and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
import torch
-# this model must need this dependency
-from hf_olmo import OLMoConfig
from torch import nn
+from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
-from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
+from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
@@ -55,7 +38,7 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
- VocabParallelEmbedding)
+ ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
@@ -70,55 +53,52 @@ class OlmoAttention(nn.Module):
def __init__(
self,
- config: OLMoConfig,
+ config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
- self.hidden_size = config.d_model
- assert config.d_model % config.n_heads == 0
+ self.hidden_size = config.hidden_size
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
- self.total_num_heads = self.config.n_heads
+ self.total_num_heads = config.num_attention_heads
+
+ assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tensor_model_parallel_world_size == 0
+
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = self.hidden_size // self.total_num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.clip_qkv = config.clip_qkv
- # Layer norms.
- self.attn_norm = nn.LayerNorm(config.d_model,
- elementwise_affine=False,
- bias=False)
# Attention input projection. Projects x -> (q, k, v)
- self.att_proj = QKVParallelLinear(
- config.d_model,
+ self.qkv_proj = QKVParallelLinear(
+ self.hidden_size,
self.head_dim,
self.total_num_heads,
- bias=config.include_bias,
+ bias=config.attention_bias,
linear_method=linear_method,
)
# Rotary embeddings.
- if self.config.rope:
- rope_theta = getattr(config, "rope_theta", 10000)
- max_position_embeddings = getattr(config,
- "max_position_embeddings", 8192)
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=max_position_embeddings,
- base=rope_theta,
- )
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling)
# Attention output projection.
- self.attn_out = RowParallelLinear(
- config.d_model,
- config.d_model,
- bias=config.include_bias,
+ self.o_proj = RowParallelLinear(
+ self.hidden_size,
+ self.hidden_size,
+ bias=config.attention_bias,
linear_method=linear_method,
)
@@ -129,13 +109,13 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
- hidden_states = self.attn_norm(hidden_states)
- qkv, _ = self.att_proj(hidden_states)
+ qkv, _ = self.qkv_proj(hidden_states)
+ if self.clip_qkv is not None:
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
- if self.config.rope:
- q, k = self.rotary_emb(positions, q, k)
+ q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.attn_out(attn_output)
+ output, _ = self.o_proj(attn_output)
return output
@@ -148,37 +128,30 @@ class OlmoMLP(nn.Module):
def __init__(
self,
- config: OLMoConfig,
+ config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
- self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
- is not None else config.mlp_ratio * config.d_model)
-
- # Layer norms.
- self.ff_norm = nn.LayerNorm(config.d_model,
- elementwise_affine=False,
- bias=False)
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
# Feed-forward input projection.
- self.ff_proj = MergedColumnParallelLinear(
- config.d_model,
- [self.hidden_size // 2] * 2,
- bias=config.include_bias,
+ self.gate_up_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.intermediate_size] * 2,
+ bias=False,
linear_method=linear_method,
)
# Activation function.
- self.act = SiluAndMul()
- self.act.output_multiplier = 0.5
- assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+ self.act_fn = SiluAndMul()
# Feed-forward output projection.
- self.ff_out = RowParallelLinear(
- int(self.act.output_multiplier * self.hidden_size),
- config.d_model,
- bias=config.include_bias,
+ self.down_proj = RowParallelLinear(
+ self.intermediate_size,
+ self.hidden_size,
+ bias=False,
linear_method=linear_method,
)
@@ -186,19 +159,13 @@ def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
- # Add feed-forward projection.
- # shape: (batch_size, seq_len, d_model)
- og_x = x
- x = self.ff_norm(x)
- x, _ = self.ff_proj(x)
- x = self.act(x)
- x, _ = self.ff_out(x)
- x = og_x + x
-
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
return x
-class OlmoBlock(nn.Module):
+class OlmoDecoderLayer(nn.Module):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
@@ -206,15 +173,23 @@ class OlmoBlock(nn.Module):
"""
def __init__(self,
- config: OLMoConfig,
+ config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
# Attention block.
- self.attn = OlmoAttention(config, linear_method)
+ self.self_attn = OlmoAttention(config, linear_method)
# MLP block.
self.mlp = OlmoMLP(config, linear_method)
+ # LayerNorm
+ self.input_layernorm = nn.LayerNorm(config.hidden_size,
+ elementwise_affine=False,
+ bias=False)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+ elementwise_affine=False,
+ bias=False)
+
def forward(
self,
positions: torch.Tensor,
@@ -223,52 +198,37 @@ def forward(
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
- og_x = hidden_states
- x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
- x = x + og_x
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(positions, hidden_states, kv_cache,
+ attn_metadata)
+ hidden_states = hidden_states + residual
# MLP block.
- hidden_states = self.mlp(x)
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
return hidden_states
class OlmoModel(nn.Module):
def __init__(self,
- config: OLMoConfig,
+ config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
- self.transformer = nn.ModuleDict(
- dict(
- wte=VocabParallelEmbedding(
- config.embedding_size or config.vocab_size,
- config.d_model,
- ),
- ln_f=nn.LayerNorm(config.d_model,
- elementwise_affine=False,
- bias=False),
- ))
-
- blocks = [
- OlmoBlock(config, linear_method) for i in range(config.n_layers)
- ]
- if self.config.block_group_size > 1:
- raise NotImplementedError("Block group size > 1 not supported yet")
- else:
- self.transformer.update({"blocks": nn.ModuleList(blocks)})
-
- if not config.weight_tying:
- self.transformer.update({
- "ff_out":
- ColumnParallelLinear(
- config.d_model,
- config.embedding_size or config.vocab_size,
- bias=config.include_bias,
- linear_method=linear_method,
- )
- })
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+ config.hidden_size)
+ self.layers = nn.ModuleList([
+ OlmoDecoderLayer(config, linear_method)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
+ self.norm = nn.LayerNorm(config.hidden_size,
+ elementwise_affine=False,
+ bias=False)
def forward(
self,
@@ -282,39 +242,49 @@ def forward(
"""
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
- x = self.transformer.wte(input_ids) # type: ignore
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ hidden_states = inputs_embeds
# Apply blocks one-by-one.
- for block_idx, block in enumerate(self.transformer.blocks):
+ for layer_idx, decoder_layer in enumerate(self.layers):
# shape: (batch_size, seq_len, d_model)
- x = block(
+ hidden_states = decoder_layer(
positions,
- x,
- kv_caches[block_idx],
+ hidden_states,
+ kv_caches[layer_idx],
attn_metadata,
)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
- x = self.transformer.ln_f(x) # type: ignore
- return x
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
-class OLMoForCausalLM(nn.Module):
+class OlmoForCausalLM(nn.Module):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self,
- config: OLMoConfig,
+ config: OlmoConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = OlmoModel(config, linear_method)
- self.lm_head_weight = (self.model.transformer.wte.weight
- if config.weight_tying else
- self.model.transformer.ff_out.weight)
+ if config.tie_word_embeddings:
+ self.lm_head_weight = self.model.embed_tokens.weight
+ else:
+ self.unpadded_vocab_size = config.vocab_size
+ self.lm_head = ParallelLMHead(
+ self.unpadded_vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ )
+ self.lm_head_weight = self.lm_head.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@@ -348,20 +318,39 @@ def sample(
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
- # attention
- if ".att" in name:
- name = name.replace(".att", ".attn.att")
- # mlp
- if ".ff_proj" in name:
- name = name.replace(".ff_proj", ".mlp.ff_proj")
- # Reverse the weight for the MergeColumnParallelLinear
- loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
- if ".ff_out" in name and "transformer.ff_out" not in name:
- name = name.replace(".ff_out", ".mlp.ff_out")
- # there is no bias in olmo
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
+ if "rotary_emb.inv_freq" in name:
+ continue
+ if ("rotary_emb.cos_cached" in name
+ or "rotary_emb.sin_cached" in name):
+ # Models trained using ColossalAI may include these tensors in
+ # the checkpoint. Skip them.
+ continue
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
From 7ee82bef1e71febc28757807f5df8191bb36d88e Mon Sep 17 00:00:00 2001
From: Alexei-V-Ivanov-AMD
<156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Date: Thu, 25 Apr 2024 11:37:20 -0500
Subject: [PATCH 119/413] [CI/Build] Adding functionality to reset the node's
GPUs before processing. (#4213)
---
.buildkite/run-amd-test.sh | 16 +++++++++++++++-
1 file changed, 15 insertions(+), 1 deletion(-)
diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh
index 83a56e25aca73..38aff57a410dc 100644
--- a/.buildkite/run-amd-test.sh
+++ b/.buildkite/run-amd-test.sh
@@ -5,6 +5,19 @@ set -ex
# Print ROCm version
rocminfo
+
+echo "reset" > /opt/amdgpu/etc/gpu_state
+
+while true; do
+ sleep 3
+ if grep -q clean /opt/amdgpu/etc/gpu_state; then
+ echo "GPUs state is \"clean\""
+ break
+ fi
+done
+
+
+
# Try building the docker image
docker build -t rocm -f Dockerfile.rocm .
@@ -14,7 +27,8 @@ trap remove_docker_container EXIT
remove_docker_container
# Run the image
-docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server &
+export HIP_VISIBLE_DEVICES=1
+docker run --device /dev/kfd --device /dev/dri --network host -e HIP_VISIBLE_DEVICES --name rocm rocm python3 -m vllm.entrypoints.api_server &
# Wait for the server to start
wait_for_server_to_start() {
From bd7a8eef25cd85be7eb9f2a94fd752d27ee7dce3 Mon Sep 17 00:00:00 2001
From: Caio Mendes
Date: Thu, 25 Apr 2024 14:32:00 -0300
Subject: [PATCH 120/413] [Doc] README Phi-3 name fix. (#4372)
Co-authored-by: Caio Mendes
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index e59a1c60cc369..524d027137aba 100644
--- a/README.md
+++ b/README.md
@@ -78,7 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
-- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
+- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
From f4bc4de1b1a1bd33bd3ec7dea3377eb75884250a Mon Sep 17 00:00:00 2001
From: Kunshang Ji
Date: Thu, 25 Apr 2024 19:03:56 +0000
Subject: [PATCH 121/413] [Core]refactor aqlm quant ops (#4351)
---
benchmarks/kernels/benchmark_aqlm.py | 2 +-
vllm/_custom_ops.py | 14 ++++++++++++++
vllm/model_executor/layers/quantization/aqlm.py | 2 +-
3 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py
index 9602d20bcbc74..59392947b15c8 100644
--- a/benchmarks/kernels/benchmark_aqlm.py
+++ b/benchmarks/kernels/benchmark_aqlm.py
@@ -6,7 +6,7 @@
import torch
import torch.nn.functional as F
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index e4b16ed918d1a..508d35656eb00 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
+# aqlm
+def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
+ codebooks: torch.Tensor, scales: torch.Tensor,
+ codebook_partition_sizes: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ return vllm_ops.aqlm_gemm(input, codes, codebooks, scales,
+ codebook_partition_sizes, bias)
+
+
+def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
+ codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
+ return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
+
+
# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py
index 6115b1de679ad..b48c6e1702be4 100644
--- a/vllm/model_executor/layers/quantization/aqlm.py
+++ b/vllm/model_executor/layers/quantization/aqlm.py
@@ -8,7 +8,7 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
From b5b4a398a794e3276729a525e58eaa92f5fc0212 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Fri, 26 Apr 2024 04:13:50 +0900
Subject: [PATCH 122/413] [Mypy] Typing lora folder (#4337)
---
.github/workflows/mypy.yaml | 7 ++--
format.sh | 2 +-
vllm/lora/layers.py | 35 ++++++++++++--------
vllm/lora/lora.py | 28 +++++++++-------
vllm/lora/models.py | 64 ++++++++++++++++++++-----------------
vllm/lora/worker_manager.py | 21 ++++++------
vllm/worker/model_runner.py | 4 +--
7 files changed, 91 insertions(+), 70 deletions(-)
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
index 9f1855696e20a..089c7d18ad6f2 100644
--- a/.github/workflows/mypy.yaml
+++ b/.github/workflows/mypy.yaml
@@ -33,8 +33,6 @@ jobs:
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
- # TODO(sang): Fix nested dir
- mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
@@ -44,8 +42,9 @@ jobs:
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
+ mypy vllm/lora --config-file pyproject.toml
+
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
- # TODO(sang): Fix nested dir
- # mypy vllm/lora/*.py --config-file pyproject.toml
+ mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
diff --git a/format.sh b/format.sh
index bd2e9e89e1806..4ac1842daef0a 100755
--- a/format.sh
+++ b/format.sh
@@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
-# mypy vllm/lora/*.py --config-file pyproject.toml
+mypy vllm/lora --config-file pyproject.toml
CODESPELL_EXCLUDES=(
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index aac86351b15e1..98e74168002c4 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
+ self.embeddings_slice: Optional[Tuple[int, int]]
+ self.embeddings_weights: Optional[torch.Tensor]
def create_lora_weights(
self,
@@ -233,9 +235,10 @@ def create_lora_weights(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
- self.indices: Optional[torch.Tensor] = None
- self.indices_len: Optional[List[int]] = None
- self.embeddings_indices = None
+ # Lazily initialized.
+ self.indices: torch.Tensor
+ self.indices_len: List[int]
+ self.embeddings_indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
@@ -267,6 +270,7 @@ def set_lora(
self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2]
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
+ assert self.embeddings_weights is not None
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def set_mapping(
@@ -343,11 +347,12 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)
-
- self.indices: Optional[torch.Tensor] = None
- self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[2]
+ # lazily initialized.
+ self.indices: torch.Tensor
+ self.indices_len: List[int]
+
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
@@ -475,8 +480,9 @@ def create_lora_weights(
device=self.device,
) for _ in range(n_slices))
- self.indices: Optional[torch.Tensor] = None
self.output_dim = self.lora_b_stacked[0].shape[2]
+ # Lazily initialized.
+ self.indices: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
@@ -690,7 +696,8 @@ def create_lora_weights(
self.kv_proj_shard_size)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
- self.indices_len: Optional[List[int]] = None
+ # lazily initialized.
+ self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
@@ -814,8 +821,9 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)
- self.indices: Optional[torch.Tensor] = None
- self.indices_len: Optional[List[int]] = None
+ # Lazily initialized
+ self.indices: torch.Tensor
+ self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
@@ -991,9 +999,10 @@ def create_lora_weights(
dtype=self.dtype,
device=self.device,
)
- self.indices = None
- self.indices_padded = None
- self.indices_len = None
+ # Lazily initialized.
+ self.indices: torch.Tensor
+ self.indices_len: List[int]
+ self.indices_padded: torch.Tensor
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py
index fefad16700fe3..d7794aa7cd35c 100644
--- a/vllm/lora/lora.py
+++ b/vllm/lora/lora.py
@@ -97,9 +97,9 @@ def __init__(
self,
module_name: str,
rank: int,
- lora_alphas: List[int],
- lora_a: List[torch.Tensor],
- lora_b: List[torch.Tensor],
+ lora_alphas: List[Optional[int]],
+ lora_a: List[Optional[torch.Tensor]],
+ lora_b: List[Optional[torch.Tensor]],
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
@@ -108,17 +108,20 @@ def __init__(
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
- scaling=scaling,
+ scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
- self.scaling = [
- lora_alpha / self.rank for lora_alpha in self.lora_alphas
+ self.scaling = [ # type: ignore
+ lora_alpha / self.rank # type: ignore # noqa
+ for lora_alpha in self.lora_alphas
]
@classmethod
- def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
+ def pack(
+ cls, loras: List[Optional["LoRALayerWeights"]]
+ ) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
@@ -136,16 +139,19 @@ def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
- scaling=[1 if lora is not None else None for lora in loras])
+ scaling=[
+ 1 if lora is not None else None # type: ignore
+ for lora in loras
+ ])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
- if self.scaling[i] == 1 or self.lora_b[i] is None:
+ if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
- self.lora_b[i] *= self.scaling[i]
- self.scaling[i] = 1
+ self.lora_b[i] *= self.scaling[i] # type: ignore
+ self.scaling[i] = 1 # type: ignore
return self
@property
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 6bb9fee27d535..c249497a4d893 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -3,7 +3,7 @@
import math
import os
import re
-from typing import Callable, Dict, Hashable, List, Optional, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type
import safetensors.torch
import torch
@@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
indices_len: List of lengths of the above tensors.
"""
- indices = list(mapping.index_mapping).copy()
- embedding_indices = indices.copy()
- lora_indices = indices.copy()
- prompt_mapping = [
+ index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
+ embedding_indices = index_mapping_indices.copy()
+ lora_indices = index_mapping_indices.copy()
+ prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
- for i in range(len(indices)):
+ for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
- lora_idx = (lora_index_to_id.index(indices[i])
- if indices[i] > 0 else -1)
- embedding_indices[i] = lora_idx if indices[i] > 0 else 0
- indices[i] = i
+ lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
+ if index_mapping_indices[i] > 0 else -1)
+ embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
+ index_mapping_indices[i] = i
lora_indices[i] = lora_idx
- indices = torch.tensor([indices, lora_indices, embedding_indices],
- dtype=torch.long,
- device="cuda")
- prompt_mapping = torch.tensor(prompt_mapping,
- device="cuda",
- dtype=torch.long)
+ indices = torch.tensor(
+ [index_mapping_indices, lora_indices, embedding_indices],
+ dtype=torch.long,
+ device="cuda")
+ prompt_mapping_tensor = torch.tensor(prompt_mapping,
+ device="cuda",
+ dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
- sampler_indices = prompt_mapping
+ sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
- indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
- sampler_indices_padded.shape[-1],
- embeddings_indices.shape[-1])
+ indices_len = [
+ base_indices.shape[-1], sampler_indices.shape[-1],
+ sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
+ ]
return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, indices_len)
@@ -149,6 +151,7 @@ def from_lora_tensors(
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
+ assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
@@ -171,6 +174,7 @@ def from_lora_tensors(
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
+ assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
@@ -295,11 +299,10 @@ def __init__(
self.max_num_batched_tokens,
dtype=torch.long,
device="cuda")
- self.offsets = []
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
- self.indices_len = [None] * 4
+ self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model
if hasattr(self.model, "supported_lora_modules"):
@@ -312,7 +315,7 @@ def __init__(
self._registered_loras: Dict[int, LoRAModel] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_loras: Dict[int, None] = {}
- self._last_mapping = None
+ self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
@@ -370,7 +373,7 @@ def deactivate_lora(self, lora_id: int) -> bool:
return True
return False
- def _add_lora(self, lora: LoRAModel) -> bool:
+ def _add_lora(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_loras[lora.id] = lora
@@ -418,7 +421,7 @@ def list_loras(self) -> Dict[int, LoRAModel]:
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
return self._registered_loras.get(lora_id, None)
- def remove_all_loras(self) -> bool:
+ def remove_all_loras(self):
"""Remove all LoRAModels from the manager."""
self._registered_loras.clear()
self.lora_index_to_id = [None] * self.lora_slots
@@ -467,6 +470,7 @@ def create_dummy_lora(
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
+ assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
@@ -500,7 +504,7 @@ def create_dummy_lora(
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
- subloras = []
+ subloras: List[Optional["LoRALayerWeights"]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
@@ -538,7 +542,7 @@ def _register_packed_modules(self, module_full_name: str) -> None:
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
- replacement_loras = []
+ replacement_loras: List[Optional[LoRALayerWeights]] = []
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
@@ -557,12 +561,12 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
class LoRALRUCache(LRUCache[LoRAModel]):
- def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
- None]):
+ def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
+ bool]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
- def _on_remove(self, key: Hashable, value: LoRAModel):
+ def _on_remove(self, key: int, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py
index 5356b79537b05..ec3c10c591a18 100644
--- a/vllm/lora/worker_manager.py
+++ b/vllm/lora/worker_manager.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty
-from typing import Any, Dict, List, Optional, Set, Type
+from typing import Any, Dict, List, Set, Type
import torch
@@ -37,7 +37,7 @@ def create_lora_manager(
...
@abstractmethod
- def set_active_loras(self, lora_requests: List[LoRARequest],
+ def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
...
@@ -54,7 +54,7 @@ def remove_lora(self, lora_id: int) -> bool:
...
@abstractmethod
- def remove_all_loras(self) -> bool:
+ def remove_all_loras(self):
...
@abstractmethod
@@ -81,10 +81,11 @@ def __init__(
embedding_padding_modules: List[str],
lora_model_cls: Type[LoRAModel] = LoRAModel,
):
- self._lora_manager: Optional[LoRAModelManager] = None
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
+ # Lazily initialized by create_lora_manager.
+ self._lora_manager: LoRAModelManager
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
lora_config, device)
@@ -104,7 +105,7 @@ def create_lora_manager(
lora_config=self.lora_config,
lora_manager_cls=self._lora_manager_cls,
)
- self._lora_manager: LoRAModelManager = lora_manager
+ self._lora_manager = lora_manager
return lora_manager.model
def set_active_loras(self, lora_requests: Set[LoRARequest],
@@ -188,7 +189,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)
- def remove_all_loras(self) -> bool:
+ def remove_all_loras(self):
self._lora_manager.remove_all_loras()
def list_loras(self) -> Set[int]:
@@ -217,10 +218,10 @@ def create_lora_manager(
lora_config=self.lora_config,
max_num_batched_tokens=self.max_num_batched_tokens,
)
- self._lora_manager: LRUCacheLoRAModelManager = lora_manager
+ self._lora_manager = lora_manager
return lora_manager.model
- def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
+ def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
@@ -237,12 +238,14 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id not in self.list_loras():
# Remove before we load the new lora to save memory
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
+ assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
self._lora_manager.remove_oldest_lora()
lora = self._load_lora(lora_request)
loaded = self._lora_manager.add_lora(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
- loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
+ loaded = self._lora_manager.get_lora(
+ lora_request.lora_int_id) is not None
self._lora_manager.activate_lora(lora_request.lora_int_id)
return loaded
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 31e08789dfd1f..33dbf8d90c35d 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -928,10 +928,10 @@ def profile_run(self) -> None:
torch.cuda.synchronize()
return
- def remove_all_loras(self) -> bool:
+ def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
- return self.lora_manager.remove_all_loras()
+ self.lora_manager.remove_all_loras()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
From b6dcb4d44281a9e85cafcfa6376c373d02286779 Mon Sep 17 00:00:00 2001
From: Roy
Date: Fri, 26 Apr 2024 03:43:32 +0800
Subject: [PATCH 123/413] [Misc] Fix flash attention backend log (#4368)
---
vllm/attention/selector.py | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py
index 554e802cd5513..7cc17f21dcd0e 100644
--- a/vllm/attention/selector.py
+++ b/vllm/attention/selector.py
@@ -25,7 +25,7 @@ class _Backend(enum.Enum):
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
- logger.info("Using FlashAttention backend.")
+ logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
@@ -62,12 +62,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
- logger.info("Cannot use FlashAttention backend for Volta and Turing "
+ logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS
if dtype not in (torch.float16, torch.bfloat16):
- logger.info("Cannot use FlashAttention backend for dtype other than "
+ logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
@@ -75,8 +75,8 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
import flash_attn # noqa: F401
except ImportError:
logger.info(
- "Cannot use FlashAttention backend because the flash_attn package "
- "is not found. Please install it for better performance.")
+ "Cannot use FlashAttention-2 backend because the flash_attn "
+ "package is not found. Please install it for better performance.")
return _Backend.XFORMERS
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
From 15e7c675b0dc36109c7b591f856f102e96493a94 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Thu, 25 Apr 2024 16:32:48 -0700
Subject: [PATCH 124/413] [Core] Add `shutdown()` method to `ExecutorBase`
(#4349)
---
vllm/engine/llm_engine.py | 6 ++++++
vllm/executor/executor_base.py | 7 +++++++
2 files changed, 13 insertions(+)
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 56c2417d6a6e6..7de60d738113e 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -287,6 +287,12 @@ def __reduce__(self):
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")
+ def __del__(self):
+ # Shutdown model executor when engine is garbage collected
+ # Use getattr since __init__ can fail before the field is set
+ if model_executor := getattr(self, "model_executor", None):
+ model_executor.shutdown()
+
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index 1839b5603ff3e..1838c34be2fda 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -95,6 +95,13 @@ def check_health(self) -> None:
exception."""
raise NotImplementedError
+ def shutdown(self) -> None:
+ """Shutdown the executor."""
+ return
+
+ def __del__(self):
+ self.shutdown()
+
class ExecutorAsyncBase(ExecutorBase):
From efffb63f584c1ce4fdcf4e7b7fd0bfc8b33a733a Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Thu, 25 Apr 2024 16:45:12 -0700
Subject: [PATCH 125/413] [Core] Move function tracing setup to util function
(#4352)
---
vllm/utils.py | 21 ++++++++++++++++++++-
vllm/worker/worker_base.py | 18 ++++--------------
2 files changed, 24 insertions(+), 15 deletions(-)
diff --git a/vllm/utils.py b/vllm/utils.py
index 15c8818cc4506..79ac1db01fc69 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -1,10 +1,13 @@
import asyncio
+import datetime
import enum
import gc
import glob
import os
import socket
import subprocess
+import tempfile
+import threading
import uuid
import warnings
from collections import defaultdict
@@ -18,7 +21,7 @@
import torch
from packaging.version import Version, parse
-from vllm.logger import init_logger
+from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T")
logger = init_logger(__name__)
@@ -607,3 +610,19 @@ def find_nccl_library():
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
return so_file
+
+
+def enable_trace_function_call_for_thread() -> None:
+ """Set up function tracing for the current thread,
+ if enabled via the VLLM_TRACE_FUNCTION environment variable
+ """
+
+ if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
+ tmp_dir = tempfile.gettempdir()
+ filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
+ f"_thread_{threading.get_ident()}_"
+ f"at_{datetime.datetime.now()}.log").replace(" ", "_")
+ log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
+ filename)
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
+ enable_trace_function_call(log_path)
diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py
index b5dade0a770a0..0a89e3a79769f 100644
--- a/vllm/worker/worker_base.py
+++ b/vllm/worker/worker_base.py
@@ -1,15 +1,13 @@
-import datetime
import importlib
import os
-import tempfile
-import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple
-from vllm.logger import enable_trace_function_call, init_logger
+from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
-from vllm.utils import get_vllm_instance_id, update_environment_variables
+from vllm.utils import (enable_trace_function_call_for_thread,
+ update_environment_variables)
logger = init_logger(__name__)
@@ -128,15 +126,7 @@ def init_worker(self, *args, **kwargs):
function tracing if required.
Arguments are passed to the worker class constructor.
"""
- if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
- tmp_dir = tempfile.gettempdir()
- filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
- f"_thread_{threading.get_ident()}_"
- f"at_{datetime.datetime.now()}.log").replace(" ", "_")
- log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
- filename)
- os.makedirs(os.path.dirname(log_path), exist_ok=True)
- enable_trace_function_call(log_path)
+ enable_trace_function_call_for_thread()
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
From cf29b7eda47d5a8c5fe6a7a53490271da8520563 Mon Sep 17 00:00:00 2001
From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Date: Thu, 25 Apr 2024 21:12:25 -0400
Subject: [PATCH 126/413] [ROCm][Hardware][AMD][Doc] Documentation update for
ROCm (#4376)
Co-authored-by: WoosukKwon
---
.../getting_started/amd-installation.rst | 165 +++++++-----------
1 file changed, 65 insertions(+), 100 deletions(-)
diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst
index 3d736bf7120ec..61fcd45a26347 100644
--- a/docs/source/getting_started/amd-installation.rst
+++ b/docs/source/getting_started/amd-installation.rst
@@ -3,9 +3,7 @@
Installation with ROCm
======================
-vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
-At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
-Data types currently supported in ROCm are FP16 and BF16.
+vLLM supports AMD GPUs with ROCm 5.7 and 6.0.
Requirements
------------
@@ -13,114 +11,57 @@ Requirements
* OS: Linux
* Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
-* Pytorch 2.0.1/2.1.1/2.2
-* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
+* ROCm 6.0 and ROCm 5.7
Installation options:
-#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image `
-#. :ref:`Build from source `
#. :ref:`Build from source with docker `
+#. :ref:`Build from source `
-.. _quick_start_docker_rocm:
-
-(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
----------------------------------------------------------------------------
-
-This option is for ROCm 5.7 only:
-
-.. code-block:: console
-
- $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
- $ docker run -it \
- --network=host \
- --group-add=video \
- --ipc=host \
- --cap-add=SYS_PTRACE \
- --security-opt seccomp=unconfined \
- --device /dev/kfd \
- --device /dev/dri \
- -v :/app/model \
- embeddedllminfo/vllm-rocm \
- bash
-
-
-.. _build_from_source_rocm:
-
-Option 2: Build from source
----------------------------
-
-You can build and install vLLM from source:
-
-Below instruction is for ROCm 5.7 only.
-At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
-
-0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
-
-- `ROCm `_
-- `Pytorch `_
-
- .. code-block:: console
-
- $ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
-
-
-1. Install `flash attention for ROCm `_
-
- Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_
-
-.. note::
- - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
- - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
- - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
-
-2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
+.. _build_from_source_docker_rocm:
- .. code-block:: console
+Option 1: Build from source with docker (recommended)
+-----------------------------------------------------
- $ pip install xformers==0.0.23 --no-deps
- $ bash patch_xformers.rocm.sh
+You can build and install vLLM from source.
-3. Build vLLM.
+First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image.
- .. code-block:: console
+`Dockerfile.rocm `_ uses ROCm 6.0 by default, but also supports ROCm 5.7.
+It provides flexibility to customize the build of docker image using the following arguments:
- $ cd vllm
- $ pip install -U -r requirements-rocm.txt
- $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
+* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
+* `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target.
+* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
+* `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo `_. The default is `ae7928c`
+* `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1.
+Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
-.. _build_from_source_docker_rocm:
-Option 3: Build from source with docker
------------------------------------------------------
+To build vllm on ROCm 6.0 for MI200 and MI300 series, you can use the default:
-You can build and install vLLM from source:
+.. code-block:: console
-Build a docker image from `Dockerfile.rocm`, and launch a docker container.
+ $ docker build -f Dockerfile.rocm -t vllm-rocm .
-The `Dockerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
+To build vllm on ROCm 6.0 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below:
-* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
-* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
-* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo `_. The default is `3d2b6f5`
-* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) `_, this should be set to 0 before flash-attention supports this target.
+.. code-block:: console
-Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
+ $ docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
-For example, to build docker image for vllm on ROCm 5.7, you can run:
+To build docker image for vllm on ROCm 5.7, you can specify ``BASE_IMAGE`` as below:
.. code-block:: console
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
-f Dockerfile.rocm -t vllm-rocm .
-To build vllm on ROCm 6.0, you can use the default:
+To run the above docker image ``vllm-rocm``, use the below command:
.. code-block:: console
- $ docker build -f Dockerfile.rocm -t vllm-rocm .
$ docker run -it \
--network=host \
--group-add=video \
@@ -133,7 +74,13 @@ To build vllm on ROCm 6.0, you can use the default:
vllm-rocm \
bash
-Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
+Where the `` is the location where the model is stored, for example, the weights for llama2 or llama3 models.
+
+
+.. _build_from_source_rocm:
+
+Option 2: Build from source
+---------------------------
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
@@ -141,32 +88,50 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from
- `Pytorch `_
- `hipBLAS `_
-1. Install `flash attention for ROCm `_
+For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`.
- Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_
+Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started `_
+
+For rocm6.0:
+
+.. code-block:: console
+
+ $ pip3 install torch --index-url https://download.pytorch.org/whl/rocm6.0
+
+
+For rocm5.7:
+
+.. code-block:: console
+
+ $ pip install torch --index-url https://download.pytorch.org/whl/rocm5.7
+
+
+1. Install `Triton flash attention for ROCm `_
+
+Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_
+
+2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_
+
+Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention `_
.. note::
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
- - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
+ - If you fail to install `ROCm/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
-2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
-
- .. code-block:: console
-
- $ pip install xformers==0.0.23 --no-deps
- $ bash patch_xformers.rocm.sh
-
3. Build vLLM.
- .. code-block:: console
+.. code-block:: console
- $ cd vllm
- $ pip install -U -r requirements-rocm.txt
- $ python setup.py install # This may take 5-10 minutes.
+ $ cd vllm
+ $ pip install -U -r requirements-rocm.txt
+ $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
-.. note::
- - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
+.. tip::
+ - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
+ - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
+ - To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention.
+ - The ROCm version of pytorch, ideally, should match the ROCm driver version.
From a74dee9b62d10767eb0580f196f5e508e9e80a2d Mon Sep 17 00:00:00 2001
From: Cyrus Leung
Date: Fri, 26 Apr 2024 10:10:48 +0800
Subject: [PATCH 127/413] [Bugfix] Fix parameter name in `get_tokenizer`
(#4107)
---
tests/tokenization/test_tokenizer.py | 20 ++++++++++++++++++++
vllm/transformers_utils/tokenizer.py | 11 ++++++-----
2 files changed, 26 insertions(+), 5 deletions(-)
create mode 100644 tests/tokenization/test_tokenizer.py
diff --git a/tests/tokenization/test_tokenizer.py b/tests/tokenization/test_tokenizer.py
new file mode 100644
index 0000000000000..8db7204f15d4e
--- /dev/null
+++ b/tests/tokenization/test_tokenizer.py
@@ -0,0 +1,20 @@
+import pytest
+from transformers import PreTrainedTokenizerBase
+
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+TOKENIZER_NAMES = [
+ "facebook/opt-125m",
+ "gpt2",
+]
+
+
+@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES)
+def test_tokenizer_revision(tokenizer_name: str):
+ # Assume that "main" branch always exists
+ tokenizer = get_tokenizer(tokenizer_name, revision="main")
+ assert isinstance(tokenizer, PreTrainedTokenizerBase)
+
+ # Assume that "never" branch always does not exist
+ with pytest.raises(OSError, match='not a valid git identifier'):
+ get_tokenizer(tokenizer_name, revision="never")
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index c98a673bfed4b..afc02c434dd43 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -58,11 +58,12 @@ def get_tokenizer(
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
- tokenizer_revision: Optional[str] = None,
+ revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
- """Gets a tokenizer for the given model name via Huggingface/modelscope."""
+ """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
+ """
if VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
@@ -74,7 +75,7 @@ def get_tokenizer(
tokenizer_path = snapshot_download(
model_id=tokenizer_name,
cache_dir=download_dir,
- revision=tokenizer_revision,
+ revision=revision,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path
@@ -90,7 +91,7 @@ def get_tokenizer(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
- tokenizer_revision=tokenizer_revision,
+ revision=revision,
**kwargs)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
@@ -114,7 +115,7 @@ def get_tokenizer(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
- tokenizer_revision=tokenizer_revision,
+ revision=revision,
**kwargs)
else:
raise e
From 2f30e7c72fca61c8225654880ee1ef89cad1690c Mon Sep 17 00:00:00 2001
From: Norman Mu
Date: Thu, 25 Apr 2024 22:36:01 -0700
Subject: [PATCH 128/413] [Frontend] Add --log-level option to api server
(#4377)
---
vllm/entrypoints/api_server.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py
index 587142adb9c6b..075de0b4efb2d 100644
--- a/vllm/entrypoints/api_server.py
+++ b/vllm/entrypoints/api_server.py
@@ -100,6 +100,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
+ parser.add_argument("--log-level", type=str, default="debug")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
@@ -110,7 +111,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
uvicorn.run(app,
host=args.host,
port=args.port,
- log_level="debug",
+ log_level=args.log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
From a88081bf768fcc1c662e4f588bd01ca9ddcc6aad Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Fri, 26 Apr 2024 16:16:58 +0900
Subject: [PATCH 129/413] [CI] Disable non-lazy string operation on logging
(#4326)
Co-authored-by: Danny Guinther
---
docs/source/conf.py | 5 +-
pyproject.toml | 1 +
setup.py | 7 ++-
vllm/config.py | 16 ++---
vllm/core/scheduler.py | 10 +--
.../device_communicators/custom_all_reduce.py | 8 +--
.../device_communicators/pynccl.py | 15 ++---
.../device_communicators/pynccl_utils.py | 4 +-
vllm/distributed/parallel_state.py | 6 +-
vllm/distributed/utils.py | 4 +-
vllm/engine/async_llm_engine.py | 18 +++---
vllm/engine/llm_engine.py | 61 +++++++++++--------
vllm/engine/metrics.py | 21 ++++---
vllm/entrypoints/openai/api_server.py | 4 +-
vllm/entrypoints/openai/serving_chat.py | 11 ++--
vllm/executor/cpu_executor.py | 2 +-
vllm/executor/gpu_executor.py | 4 +-
vllm/executor/ray_gpu_executor.py | 4 +-
vllm/executor/ray_utils.py | 6 +-
vllm/logger.py | 2 +-
vllm/lora/models.py | 6 +-
.../layers/fused_moe/fused_moe.py | 4 +-
.../model_executor/model_loader/tensorizer.py | 8 +--
.../model_loader/weight_utils.py | 14 ++---
vllm/model_executor/models/__init__.py | 10 +--
vllm/model_executor/models/gemma.py | 6 +-
vllm/spec_decode/spec_decode_worker.py | 3 +-
vllm/transformers_utils/configs/dbrx.py | 13 ++--
vllm/transformers_utils/tokenizer.py | 5 +-
vllm/utils.py | 20 +++---
vllm/worker/model_runner.py | 27 ++++----
31 files changed, 176 insertions(+), 149 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index aac8cbb63ebeb..9da5a4991734d 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -98,9 +98,10 @@ def setup(app):
for mock_target in autodoc_mock_imports:
if mock_target in sys.modules:
logger.info(
- f"Potentially problematic mock target ({mock_target}) found; "
+ "Potentially problematic mock target (%s) found; "
"autodoc_mock_imports cannot mock modules that have already "
- "been loaded into sys.modules when the sphinx build starts.")
+ "been loaded into sys.modules when the sphinx build starts.",
+ mock_target)
class MockedClassDocumenter(autodoc.ClassDocumenter):
diff --git a/pyproject.toml b/pyproject.toml
index a171d45b4e064..2e026c1ac8911 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,6 +32,7 @@ select = [
"SIM",
# isort
# "I",
+ "G",
]
ignore = [
# star imports
diff --git a/setup.py b/setup.py
index 4b672e1af8494..6ba36b85ea318 100644
--- a/setup.py
+++ b/setup.py
@@ -63,7 +63,7 @@ def compute_num_jobs(self):
num_jobs = os.environ.get("MAX_JOBS", None)
if num_jobs is not None:
num_jobs = int(num_jobs)
- logger.info(f"Using MAX_JOBS={num_jobs} as the number of jobs.")
+ logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
else:
try:
# os.sched_getaffinity() isn't universally available, so fall
@@ -81,8 +81,9 @@ def compute_num_jobs(self):
nvcc_threads = os.getenv("NVCC_THREADS", None)
if nvcc_threads is not None:
nvcc_threads = int(nvcc_threads)
- logger.info(f"Using NVCC_THREADS={nvcc_threads} as the number"
- " of nvcc threads.")
+ logger.info(
+ "Using NVCC_THREADS=%d as the number of nvcc threads.",
+ nvcc_threads)
else:
nvcc_threads = 1
num_jobs = max(1, num_jobs // nvcc_threads)
diff --git a/vllm/config.py b/vllm/config.py
index 311a69f822571..887a73d9462dc 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -167,9 +167,9 @@ def _verify_quantization(self) -> None:
f"supported in ROCm.")
if self.quantization != "marlin":
logger.warning(
- f"{self.quantization} quantization is not fully "
+ "%s quantization is not fully "
"optimized yet. The speed can be slower than "
- "non-quantized models.")
+ "non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
@@ -360,7 +360,7 @@ def verify_with_parallel_config(
if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
- logger.warning("Possibly too large swap space. " + msg)
+ logger.warning("Possibly too large swap space. %s", msg)
@dataclass
@@ -898,8 +898,8 @@ def verify_with_model_config(self, model_config: ModelConfig):
"awq", "gptq"
]:
# TODO support marlin and squeezellm
- logger.warning(f"{model_config.quantization} quantization is not "
- "tested with LoRA yet.")
+ logger.warning("%s quantization is not tested with LoRA yet.",
+ model_config.quantization)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
@@ -1008,7 +1008,7 @@ def _get_and_verify_dtype(
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
- logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
+ logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
return torch_dtype
@@ -1051,8 +1051,8 @@ def _get_and_verify_max_len(
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
- f"{possible_keys}. Assuming the model's maximum length is "
- f"{default_max_len}.")
+ "%d. Assuming the model's maximum length is %d.", possible_keys,
+ default_max_len)
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 99f7a34d336a4..ac3bd7d228e94 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -617,8 +617,9 @@ def _schedule_prefills(
if num_new_tokens > self.prompt_limit:
logger.warning(
- f"Input prompt ({num_new_tokens} tokens) is too long"
- f" and exceeds limit of {self.prompt_limit}")
+ "Input prompt (%d tokens) is too long"
+ " and exceeds limit of %d", num_new_tokens,
+ self.prompt_limit)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@@ -631,8 +632,9 @@ def _schedule_prefills(
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
- f"Input prompt ({num_new_tokens} tokens) is too long"
- f" and exceeds the capacity of block_manager")
+ "Input prompt (%d tokens) is too long"
+ " and exceeds the capacity of block_manager",
+ num_new_tokens)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index 9dbb427d91ff1..ec4533326e841 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -37,7 +37,7 @@ def init_custom_ar() -> None:
return
if world_size not in _SUPPORTED_WORLD_SIZES:
- logger.warn(
+ logger.warning(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size,
@@ -47,7 +47,7 @@ def init_custom_ar() -> None:
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
- logger.warn(
+ logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
@@ -62,7 +62,7 @@ def init_custom_ar() -> None:
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
- logger.warn(
+ logger.warning(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
@@ -71,7 +71,7 @@ def init_custom_ar() -> None:
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
- logger.warn(
+ logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
index e922beba44bfa..9434867e1b120 100644
--- a/vllm/distributed/device_communicators/pynccl.py
+++ b/vllm/distributed/device_communicators/pynccl.py
@@ -43,15 +43,16 @@
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
- f"Failed to load NCCL library from {so_file} ."
+ "Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
- f"or it does not support the current platform {platform.platform()}."
- f"One solution is to download libnccl2 version 2.18 from "
- f"https://developer.download.nvidia.com/compute/cuda/repos/ "
- f"and extract the libnccl.so.2 file. If you already have the "
- f"library, please set the environment variable VLLM_NCCL_SO_PATH"
- " to point to the correct nccl library path.")
+ "or it does not support the current platform %s."
+ "One solution is to download libnccl2 version 2.18 from "
+ "https://developer.download.nvidia.com/compute/cuda/repos/ "
+ "and extract the libnccl.so.2 file. If you already have the "
+ "library, please set the environment variable VLLM_NCCL_SO_PATH"
+ " to point to the correct nccl library path.", so_file,
+ platform.platform())
raise e
# === export types and functions from nccl to Python ===
diff --git a/vllm/distributed/device_communicators/pynccl_utils.py b/vllm/distributed/device_communicators/pynccl_utils.py
index a717fddb695ba..44e4f39217a41 100644
--- a/vllm/distributed/device_communicators/pynccl_utils.py
+++ b/vllm/distributed/device_communicators/pynccl_utils.py
@@ -14,7 +14,7 @@
except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
- logger.info(f"Failed to import NCCL library: {e}")
+ logger.info("Failed to import NCCL library: %s", e)
logger.info("It is expected if you are not running on NVIDIA GPUs.")
pass
@@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
assert not is_initialized()
global comm
- logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
+ logger.info("vLLM is using nccl==%s", ncclGetVersion())
comm = NCCLCommunicator(group=group)
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 515f2212511b7..6ca6fc5b5f9fe 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -57,8 +57,10 @@ def init_distributed_environment(
local_rank: int = -1,
backend: str = "nccl",
):
- logger.debug(f"{world_size=} {rank=} {local_rank=} "
- f"{distributed_init_method=} {backend=}")
+ logger.debug(
+ "world_size=%d rank=%d local_rank=%d "
+ "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
+ distributed_init_method, backend)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py
index e0a871ebe1756..9a13b94c3ada1 100644
--- a/vllm/distributed/utils.py
+++ b/vllm/distributed/utils.py
@@ -112,7 +112,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
- logger.info(f"generating GPU P2P access cache for in {path}")
+ logger.info("generating GPU P2P access cache for in %s", path)
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
@@ -126,7 +126,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
- logger.info(f"reading GPU P2P access cache from {path}")
+ logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 4b007d71e9cfc..518532e4a280d 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -117,7 +117,7 @@ def process_request_output(self,
self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
- logger.info(f"Finished request {request_id}.")
+ logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def process_exception(self,
@@ -128,7 +128,7 @@ def process_exception(self,
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
- logger.info(f"Finished request {request_id}.")
+ logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def add_request(self, request_id: str,
@@ -151,7 +151,7 @@ def add_request(self, request_id: str,
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
- logger.info(f"Aborted request {request_id}.")
+ logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id)
@@ -521,11 +521,11 @@ async def add_request(
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
- logger.info(f"Received request {request_id}: "
- f"prompt: {shortened_prompt!r}, "
- f"sampling_params: {sampling_params}, "
- f"prompt_token_ids: {shortened_token_ids}, "
- f"lora_request: {lora_request}.")
+ logger.info(
+ "Received request %s: prompt: %r, "
+ "sampling_params: %s, prompt_token_ids: %s, "
+ "lora_request: %s.", request_id, shortened_prompt,
+ sampling_params, shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
@@ -717,4 +717,4 @@ async def check_health(self) -> None:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
- logger.debug(f"Health check took {time.perf_counter()-t}s")
+ logger.debug("Health check took %fs", time.perf_counter() - t)
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 7de60d738113e..d2f5379e621c6 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -96,29 +96,38 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None:
logger.info(
- f"Initializing an LLM engine (v{vllm.__version__}) with config: "
- f"model={model_config.model!r}, "
- f"speculative_config={speculative_config!r}, "
- f"tokenizer={model_config.tokenizer!r}, "
- f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
- f"tokenizer_mode={model_config.tokenizer_mode}, "
- f"revision={model_config.revision}, "
- f"tokenizer_revision={model_config.tokenizer_revision}, "
- f"trust_remote_code={model_config.trust_remote_code}, "
- f"dtype={model_config.dtype}, "
- f"max_seq_len={model_config.max_model_len}, "
- f"download_dir={load_config.download_dir!r}, "
- f"load_format={load_config.load_format}, "
- f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
- f"disable_custom_all_reduce="
- f"{parallel_config.disable_custom_all_reduce}, "
- f"quantization={model_config.quantization}, "
- f"enforce_eager={model_config.enforce_eager}, "
- f"kv_cache_dtype={cache_config.cache_dtype}, "
- f"quantization_param_path={model_config.quantization_param_path}, "
- f"device_config={device_config.device}, "
- f"decoding_config={decoding_config!r}, "
- f"seed={model_config.seed})")
+ "Initializing an LLM engine (v%s) with config: "
+ "model=%r, speculative_config=%r, tokenizer=%r, "
+ "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
+ "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
+ "max_seq_len=%d, download_dir=%r, load_format=%s, "
+ "tensor_parallel_size=%d, disable_custom_all_reduce=%s"
+ "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
+ "quantization_param_path=%s, device_config=%s, "
+ "decoding_config=%r, seed=%d)",
+ vllm.__version__,
+ model_config.model,
+ speculative_config,
+ model_config.tokenizer,
+ model_config.skip_tokenizer_init,
+ model_config.tokenizer_mode,
+ model_config.revision,
+ model_config.tokenizer_revision,
+ model_config.trust_remote_code,
+ model_config.dtype,
+ model_config.max_model_len,
+ load_config.download_dir,
+ load_config.load_format,
+ parallel_config.tensor_parallel_size,
+ parallel_config.disable_custom_all_reduce,
+ model_config.quantization,
+ model_config.enforce_eager,
+ cache_config.cache_dtype,
+ model_config.quantization_param_path,
+ device_config.device,
+ decoding_config,
+ model_config.seed,
+ )
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
@@ -237,8 +246,10 @@ def _initialize_kv_caches(self) -> None:
if self.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
- logger.info(f"Overriding {num_gpu_blocks=} with "
- f"{num_gpu_blocks_override=}")
+ logger.info(
+ "Overriding num_gpu_blocks=%d with "
+ "num_gpu_blocks_override=%d", num_gpu_blocks,
+ num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override
self.cache_config.num_gpu_blocks = num_gpu_blocks
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index 25e96f6c7eaf7..d3560f5fefff1 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -227,14 +227,19 @@ def log(self, stats: Stats) -> None:
# Log to stdout.
logger.info(
- f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
- f"Avg generation throughput: "
- f"{generation_throughput:.1f} tokens/s, "
- f"Running: {stats.num_running} reqs, "
- f"Swapped: {stats.num_swapped} reqs, "
- f"Pending: {stats.num_waiting} reqs, "
- f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, "
- f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%")
+ "Avg prompt throughput: %.1f tokens/s, "
+ "Avg generation throughput: %.1f tokens/s, "
+ "Running: %d reqs, Swapped: %d reqs, "
+ "Pending: %d reqs, GPU KV cache usage: %.1f%, "
+ "CPU KV cache usage: %.1f%",
+ prompt_throughput,
+ generation_throughput,
+ stats.num_running,
+ stats.num_swapped,
+ stats.num_waiting,
+ stats.gpu_cache_usage * 100,
+ stats.cpu_cache_usage * 100,
+ )
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 37d76b8e74055..af9ba7a3bc825 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -148,8 +148,8 @@ async def authentication(request: Request, call_next):
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
- logger.info(f"vLLM API server version {vllm.__version__}")
- logger.info(f"args: {args}")
+ logger.info("vLLM API server version %s", vllm.__version__)
+ logger.info("args: %s", args)
if args.served_model_name is not None:
served_model_names = args.served_model_name
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 2ff335eb71073..f6011b6fc4cb6 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -57,8 +57,7 @@ async def create_chat_completion(
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
- logger.error(
- f"Error in applying chat template from request: {str(e)}")
+ logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
@@ -338,11 +337,11 @@ def _load_chat_template(self, chat_template):
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
- logger.info(
- f"Using supplied chat template:\n{tokenizer.chat_template}")
+ logger.info("Using supplied chat template:\n%s",
+ tokenizer.chat_template)
elif tokenizer.chat_template is not None:
- logger.info(
- f"Using default chat template:\n{tokenizer.chat_template}")
+ logger.info("Using default chat template:\n%s",
+ tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index 8d6a1fff91fd8..aa810f9743395 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -69,7 +69,7 @@ def initialize_cache(self, num_gpu_blocks: int,
# NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
- logger.info(f"# CPU blocks: {num_gpu_blocks}")
+ logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self,
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index d413a7d27ff37..d2c60a3b68e14 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -116,8 +116,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
- logger.info(f"# GPU blocks: {num_gpu_blocks}, "
- f"# CPU blocks: {num_cpu_blocks}")
+ logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
+ num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 14b3f803782c6..6f72babe14fd5 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -214,8 +214,8 @@ def initialize_cache(self, num_gpu_blocks: int,
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
- logger.info(f"# GPU blocks: {num_gpu_blocks}, "
- f"# CPU blocks: {num_cpu_blocks}")
+ logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
+ num_cpu_blocks)
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py
index febae42b84549..9db3ae2ff8298 100644
--- a/vllm/executor/ray_utils.py
+++ b/vllm/executor/ray_utils.py
@@ -43,9 +43,9 @@ def execute_model_compiled_dag_remote(self, ignored):
return output
except ImportError as e:
- logger.warning(f"Failed to import Ray with {e!r}. "
- "For distributed inference, please install Ray with "
- "`pip install ray`.")
+ logger.warning(
+ "Failed to import Ray with %r. For distributed inference, "
+ "please install Ray with `pip install ray`.", e)
ray = None # type: ignore
RayWorkerWrapper = None # type: ignore
diff --git a/vllm/logger.py b/vllm/logger.py
index 341fc473585d7..3928e5367d1e6 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -126,7 +126,7 @@ def enable_trace_function_call(log_file_path: str,
"VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
- logger.info(f"Trace frame log is saved to {log_file_path}")
+ logger.info("Trace frame log is saved to %s", log_file_path)
if root_dir is None:
# by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index c249497a4d893..6a077e9b0c755 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -345,8 +345,8 @@ def activate_lora(
index, _ = first_free_slot
self._active_loras[lora_id] = None
lora_model = self._registered_loras[lora_id]
- logger.debug(
- f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
+ logger.debug("Activating LoRA. int id: %d, slot index: %d",
+ lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
@@ -567,7 +567,7 @@ def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: int, value: LoRAModel):
- logger.debug(f"Removing LoRA. int id: {key}")
+ logger.debug("Removing LoRA. int id: %d", key)
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index ac7c30e2a9727..aed2c350bdd10 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -296,8 +296,8 @@ def get_moe_configs(E: int, N: int,
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
- logger.info(
- f"Using configuration from {config_file_path} for MoE layer.")
+ logger.info("Using configuration from %s for MoE layer.",
+ config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py
index 16be0ecf9ce07..7e65d54bc522f 100644
--- a/vllm/model_executor/model_loader/tensorizer.py
+++ b/vllm/model_executor/model_loader/tensorizer.py
@@ -334,10 +334,10 @@ def deserialize(self):
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
- logger.info(f"Deserialized {total_bytes_str} in "
- f"{end - start:0.2f}s, {per_second}/s")
- logger.info(f"Memory usage before: {before_mem}")
- logger.info(f"Memory usage after: {after_mem}")
+ logger.info("Deserialized %s in %0.2fs, %f/s", total_bytes_str,
+ end - start, per_second)
+ logger.info("Memory usage before: %s", before_mem)
+ logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 9995f2afe3cf7..c0905b9051314 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -190,7 +190,7 @@ def download_weights_from_hf(model_name_or_path: str,
allow_patterns = [pattern]
break
- logger.info(f"Using model weights format {allow_patterns}")
+ logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
@@ -310,17 +310,17 @@ def kv_cache_scales_loader(
return layer_scales_map.items()
except FileNotFoundError:
- logger.error(f"File or directory '{filename}' not found.")
+ logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
- logger.error(f"Error decoding JSON in file '{filename}'.")
+ logger.error("Error decoding JSON in file '%s'.", filename)
except Exception as e:
- logger.error(f"An error occurred while reading '{filename}': {e}")
+ logger.error("An error occurred while reading '%s': %s", filename, e)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
- logger.warning("Defaulting to KV cache scaling factors = 1.0 "
- f"for all layers in TP rank {tp_rank} "
- "as an error occurred during loading.")
+ logger.warning(
+ "Defaulting to KV cache scaling factors = 1.0 for all "
+ "layers in TP rank %d as an error occurred during loading.", tp_rank)
return []
diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py
index 6afb2f31c1334..c5cdc059473b3 100755
--- a/vllm/model_executor/models/__init__.py
+++ b/vllm/model_executor/models/__init__.py
@@ -91,8 +91,8 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
- f"Model architecture {model_arch} is partially supported "
- "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
+ "Model architecture %s is partially supported by ROCm: %s",
+ model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
@@ -107,9 +107,9 @@ def get_supported_archs() -> List[str]:
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
- f"Model architecture {model_arch} is already registered, "
- "and will be overwritten by the new model "
- f"class {model_cls.__name__}.")
+ "Model architecture %s is already registered, and will be "
+ "overwritten by the new model class %s.", model_arch,
+ model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 6d01537c5c344..c3193258d6418 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -55,10 +55,10 @@ def _get_gemma_act_fn(
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
- f"`{hidden_act}`, edit the config JSON to set "
- f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
+ "`%s`, edit the config JSON to set "
+ "`hidden_activation=%s` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
- "for more details.")
+ "for more details.", hidden_act, hidden_act)
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh")
diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py
index 2c6642f5a3c81..4e70ea9686005 100644
--- a/vllm/spec_decode/spec_decode_worker.py
+++ b/vllm/spec_decode/spec_decode_worker.py
@@ -183,7 +183,8 @@ def execute_model(
"speculative decoding "
"requires non-None seq_group_metadata_list")
- logger.info(f"spec_decode_worker.execute_model {num_lookahead_slots=}")
+ logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
+ num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py
index 1d2724f22abd6..0dc9664723d34 100644
--- a/vllm/transformers_utils/configs/dbrx.py
+++ b/vllm/transformers_utils/configs/dbrx.py
@@ -72,9 +72,10 @@ def from_pretrained(
and config_dict["model_type"] != cls.model_type
):
logger.warning(
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
- + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
- )
+ "You are using a model of type %s to instantiate a model of "
+ "type %s. This is not supported for all configurations of "
+ "models and can yield errors.",
+ config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs)
@@ -151,9 +152,9 @@ def from_pretrained(
and config_dict["model_type"] != cls.model_type
):
logger.warning(
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
- + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
- )
+ "You are using a model of type %s to instantiate a model of "
+ "type %s. This is not supported for all "
+ "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type)
return cls.from_dict(config_dict, **kwargs)
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index afc02c434dd43..2fcddc3bea5ab 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -138,9 +138,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
- f"No tokenizer found in {lora_request.lora_local_path}, "
- "using base model tokenizer instead. "
- f"(Exception: {str(e)})")
+ "No tokenizer found in %s, using base model tokenizer instead. "
+ "(Exception: %s)", lora_request.lora_local_path, e)
tokenizer = None
return tokenizer
diff --git a/vllm/utils.py b/vllm/utils.py
index 79ac1db01fc69..76c2fc66e47c3 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -289,8 +289,9 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items():
if k in os.environ and os.environ[k] != v:
- logger.warning(f"Overwriting environment variable {k} "
- f"from '{os.environ[k]}' to '{v}'")
+ logger.warning(
+ "Overwriting environment variable %s "
+ "from '%s' to '%s'", k, os.environ[k], v)
os.environ[k] = v
@@ -310,11 +311,12 @@ def get_nvcc_cuda_version() -> Optional[Version]:
if not cuda_home:
cuda_home = '/usr/local/cuda'
if os.path.isfile(cuda_home + '/bin/nvcc'):
- logger.info(f'CUDA_HOME is not found in the environment. '
- f'Using {cuda_home} as CUDA_HOME.')
+ logger.info(
+ 'CUDA_HOME is not found in the environment. '
+ 'Using %s as CUDA_HOME.', cuda_home)
else:
- logger.warning(
- f'Not found nvcc in {cuda_home}. Skip cuda version check!')
+ logger.warning('Not found nvcc in %s. Skip cuda version check!',
+ cuda_home)
return None
nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
universal_newlines=True)
@@ -599,8 +601,8 @@ def find_nccl_library():
# manually load the nccl library
if so_file:
logger.info(
- f"Found nccl from environment variable VLLM_NCCL_SO_PATH={so_file}"
- )
+ "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s",
+ so_file)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2")
@@ -608,7 +610,7 @@ def find_nccl_library():
so_file = find_library("librccl.so.1")
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
- logger.info(f"Found nccl from library {so_file}")
+ logger.info("Found nccl from library %s", so_file)
return so_file
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 33dbf8d90c35d..c6da28f110325 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -170,8 +170,8 @@ def load_model(self) -> None:
)
self.model_memory_usage = m.consumed_memory
- logger.info(f"Loading model weights took "
- f"{self.model_memory_usage / float(2**30):.4f} GB")
+ logger.info("Loading model weights took %.4f GB",
+ self.model_memory_usage / float(2**30))
if self.lora_config:
assert hasattr(self.model, "supported_lora_modules"
@@ -196,18 +196,19 @@ def load_model(self) -> None:
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
else:
- raise RuntimeError("Using FP8 KV cache and scaling "
- "factors provided but model "
- f"{self.model.__class__} does not "
- "support loading scaling factors.")
+ raise RuntimeError(
+ "Using FP8 KV cache and scaling factors provided but "
+ "model %s does not support loading scaling factors.",
+ self.model.__class__)
else:
- logger.warn("Using FP8 KV cache but no scaling factors "
- "provided. Defaulting to scaling factors of 1.0. "
- "This may lead to less accurate results!")
+ logger.warning(
+ "Using FP8 KV cache but no scaling factors "
+ "provided. Defaulting to scaling factors of 1.0. "
+ "This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
- logger.warn("KV cache scaling factors provided, "
- "but the KV cache data type is not FP8. "
- "KV cache scaling factors will not be used.")
+ logger.warning("KV cache scaling factors provided, "
+ "but the KV cache data type is not FP8. "
+ "KV cache scaling factors will not be used.")
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
@@ -1054,7 +1055,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
end_time = time.perf_counter()
elapsed_time = end_time - start_time
# This usually takes < 10 seconds.
- logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
+ logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
def __del__(self) -> None:
# Delete the CUDA graphs before deleting the pynccl communicator.
From 603ad8481594321ceae7d54e2c0050b3638c6502 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Fri, 26 Apr 2024 22:02:02 +0900
Subject: [PATCH 130/413] [Core] Refactoring sampler and support prompt logprob
for chunked prefill (#4309)
---
tests/samplers/test_logprobs.py | 44 +-
tests/samplers/test_sampler.py | 47 +-
tests/test_logits_processor.py | 10 +-
tests/worker/test_model_runner.py | 19 +-
vllm/core/scheduler.py | 15 +
vllm/engine/async_llm_engine.py | 2 +-
vllm/engine/llm_engine.py | 25 +-
vllm/engine/output_processor/interfaces.py | 6 +
vllm/engine/output_processor/multi_step.py | 9 +
vllm/engine/output_processor/single_step.py | 22 +-
vllm/engine/output_processor/util.py | 7 +-
.../model_executor/layers/logits_processor.py | 27 +-
vllm/model_executor/layers/sampler.py | 544 ++++++++++++------
vllm/model_executor/sampling_metadata.py | 349 ++++++++---
vllm/sequence.py | 11 +-
vllm/worker/cpu_model_runner.py | 116 +---
vllm/worker/model_runner.py | 123 +---
vllm/worker/neuron_model_runner.py | 119 +---
18 files changed, 862 insertions(+), 633 deletions(-)
diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py
index 41b7f3da1e839..57d6d2a410ee5 100644
--- a/tests/samplers/test_logprobs.py
+++ b/tests/samplers/test_logprobs.py
@@ -9,15 +9,26 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
+@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
+ chunked_prefill_token_size: int,
+ num_top_logprobs: int,
example_prompts,
):
+ max_num_seqs = 256
+ enable_chunked_prefill = False
+ max_num_batched_tokens = None
+ if chunked_prefill_token_size != -1:
+ enable_chunked_prefill = True
+ max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
+ max_num_batched_tokens = chunked_prefill_token_size
+
max_tokens = 5
- num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
@@ -25,10 +36,17 @@ def test_get_prompt_logprobs(
)
del hf_model
- vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
+ vllm_model = vllm_runner(
+ model,
+ dtype=dtype,
+ max_logprobs=num_top_logprobs,
+ enable_chunked_prefill=enable_chunked_prefill,
+ max_num_batched_tokens=max_num_batched_tokens,
+ max_num_seqs=max_num_seqs,
+ )
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
- prompt_logprobs=5,
+ prompt_logprobs=num_top_logprobs,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
@@ -52,9 +70,18 @@ def test_get_prompt_logprobs(
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
+ # The first prompt logprob is always None
+ assert result.prompt_logprobs[0] is None
+ for prompt_logprobs in result.prompt_logprobs[1:]:
+ # If the prompt token is not included in the top X
+ # logprob, it can return 1 more data
+ assert (len(prompt_logprobs) == num_top_logprobs
+ or len(prompt_logprobs) == num_top_logprobs + 1)
+
# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
+ # The first prompt logprob is always None, so we compare it from 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
@@ -74,6 +101,17 @@ def test_get_prompt_logprobs(
"The token should be decoded by the time it is returned "
" to the user.")
+ # Test if prompt logprobs are correctly set.
+ for vllm_result in vllm_results:
+ token_ids = vllm_result.prompt_token_ids
+ prompt_logprobs = vllm_result.prompt_logprobs
+
+ # The first token doesn't have logprob.
+ assert prompt_logprobs[0] is None
+
+ for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
+ assert token_id in logprob_dict
+
def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py
index 52a2b0ca52aaa..6f2145f8cdcf4 100644
--- a/tests/samplers/test_sampler.py
+++ b/tests/samplers/test_sampler.py
@@ -8,6 +8,7 @@
from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler
+from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter
@@ -54,6 +55,7 @@ def _do_sample(
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
+ device: str,
):
seq_group_metadata_list = []
prompt_lens = []
@@ -68,9 +70,12 @@ def _do_sample(
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens=prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ subquery_lens=prompt_lens,
+ device=device,
+ pin_memory=model_runner.pin_memory)
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str):
sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
- sampling_params)
+ sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str):
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
- sampling_params)
+ sampling_params, device)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
- sampling_params)
+ sampling_params, device)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
@@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
- model_runner, sampling_params)
+ model_runner, sampling_params, device)
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
- model_runner, sampling_params)
+ model_runner, sampling_params, device)
assert first_sampler_output == second_sampler_output
@@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str):
best_of=2,
use_beam_search=True,
)
- _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params)
+ _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
+ device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
@@ -443,10 +449,12 @@ def run_test_case(*,
"batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
- sampling_metadata = model_runner._prepare_sample(
+ sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
prompt_lens=prompt_lens if prompt_lens else None,
- subquery_lens=prompt_lens if prompt_lens else None)
+ subquery_lens=prompt_lens if prompt_lens else None,
+ device=device,
+ pin_memory=model_runner.pin_memory)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
@@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str):
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner):
- sampling_metadata = model_runner._prepare_sample(
- seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ subquery_lens=prompt_lens,
+ device=device,
+ pin_memory=model_runner.pin_memory)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)
@@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens=prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ subquery_lens=prompt_lens,
+ device=device,
+ pin_memory=model_runner.pin_memory)
sample_probs = None
diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py
index 5bb93ca74855b..dbaeb4de18258 100644
--- a/tests/test_logits_processor.py
+++ b/tests/test_logits_processor.py
@@ -6,6 +6,7 @@
import torch
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
@@ -82,9 +83,12 @@ def pick_ith(token_ids, logits):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens=prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ subquery_lens=prompt_lens,
+ device=model_runner.device,
+ pin_memory=model_runner.pin_memory)
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py
index 59bed2ce0dad3..abb401f25c100 100644
--- a/tests/worker/test_model_runner.py
+++ b/tests/worker/test_model_runner.py
@@ -2,6 +2,7 @@
import torch
from vllm.config import ModelConfig, SchedulerConfig
+from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size):
assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions)
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens=prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ subquery_lens=prompt_lens,
+ device=model_runner.device,
+ pin_memory=model_runner.pin_memory)
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices
@@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size):
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
- sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens=prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ subquery_lens=prompt_lens,
+ device=model_runner.device,
+ pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index ac3bd7d228e94..7439f7dc33e8d 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -915,6 +915,20 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
+ do_sample = True
+ if seq_group.is_prefill():
+ seqs = seq_group.get_seqs()
+ # Prefill has only 1 sequence.
+ assert len(seqs) == 1
+ # In the next iteration, all prompt tokens are not computed.
+ # It means the prefill is chunked, and we don't need sampling.
+ # NOTE: We use get_len instead of get_prompt_len because when
+ # a sequence is preempted, prefill includes previous generated
+ # output tokens.
+ if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
+ seqs[0].data.get_len()):
+ do_sample = False
+
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = seq_group.is_prefill()
@@ -924,6 +938,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
+ do_sample=do_sample,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 518532e4a280d..89ee3f0db491c 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -219,7 +219,7 @@ async def step_async(self) -> List[RequestOutput]:
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
- scheduler_outputs.ignored_seq_groups)
+ scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
if self.log_stats:
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index d2f5379e621c6..741d3bcd80890 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -22,7 +22,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
- SequenceGroup, SequenceStage)
+ SequenceGroup, SequenceGroupMetadata)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -476,9 +476,12 @@ def has_unfinished_requests(self) -> bool:
return self.scheduler.has_unfinished_seqs()
def _process_model_outputs(
- self, output: List[SamplerOutput],
- scheduled_seq_groups: List[SequenceGroup],
- ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
+ self,
+ output: List[SamplerOutput],
+ scheduled_seq_groups: List[SequenceGroup],
+ ignored_seq_groups: List[SequenceGroup],
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ ) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
@@ -492,17 +495,15 @@ def _process_model_outputs(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
- for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
- output_by_sequence_group):
+ for scheduled_seq_group, outputs, seq_group_meta in zip(
+ scheduled_seq_groups, output_by_sequence_group,
+ seq_group_metadata_list):
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
- # If all sequences in the sequence group are in DECODE, then we can
- # process the output tokens. Otherwise, they are (chunked) prefill
- # samples and should not be processed.
- stages = [seq.data._stage for seq in seq_group.seqs_dict.values()]
- if all(stage == SequenceStage.DECODE for stage in stages):
+ self.output_processor.process_prompt_logprob(seq_group, outputs)
+ if seq_group_meta.do_sample:
self.output_processor.process_outputs(seq_group, outputs)
# Free the finished sequence groups.
@@ -585,7 +586,7 @@ def step(self) -> List[RequestOutput]:
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
- scheduler_outputs.ignored_seq_groups)
+ scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Log stats.
if self.log_stats:
diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py
index f307ea4da3011..9ddb6a3648b8c 100644
--- a/vllm/engine/output_processor/interfaces.py
+++ b/vllm/engine/output_processor/interfaces.py
@@ -68,3 +68,9 @@ def process_outputs(self, sequence_group: SequenceGroup,
scheduler.
"""
pass
+
+ @abstractmethod
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
+ outputs: List[SequenceGroupOutput]) -> None:
+ """Update prompt logprobs received from outputs to seq_group."""
+ pass
diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py
index 39e99d06ed875..9abd87a4d5a9a 100644
--- a/vllm/engine/output_processor/multi_step.py
+++ b/vllm/engine/output_processor/multi_step.py
@@ -44,6 +44,15 @@ def __init__(
self.get_tokenizer_for_seq = get_tokenizer_for_seq
self.stop_checker = stop_checker
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
+ outputs: List[SequenceGroupOutput]) -> None:
+ # TODO(sang): Prompt logprob currently not implemented in multi step
+ # workers.
+ logger.warning(
+ "Prompt logprob is not supported by multi step workers. "
+ "(e.g., speculative decode uses multi step workers).")
+ pass
+
def process_outputs(self, sequence_group: SequenceGroup,
outputs: List[SequenceGroupOutput]) -> None:
"""Append new tokens in the outputs to sequences in the sequence group.
diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py
index 7e9d652446703..07b140584bbe2 100644
--- a/vllm/engine/output_processor/single_step.py
+++ b/vllm/engine/output_processor/single_step.py
@@ -55,17 +55,23 @@ def process_outputs(self, sequence_group: SequenceGroup,
), f"{type(self)} does not support multiple outputs per step"
return self._process_sequence_group_outputs(sequence_group, outputs[0])
- def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
- outputs: SequenceGroupOutput) -> None:
-
- # Process prompt logprobs
- prompt_logprobs = outputs.prompt_logprobs
- if prompt_logprobs is not None and \
- seq_group.sampling_params.detokenize and self.detokenizer:
+ def process_prompt_logprob(self, seq_group: SequenceGroup,
+ outputs: List[SequenceGroupOutput]) -> None:
+ assert len(outputs) == 1, ("Single step should only has 1 output.")
+ output = outputs[0]
+ prompt_logprobs = output.prompt_logprobs
+ if (prompt_logprobs is not None
+ and seq_group.sampling_params.detokenize and self.detokenizer):
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
- seq_group.prompt_logprobs = prompt_logprobs
+ if not seq_group.prompt_logprobs:
+ # The first prompt token's logprob is None because it doesn't
+ # have tokens that are precedent.
+ seq_group.prompt_logprobs = [None]
+ seq_group.prompt_logprobs.extend(prompt_logprobs)
+ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
+ outputs: SequenceGroupOutput) -> None:
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py
index d076fee8c2a36..9816e966c1e36 100644
--- a/vllm/engine/output_processor/util.py
+++ b/vllm/engine/output_processor/util.py
@@ -1,10 +1,11 @@
from typing import List
-from vllm.sequence import SamplerOutput
+from vllm.sequence import SamplerOutput, SequenceGroupOutput
-def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput],
- num_seq_groups: int):
+def create_output_by_sequence_group(
+ sampler_outputs: List[SamplerOutput],
+ num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py
index e556e31f99378..22620d9fc86d9 100644
--- a/vllm/model_executor/layers/logits_processor.py
+++ b/vllm/model_executor/layers/logits_processor.py
@@ -83,30 +83,27 @@ def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
- logits_row_idx = 0
found_logits_processors = False
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
+ logits_processed = 0
+ for seq_group in sampling_metadata.seq_groups:
+ seq_ids = seq_group.seq_ids
+ sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
- # handle prompt_logprobs by skipping rows in logits added for
- # the prompt tokens (prompt logprobs are not processed)
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- assert len(seq_ids) == 1
- logits_row_idx += sampling_metadata.prompt_lens[i] - 1
if logits_processors:
found_logits_processors = True
- for seq_id in seq_ids:
+ for seq_id, logits_row_idx in zip(seq_ids,
+ seq_group.sample_indices):
logits_row = logits[logits_row_idx]
- token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
+ token_ids = seq_group.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
- logits_row_idx += 1
- else:
- logits_row_idx += len(seq_ids)
+
+ logits_processed += len(seq_group.sample_indices) + len(
+ seq_group.prompt_logprob_indices)
+
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
- assert logits_row_idx == logits.shape[0]
+ assert logits_processed == logits.shape[0]
return logits
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index c4b11cb33a677..2ffa8227cc4ed 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -7,11 +7,11 @@
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
- SamplingTensors)
-from vllm.sampling_params import SamplingParams, SamplingType
+ SamplingTensors,
+ SequenceGroupToSample)
+from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
- SamplerOutput, SequenceData, SequenceGroupOutput,
- SequenceOutput)
+ SamplerOutput, SequenceGroupOutput, SequenceOutput)
class Sampler(nn.Module):
@@ -48,11 +48,14 @@ def forward(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
+ """
+ Args:
+ logits: (num_tokens, vocab_size).
+ sampling_metadata: Metadata for sampling.
+ """
assert logits is not None
_, vocab_size = logits.shape
- # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
- # have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking.
@@ -83,7 +86,6 @@ def forward(
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
- # Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
@@ -149,24 +151,28 @@ def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
+ """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
+ have not been generated yet
+ """
# list of indices in logits that will be set to -inf
logits_to_penalize = []
- start_idx = 0
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
-
- # handle prompt_logprobs by skipping rows in logits added for the prompt
- # tokens (prompt logprobs are not penalized)
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- assert len(seq_ids) == 1
- start_idx += sampling_metadata.prompt_lens[i] - 1
+ logits_applied = 0
+ for seq_group in sampling_metadata.seq_groups:
+ seq_ids = seq_group.seq_ids
+ sampling_params = seq_group.sampling_params
+
+ sample_indices = seq_group.sample_indices
+ logits_applied += len(sample_indices) + len(
+ seq_group.prompt_logprob_indices)
+ if not seq_group.do_sample:
+ continue
+ start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
- seq_data = sampling_metadata.seq_data[seq_id]
+ seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
@@ -180,15 +186,13 @@ def _apply_min_tokens_penalty(
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
- start_idx += len(seq_ids)
-
if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly
- assert start_idx == logits.shape[0]
+ assert logits_applied == logits.shape[0]
return logits
@@ -265,14 +269,30 @@ def _apply_min_p(
def _greedy_sample(
- selected_seq_groups: List[Tuple[List[int], SamplingParams]],
+ selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
+ """Run greedy sampling on a given samples.
+
+ Args:
+ selected_seq_groups: A list of sequence groups batched.
+ samples: (num_selected_samples,) A tensor of samples. The length of
+ samples could be smaller than selected_seq_groups if
+ seq_group.do_sample is False.
+ Returns:
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
+ same as the length of selected_seq_groups. If the corresponding
+ seq_group has do_sample=False, tuple contains ([], [])
+ """
samples = samples.tolist()
sample_idx = 0
results = []
for seq_group in selected_seq_groups:
- seq_ids, _ = seq_group
+ if not seq_group.do_sample:
+ results.append(([], []))
+ continue
+
+ seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
@@ -284,16 +304,33 @@ def _greedy_sample(
def _random_sample(
- selected_seq_groups: List[Tuple[List[int], SamplingParams]],
- is_prompts: List[bool],
+ selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
+ """Run random sampling on a given samples.
+
+ Args:
+ selected_seq_groups: A list of sequence groups batched.
+ random_samples: (num_selected_samples,) A tensor of samples. The
+ length of samples could be smaller than selected_seq_groups if
+ seq_group.do_sample is False.
+ Returns:
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
+ same as the length of selected_seq_groups. If the corresponding
+ seq_group has do_sample=False, tuple contains ([], [])
+ """
# Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
results = []
- for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
- seq_ids, sampling_params = seq_group
+ for seq_group in selected_seq_groups:
+ if not seq_group.do_sample:
+ results.append(([], []))
+ continue
+
+ seq_ids = seq_group.seq_ids
+ sampling_params = seq_group.sampling_params
+ is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
@@ -311,11 +348,20 @@ def _random_sample(
def _beam_search_sample(
- selected_seq_groups: List[Tuple[List[int], SamplingParams]],
- is_prompts: List[bool],
- seq_data: Dict[int, SequenceData],
+ selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
+ """Run beam sampling on a given samples.
+
+ Args:
+ selected_seq_groups: A list of sequence groups batched.
+ logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
+ on selected sample indices.
+ Returns:
+ Tuple of (next_token_ids, parent_ids). The length of returned list is
+ same as the length of selected_seq_groups. If the corresponding
+ seq_group has do_sample=False, tuple contains ([], [])
+ """
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
@@ -327,8 +373,13 @@ def _beam_search_sample(
# other sampling methods.
sample_idx = 0
results = []
- for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
- seq_ids, sampling_params = seq_group
+ for seq_group in selected_seq_groups:
+ if not seq_group.do_sample:
+ results.append(([], []))
+ continue
+
+ is_prompt = seq_group.is_prompt
+ seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
@@ -343,7 +394,8 @@ def _beam_search_sample(
else:
# Generation phase.
cumulative_logprobs = [
- seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
+ seq_group.seq_data[seq_id].cumulative_logprob
+ for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs,
@@ -371,8 +423,7 @@ def _beam_search_sample(
def _multinomial(
probs: torch.Tensor,
num_samples: int,
- seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
- generators: Optional[List[torch.Generator]] = None,
+ seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
@@ -388,9 +439,11 @@ def _multinomial(
q.exponential_()
else:
sample_idx = 0
- for (seq_ids, _), generator in zip(seq_groups, generators):
+ for seq_group in seq_groups:
+ seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
- q[sample_idx:next_sample_idx].exponential_(generator=generator)
+ q[sample_idx:next_sample_idx].exponential_(
+ generator=seq_group.generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
@@ -405,7 +458,7 @@ def _sample_with_torch(
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
- _, sampling_params = seq_group
+ sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
@@ -429,13 +482,11 @@ def _sample_with_torch(
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
- seq_group_ids = categorized_seq_group_ids[sampling_type]
- seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
- is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
- sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
- is_prompts, sample_indices)
- long_sample_indices = sample_indices.long()
+ seq_group_id = categorized_seq_group_ids[sampling_type]
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
+ sample_metadata[sampling_type] = (seq_group_id, seq_groups)
+ long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
@@ -455,14 +506,13 @@ def _sample_with_torch(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
- for seq_group, is_prompt in zip(seq_groups, is_prompts):
- if is_prompt:
- _, sampling_params = seq_group
+ for seq_group in seq_groups:
+ if seq_group.is_prompt:
+ sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
- "generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
@@ -481,25 +531,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
-
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
- seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
- sampling_type]
+ (seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
- sample_results = _random_sample(seq_groups, is_prompts,
+ sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
- sample_results = _beam_search_sample(seq_groups, is_prompts,
- sampling_metadata.seq_data,
+ sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
- sample_results_dict.update(zip(seq_group_ids, sample_results))
+ sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
- sample_results_dict[i]
+ sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results, sampled_token_ids_tensor
@@ -514,7 +561,7 @@ def _sample_with_triton_kernel(
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
- _, sampling_params = seq_group
+ sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
@@ -530,17 +577,16 @@ def _sample_with_triton_kernel(
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
- seq_group_ids = categorized_seq_group_ids[sampling_type]
- seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
- is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
- sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
- is_prompts, sample_indices,
+ seq_group_id = categorized_seq_group_ids[sampling_type]
+ seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
+ sample_metadata[sampling_type] = (seq_group_id, seq_groups,
+ sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
- for seq_group, is_prompt in zip(seq_groups, is_prompts):
- if is_prompt:
- _, sampling_params = seq_group
+ for seq_group in seq_groups:
+ if seq_group.is_prompt:
+ sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
@@ -564,22 +610,21 @@ def _sample_with_triton_kernel(
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
- (seq_group_ids, seq_groups, is_prompts, sample_indices,
+ (seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
- seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
+ seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
- sample_results = _beam_search_sample(seq_groups, is_prompts,
- sampling_metadata.seq_data,
+ sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
- sample_results_dict.update(zip(seq_group_ids, sample_results))
+ sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
- sample_results_dict[i]
+ sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
@@ -590,6 +635,18 @@ def _sample(
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+ """
+ Args:
+ probs: (num_query_tokens_in_batch, num_vocab)
+ logprobs: (num_query_tokens_in_batch, num_vocab)
+ sampling_metadata: The metadata for a batch for sampling.
+ sampling_tensors: Tensors that include sampling related metadata.
+
+ Returns:
+ (next_token_ids, parent_seq_ids) for each seq group in a batch.
+ If sampling is skipped, it returns ([], [])
+ sampled_token_ids_tensor: A tensor of sampled token ids.
+ """
return _sample_with_torch(
probs,
logprobs,
@@ -626,56 +683,97 @@ def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]],
-) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
- int, float]]]]:
- # Prepare query indices
- batched_logprobs_query_seq_indices: List[int] = []
- batched_logprobs_query_token_indices: List[int] = []
- # at least get one logprob for each token
+) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
+ """Return sample lobprobs and prompt logprobs.
+
+ The logic consists of 3 parts.
+ - Select indices to compute logprob from, ranks of token ids, and
+ the top k token ids from logprobs.
+ - Compute prompt logprobs if required.
+ - Compute sample logprobs if required.
+
+ Args:
+ logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
+ logprob per vocab. Sequence groups' query tokens are batched in a
+ single flattened tensor. For example, assuming there are N
+ seq groups, it is sorted by prefill tokens for seq_group_1 (if
+ prompt logprob is enabled), decode tokens for seq_group_1 (if
+ sampling is required), prefill tokens for seq_group_2, ...
+ sampling_metadata: The sampling metadata.
+ sample_results: (num_seq_groups) The tuple of (next_token_ids,
+ parent_ids) for each sequence group. When beam search is enabled,
+ sample_results can contain different number of seq_ids from
+ sampling_metadata.seq_groups. It is because beam search creates
+ 2 * BEAM_WIDTH number of samples (whereas there are only up to
+ BEAM_WIDTH number of seq_ids).
+
+ Returns:
+ A tuple of prompt and sample logprobs per sequence group in a batch.
+ """
+ # The index of query token to calculate logprobs. It includes both
+ # prompt and sample logprob indices.
+ query_indices: List[int] = []
+ # The next token ids to get the logprob value from.
+ next_token_ids: List[int] = []
+ # The largest requested number of logprobs. We find logprobs as many as the
+ # largest num logprobs in this API.
largest_num_logprobs = 1
- sample_idx = 0
- for i, (seq_group, sample_result) in enumerate(
- zip(sampling_metadata.seq_groups, sample_results)):
- seq_ids, sampling_params = seq_group
- next_token_ids, parent_ids = sample_result
- num_parent_seqs = len(seq_ids)
- if (i < sampling_metadata.num_prompts
+
+ # Select indices to compute logprob from, ranks of token ids, and the top
+ # k token ids from logprobs.
+ for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
+ sample_results):
+ sampling_params = seq_group.sampling_params
+
+ # Update indices and tokens for prompt logprobs.
+ if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
- prompt_len = sampling_metadata.prompt_lens[i]
- prompt_tokens = sampling_metadata.seq_data[
- seq_ids[0]].prompt_token_ids
- batched_logprobs_query_seq_indices.extend(
- sample_idx + j for j in range(prompt_len - 1))
- batched_logprobs_query_token_indices.extend(
- token_id for token_id in prompt_tokens[1:])
- sample_idx += prompt_len - 1
- batched_logprobs_query_seq_indices.extend(
- [sample_idx + parent_id for parent_id in parent_ids])
- batched_logprobs_query_token_indices.extend(next_token_ids)
- if sampling_params.logprobs is not None:
- largest_num_logprobs = max(largest_num_logprobs,
- sampling_params.logprobs)
- sample_idx += num_parent_seqs
- assert sample_idx == logprobs.size(0)
-
- batched_logprobs_query_seq_indices_gpu = torch.tensor(
- batched_logprobs_query_seq_indices, device=logprobs.device)
- batched_logprobs_query_token_indices_gpu = torch.tensor(
- batched_logprobs_query_token_indices, device=logprobs.device)
-
- # Batched query for logprobs of selected token
- batched_logprobs_query_result = logprobs[[
- batched_logprobs_query_seq_indices_gpu,
- batched_logprobs_query_token_indices_gpu
+ next_prompt_tokens = _get_next_prompt_tokens(seq_group)
+ query_indices.extend(seq_group.prompt_logprob_indices)
+ next_token_ids.extend(next_prompt_tokens)
+
+ # Update indices and next tokenes for sample logprob.
+ if seq_group.do_sample:
+ token_ids, parent_seq_ids = sample_result
+ # NOTE: We cannot directly use sample_indices because
+ # sample_indices only contain parent seq_ids of a previous step.
+ # The current step may have different number of seq_ids, and
+ # we can obtain it from `sample_result[1]`.
+ query_idx = seq_group.sample_indices[0]
+ query_indices.extend(
+ [query_idx + parent_id for parent_id in parent_seq_ids])
+ next_token_ids.extend(token_ids)
+
+ if sampling_params.logprobs is not None:
+ largest_num_logprobs = max(largest_num_logprobs,
+ sampling_params.logprobs)
+
+ assert len(next_token_ids) == len(query_indices)
+
+ if len(query_indices) == 0:
+ empty_sampled_logprob = []
+ empty_prompt_logprob = None
+ return [empty_prompt_logprob], [empty_sampled_logprob]
+
+ query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
+ next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
+
+ # (num_selected_query_tokens, num_logprobs). Note that query_indices can
+ # contain duplicates if beam search is enabled.
+ selected_logprobs = logprobs[[
+ query_indices_gpu,
+ next_token_ids_gpu,
]]
+ ranks = _get_ranks(
+ logprobs[query_indices_gpu],
+ next_token_ids_gpu,
+ )
+ assert selected_logprobs.shape[0] == ranks.shape[0]
- batched_ranks_query_result = _get_ranks(
- logprobs[batched_logprobs_query_seq_indices_gpu],
- batched_logprobs_query_token_indices_gpu)
-
- # Batched query for logprobs of topk tokens
+ # Logprobs of topk tokens for a batch of sequence groups.
+ # (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
@@ -685,79 +783,136 @@ def _get_logprobs(
else:
top_logprobs, top_token_ids = None, None
- batched_logprobs_query_result = batched_logprobs_query_result.cpu()
- batched_ranks_query_result = batched_ranks_query_result.cpu()
-
- # Gather results
- result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
- result_sample_logprobs: List[SampleLogprobs] = []
- sample_idx = 0
- query_result_idx = 0
- for i, (seq_group, sample_result) in enumerate(
- zip(sampling_metadata.seq_groups, sample_results)):
- seq_ids, sampling_params = seq_group
- next_token_ids, parent_ids = sample_result
+ selected_logprobs = selected_logprobs.cpu()
+ ranks = ranks.cpu()
+
+ # Find prompt/sample logprobs.
+ prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
+ sample_logprobs_per_seq_group: List[SampleLogprobs] = []
+ top_logprob_idx = 0
+ selected_logprobs_idx = 0
+
+ for seq_group, sample_result in zip(sampling_metadata.seq_groups,
+ sample_results):
+ (prompt_logprobs, top_logprob_idx,
+ selected_logprobs_idx) = _get_prompt_logprob_if_needed(
+ seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
+ selected_logprobs_idx, top_logprob_idx)
+ prompt_logprobs_per_seq_group.append(prompt_logprobs)
+
+ (sampled_logprobs, top_logprob_idx,
+ selected_logprobs_idx) = _get_sampled_logprob_if_needed(
+ seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
+ top_logprobs, selected_logprobs_idx, top_logprob_idx)
+ sample_logprobs_per_seq_group.append(sampled_logprobs)
+
+ return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
+
+
+def _get_prompt_logprob_if_needed(
+ seq_group: SequenceGroupToSample,
+ selected_logprobs: torch.Tensor,
+ ranks: torch.Tensor,
+ top_token_ids: torch.Tensor,
+ top_logprobs: torch.Tensor,
+ selected_logprobs_idx: int,
+ top_logprob_idx: int,
+):
+ """Compute the prompt logprob from a sequence group if needed."""
+ sampling_params = seq_group.sampling_params
+ is_prompt = seq_group.is_prompt
+
+ # Find prompt logprobs
+ prompt_logprobs: Optional[PromptLogprobs] = None
+ if (is_prompt and sampling_params.prompt_logprobs is not None):
+ prompt_logprobs = []
+ num_logprobs = sampling_params.prompt_logprobs
+ next_prompt_tokens = _get_next_prompt_tokens(seq_group)
+ for token_id in next_prompt_tokens:
+ # Calculate the prompt logprob of the real prompt tokens.
+ # Use tuple here for performance (to use to_list()).
+ # {token_id: (logprob, rank_from_vocab)}
+ prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
+ token_id: (selected_logprobs[selected_logprobs_idx].item(),
+ ranks[selected_logprobs_idx].item())
+ }
- # Prompt logprobs
- if (i < sampling_metadata.num_prompts
- and sampling_params.prompt_logprobs is not None):
- num_logprobs = sampling_params.prompt_logprobs
- prompt_tokens = sampling_metadata.seq_data[
- seq_ids[0]].prompt_token_ids
- group_prompt_logprobs: PromptLogprobs = [None]
- for token_id in prompt_tokens[1:]:
- prompt_logprobs_dict = {
- token_id:
- (batched_logprobs_query_result[query_result_idx].item(),
- batched_ranks_query_result[query_result_idx].item())
- }
- if num_logprobs > 0:
- prompt_logprobs_dict.update(
+ # Add top K prompt logprobs along with its rank.
+ if num_logprobs > 0:
+ prompt_logprobs_dict.update(
+ zip(
+ top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
zip(
- top_token_ids[sample_idx, :num_logprobs].tolist(),
- zip(
- top_logprobs[
- sample_idx, :num_logprobs].tolist(),
- range(1, num_logprobs + 1))))
- group_prompt_logprobs.append({
- token_id: Logprob(*logprob_rank)
- for token_id, logprob_rank in prompt_logprobs_dict.items()
- })
- sample_idx += 1
- query_result_idx += 1
- result_prompt_logprobs.append(group_prompt_logprobs)
- else:
- result_prompt_logprobs.append(None)
-
- # Sample logprobs
- num_logprobs = sampling_params.logprobs
- if num_logprobs is None:
- num_logprobs = 0
- group_sample_logprobs: SampleLogprobs = []
- for next_token_id, parent_id in zip(next_token_ids, parent_ids):
- sample_logprobs_dict = {
+ top_logprobs[
+ top_logprob_idx, :num_logprobs].tolist(),
+ # This is ranks. Since top_logprob is sorted,
+ # we can just use a range here.
+ range(1, num_logprobs + 1))))
+ prompt_logprobs.append({
+ token_id: Logprob(*logprob_and_rank)
+ for token_id, logprob_and_rank in prompt_logprobs_dict.items()
+ })
+ # + 1 to go to the next prompt token.
+ top_logprob_idx += 1
+ selected_logprobs_idx += 1
+ return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
+
+
+def _get_sampled_logprob_if_needed(
+ seq_group: SequenceGroupToSample,
+ sample_result: Tuple[List[int], List[int]],
+ selected_logprobs: torch.Tensor,
+ ranks: torch.Tensor,
+ top_token_ids: torch.Tensor,
+ top_logprobs: torch.Tensor,
+ selected_logprobs_idx: int,
+ top_logprob_idx: int,
+):
+ """Compute the sample logprob if needed."""
+ seq_ids = seq_group.seq_ids
+ num_logprobs = seq_group.sampling_params.logprobs
+ if num_logprobs is None:
+ num_logprobs = 0
+ sampled_logprobs: SampleLogprobs = []
+ next_token_ids, parent_seq_ids = sample_result
+
+ if seq_group.do_sample:
+ assert len(next_token_ids) > 0
+ for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
+ # Calculate the sample logprob of the real sampled tokens.
+ # Use tuple here for performance (to use to_list()).
+ # token_id: (logprob, rank_from_vocab)
+ sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
next_token_id:
- (batched_logprobs_query_result[query_result_idx].item(),
- batched_ranks_query_result[query_result_idx].item())
+ (selected_logprobs[selected_logprobs_idx].item(),
+ ranks[selected_logprobs_idx].item())
}
- query_result_idx += 1
+ # +1 to go to the next sampled token. Note that
+ # selected_logprobs can contain duplicates unlike top_logprobs
+ # when beam search is enabled.
+ selected_logprobs_idx += 1
+
+ # Second, add top K logprobs along with its rank.
if num_logprobs >= 0:
- sample_logprobs_dict.update(
+ sampled_logprobs_dict.update(
zip(
- top_token_ids[sample_idx +
+ top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
zip(
- top_logprobs[sample_idx +
+ top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
+ # This is rank. Since top_logprob is sorted, we
+ # can just use a range here.
range(1, num_logprobs + 1))))
- group_sample_logprobs.append({
- token_id: Logprob(*logprob_rank)
- for token_id, logprob_rank in sample_logprobs_dict.items()
+ sampled_logprobs.append({
+ token_id: Logprob(*logprob_and_rank)
+ for token_id, logprob_and_rank in
+ sampled_logprobs_dict.items()
})
- result_sample_logprobs.append(group_sample_logprobs)
- sample_idx += len(seq_ids)
-
- return result_prompt_logprobs, result_sample_logprobs
+ # There are len(seq_ids) number of sampled tokens for the current
+ # sequence group in top_logprobs. Jump to the next seq_group.
+ top_logprob_idx += len(seq_ids)
+ return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
@@ -832,7 +987,7 @@ def _build_sampler_output(
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
- seq_ids, _ = seq_group
+ seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
@@ -854,3 +1009,36 @@ def _build_sampler_output(
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
)
+
+
+def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
+ """Get a list of next prompt tokens to compute logprob from a
+ given sequence group.
+
+ It is used to compute prompt logprob. Imagine you have logprob for each
+ query token. Query token needs to know the next prompt token id to compute
+ prompt logprob. This is a helper to obtain next prompt token ids.
+
+ This API has to be used only when the caller knows seq_group is in prefill
+ stage.
+
+ Returns:
+ A list of next prompt tokens to compute logprob.
+ """
+ assert seq_group.is_prompt, (
+ "Caller should ensure the sequence group is in a prefill stage.")
+ seq_ids = seq_group.seq_ids
+ subquery_len = seq_group.subquery_len
+ assert subquery_len is not None
+ # prompt has only 1 seq id.
+ assert len(seq_ids) == 1
+ seq_data = seq_group.seq_data[seq_ids[0]]
+ computed_len = seq_data.get_num_computed_tokens()
+ prompt_tokens = seq_data.prompt_token_ids
+ # +1 because we are looking for a next prompt token.
+ next_token_index_start = computed_len + 1
+ next_token_index_end = min(computed_len + subquery_len + 1,
+ len(prompt_tokens))
+ next_prompt_tokens = prompt_tokens[
+ next_token_index_start:next_token_index_end]
+ return next_prompt_tokens
diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py
index 31032c4cead20..12156b2ba1aa2 100644
--- a/vllm/model_executor/sampling_metadata.py
+++ b/vllm/model_executor/sampling_metadata.py
@@ -6,57 +6,275 @@
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType
-from vllm.sequence import SequenceData
-from vllm.utils import is_pin_memory_available
+from vllm.sequence import SequenceData, SequenceGroupMetadata
+from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
+ maybe_expand_dim)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
+@dataclass
+class SequenceGroupToSample:
+ # Sequence ids for the sequence group in a previous step.
+ seq_ids: List[int]
+ sampling_params: SamplingParams
+ # seq_id -> sequence data.
+ seq_data: Dict[int, SequenceData]
+ # The length of the prompt of the sequence group. None if it is in a decode
+ # stage.
+ prompt_len: Optional[int]
+ # The length of the query tokens to compute in the current step. None if it
+ # is in a decode stage. The length of subquery_len <= prompt_len.
+ subquery_len: Optional[int]
+ # A random number generator for sampling.
+ generator: Optional[torch.Generator]
+ # True if the sequence group is in prefill stage. False if it is in a
+ # decode stage.
+ is_prompt: bool
+ # Query token indices from logits. to compute prompt logprob. Empty if
+ # prompt logprob is not required.
+ prompt_logprob_indices: List[int]
+ # Sample token indices from logits. Empty if sampling is not required.
+ sample_indices: List[int]
+
+ @property
+ def do_sample(self):
+ return len(self.sample_indices) > 0
+
+ def __post_init__(self):
+ if len(self.prompt_logprob_indices) > 0:
+ assert self.sampling_params.prompt_logprobs is not None
+ if self.is_prompt:
+ assert self.prompt_len is not None
+ assert self.subquery_len is not None
+
+
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
+ The usage is as follow;
+ ```
+ hidden_states = execute_model(...)
+ logits = hidden_states[sampling_metadata.selected_token_indices]
+ sample(logits)
+
+ def sample(logits):
+ # Use categorized_sample_indices for sampling....
+ ```
+
Args:
- seq_groups: List of (seq_ids, sampling_params).
- seq_data: Seq_id -> SequenceData.
- prompt_lens: Lengths of prompts.
- selected_token_indices: Token indices selected for sampling.
+ seq_groups: List of batched sequence groups.
+ selected_token_indices: (num_query_tokens_to_logprob). Indices to find
+ logits from the initial model output hidden states.
categorized_sample_indices: SamplingType -> token indices to sample.
- generators: List of torch.Generators to use for seeded sampling
- perform_sampling: Whether to perform sampling. This option is used to
- make the sampling only happens in the driver worker, and disable
- sampling in other worker processes.
+ Each token indices is 2D tensor of (num_indices, num_indices) where
+ the first item means the sample index within the returned logit
+ (before pruning padding), and the second item means the sample
+ index after pruning using selected_token_indices.
+ For example, if the returned logit is [1, 2, 3], and we select
+ [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
+ The first tuple is [1, 2] (sampled index within original logit),
+ and the second tuple is [0, 1] (sampled index within pruned logit).
+ num_prompts: Number of prompt sequence groups in seq_groups.
"""
def __init__(
self,
- seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
- seq_data: Optional[Dict[int, SequenceData]],
- prompt_lens: Optional[List[int]],
+ seq_groups: List[SequenceGroupToSample],
selected_token_indices: torch.Tensor,
- categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
- generators: Optional[List[torch.Generator]] = None,
- perform_sampling: bool = True,
+ categorized_sample_indices: Dict[SamplingType, torch.Tensor],
+ num_prompts: int,
) -> None:
self.seq_groups = seq_groups
- self.seq_data = seq_data
- self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
- self.generators = generators
- self.perform_sampling = perform_sampling
+ self.num_prompts = num_prompts
- self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
+ @staticmethod
+ def prepare(
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ prompt_lens: List[int],
+ subquery_lens: Optional[List[int]],
+ device: str,
+ pin_memory: bool,
+ ) -> "SamplingMetadata":
+ (
+ seq_groups,
+ selected_token_indices,
+ categorized_sample_indices,
+ num_prompts,
+ ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
+ subquery_lens, device)
+ selected_token_indices = async_tensor_h2d(selected_token_indices,
+ dtype=torch.long,
+ target_device=device,
+ pin_memory=pin_memory)
+ categorized_sample_indices = {
+ t: maybe_expand_dim(
+ async_tensor_h2d(seq_ids,
+ dtype=torch.int,
+ target_device=device,
+ pin_memory=pin_memory), 2, 2)
+ for t, seq_ids in categorized_sample_indices.items()
+ }
+
+ sampling_metadata = SamplingMetadata(
+ seq_groups=seq_groups,
+ selected_token_indices=selected_token_indices,
+ categorized_sample_indices=categorized_sample_indices,
+ num_prompts=num_prompts,
+ )
+ return sampling_metadata
def __repr__(self) -> str:
return (
"SamplingMetadata("
f"seq_groups={self.seq_groups}, "
- f"seq_data={self.seq_data}, "
- f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, "
- f"categorized_sample_indices={self.categorized_sample_indices}), "
- f"perform_sampling={self.perform_sampling})")
+ f"categorized_sample_indices={self.categorized_sample_indices}), ")
+
+
+def _prepare_seq_groups(
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ prompt_lens: List[int],
+ subquery_lens: Optional[List[int]],
+ device: str,
+) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
+ SamplingType, List[Tuple[int, int]]], int]:
+ """Prepare sequence groups and indices for sampling.
+
+ Args:
+ seq_group_metadata_list: A list of sequence group to batch.
+ prompt_lens: A list of prompt lens per sequence group.
+ Index of prompt len should match with seq_group_metadata_list.
+ subquery_lens: A list of query lengths. Prompt lens include the length
+ of entire prompt tokens, and it could be shorter.
+ device: A device to use for random number generator,
+ `SequenceGroupToSample.generator`.
+
+ Returns:
+ seq_groups: A list of sequence group to sample.
+ selected_token_indices: See the definition from `SamplingMetadata`.
+ categorized_sample_indices: See the definition from `SamplingMetadata`.
+ num_prompts: Total number of prompts from `seq_group_metadata_list`.
+ """
+ # Batched sequence groups for the current model forward stsep.
+ seq_groups: List[SequenceGroupToSample] = []
+ # A list of token indices to sample/compute logprob. It is used to
+ # prune the outcome logits from the model for the performance.
+ selected_token_indices: List[int] = []
+ # Used for selected_token_indices.
+ model_output_idx = 0
+
+ # Sampling type -> (
+ # indices to sample/prompt logprob within pruned output logits,
+ # indices to sample within pruned logits)
+ categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
+ t: []
+ for t in SamplingType
+ }
+ # Index of logits to compute logprob. Logits include both prompt logprob
+ # and sample logprob indices.
+ logit_idx = 0
+ # Index to sample from a sample tensor. It is used by triton sample kernel.
+ # See `_sample_with_triton_kernel` for more details.
+ sample_idx = 0
+ # Total number of prompts from given sequence groups.
+ num_prompts = 0
+
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+ sampling_params = seq_group_metadata.sampling_params
+ is_prompt = seq_group_metadata.is_prompt
+ generator: Optional[torch.Generator] = None
+ # If the current seq group is in decode stage, it is None.
+ prompt_len: Optional[int] = None
+ subquery_len: Optional[int] = None
+ prompt_logprob_indices: List[int] = []
+ sample_indices: List[int] = []
+ do_sample = seq_group_metadata.do_sample
+
+ if seq_group_metadata.is_prompt:
+ if sampling_params.seed is not None:
+ seq_group_metadata.state.generator = torch.Generator(
+ device=device).manual_seed(sampling_params.seed)
+
+ num_prompts += 1
+ num_prefill_sample = len(seq_ids)
+ assert num_prefill_sample == 1
+ assert subquery_lens is not None and prompt_lens is not None
+ subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
+ # If we need sampling, exclude num_prefill_sample tokens from
+ # prompt logprob.
+ prompt_logprob_len = (subquery_len - num_prefill_sample
+ if do_sample else subquery_len)
+ sample_len = num_prefill_sample if do_sample else 0
+ else:
+ # Decode
+ prompt_logprob_len = 0
+ sample_len = len(seq_ids) if do_sample else 0
+
+ # Update indices to select from the model output.
+ """
+ This blocks computes selected_token_indices which is used in the
+ following way.
+
+ hidden_states = model(...)
+ logits = hidden_states[selected_token_indices]
+ """
+
+ if sampling_params.prompt_logprobs:
+ selected_token_indices.extend(
+ range(model_output_idx, model_output_idx + prompt_logprob_len))
+ model_output_idx += prompt_logprob_len
+ if do_sample:
+ selected_token_indices.extend(
+ range(model_output_idx, model_output_idx + sample_len))
+ model_output_idx += sample_len
+
+ # We now find indices for logprob computation and sampling.
+ """
+ This block computes categorized_sample_indices which is used in the
+ following way.
+
+ hidden_states = model(...)
+ logits = hidden_states[selected_token_indices]
+ def sample(logits):
+ # Use categorized_sample_indices for sampling.
+ # prompt_logprob_indices to find prompt logprob indices.
+ # sample_indices to find sample indices.
+ """
+
+ if sampling_params.prompt_logprobs is not None:
+ prompt_logprob_indices.extend(
+ range(logit_idx, logit_idx + prompt_logprob_len))
+ logit_idx += prompt_logprob_len
+ if do_sample:
+ sample_indices.extend(range(logit_idx, logit_idx + sample_len))
+ categorized_sample_indices[sampling_params.sampling_type].extend(
+ list(
+ zip(range(logit_idx, logit_idx + sample_len),
+ range(sample_idx, sample_idx + sample_len))))
+ logit_idx += sample_len
+ sample_idx += sample_len
+
+ if sampling_params.seed is not None:
+ generator = seq_group_metadata.state.generator
+
+ seq_groups.append(
+ SequenceGroupToSample(
+ seq_ids=seq_ids,
+ sampling_params=sampling_params,
+ seq_data=seq_group_metadata.seq_data,
+ prompt_len=prompt_len,
+ subquery_len=subquery_len,
+ generator=generator,
+ is_prompt=is_prompt,
+ prompt_logprob_indices=list(prompt_logprob_indices),
+ sample_indices=list(sample_indices)))
+ return (seq_groups, selected_token_indices, categorized_sample_indices,
+ num_prompts)
@dataclass
@@ -112,11 +330,10 @@ def from_sampling_metadata(
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
- sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
- assert sampling_metadata.seq_data is not None
- for i, seq_group in enumerate(sampling_metadata.seq_groups):
- seq_ids, sampling_params = seq_group
+ for seq_group in sampling_metadata.seq_groups:
+ seq_ids = seq_group.seq_ids
+ sampling_params = seq_group.sampling_params
temperature = sampling_params.temperature
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
@@ -145,45 +362,46 @@ def from_sampling_metadata(
or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True
- if (i < sampling_metadata.num_prompts
+ is_prompt = seq_group.is_prompt
+ if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
- assert sampling_metadata.prompt_lens is not None
- prompt_len = sampling_metadata.prompt_lens[i]
- temperatures += [temperature] * (prompt_len - 1)
- top_ps += [top_p] * (prompt_len - 1)
- top_ks += [top_k] * (prompt_len - 1)
- min_ps += [min_p] * (prompt_len - 1)
- presence_penalties += [0] * (prompt_len - 1)
- frequency_penalties += [0] * (prompt_len - 1)
- repetition_penalties += [1] * (prompt_len - 1)
- prompt_tokens.extend([] for _ in range(prompt_len - 1))
- output_tokens.extend([] for _ in range(prompt_len - 1))
- for seq_id in seq_ids:
- seq_data = sampling_metadata.seq_data[seq_id]
- prompt_tokens.append(seq_data.prompt_token_ids)
- output_tokens.append(seq_data.output_token_ids)
- temperatures += [temperature] * len(seq_ids)
- top_ps += [top_p] * len(seq_ids)
- top_ks += [top_k] * len(seq_ids)
- min_ps += [min_p] * len(seq_ids)
- presence_penalties += [p] * len(seq_ids)
- frequency_penalties += [f] * len(seq_ids)
- repetition_penalties += [r] * len(seq_ids)
-
- is_prompt = i < sampling_metadata.num_prompts
+ subquery_len = seq_group.subquery_len
+ assert subquery_len is not None
+ prefill_len = len(seq_group.prompt_logprob_indices)
+ temperatures += [temperature] * prefill_len
+ top_ps += [top_p] * prefill_len
+ top_ks += [top_k] * prefill_len
+ min_ps += [min_p] * prefill_len
+ presence_penalties += [0] * prefill_len
+ frequency_penalties += [0] * prefill_len
+ repetition_penalties += [1] * prefill_len
+ prompt_tokens.extend([] for _ in range(prefill_len))
+ output_tokens.extend([] for _ in range(prefill_len))
+
+ if seq_group.do_sample:
+ sample_lens = len(seq_group.sample_indices)
+ assert sample_lens == len(seq_ids)
+ for seq_id in seq_ids:
+ seq_data = seq_group.seq_data[seq_id]
+ prompt_tokens.append(seq_data.prompt_token_ids)
+ output_tokens.append(seq_data.output_token_ids)
+ temperatures += [temperature] * len(seq_ids)
+ top_ps += [top_p] * len(seq_ids)
+ top_ks += [top_k] * len(seq_ids)
+ min_ps += [min_p] * len(seq_ids)
+ presence_penalties += [p] * len(seq_ids)
+ frequency_penalties += [f] * len(seq_ids)
+ repetition_penalties += [r] * len(seq_ids)
+
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
- assert sampling_metadata.prompt_lens is not None
- prompt_len = sampling_metadata.prompt_lens[i]
+ subquery_len = seq_group.subquery_len
+ assert subquery_len is not None
- if sampling_params.prompt_logprobs is not None:
- # NOTE: the sampling position is the last token
- # in the prompt
- sample_indices_start_idx += prompt_len - 1
for seq_id in seq_ids:
- seq_data = sampling_metadata.seq_data[seq_id]
+ seq_data = seq_group.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
@@ -193,8 +411,7 @@ def from_sampling_metadata(
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
- sample_indices.append(sample_indices_start_idx)
- sample_indices_start_idx += 1
+ sample_indices.extend(seq_group.sample_indices)
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
@@ -217,12 +434,14 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float],
# Note that the performance will be very bad without
# pinned memory.
pin_memory = is_pin_memory_available()
- prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
+ prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
+ default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
- output_max_len = max(len(tokens) for tokens in output_tokens)
+ output_max_len = max([len(tokens) for tokens in output_tokens],
+ default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
diff --git a/vllm/sequence.py b/vllm/sequence.py
index b296b37a84f15..567fca5709518 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -28,7 +28,10 @@ class Logprob:
decoded_token: Optional[str] = None
+# {token_id -> logprob} per each sequence group. None if the corresponding
+# sequence group doesn't require prompt logprob.
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
+# {token_id -> logprob} for each sequence group.
SampleLogprobs = List[Dict[int, Logprob]]
@@ -215,7 +218,7 @@ def __init__(
self.eos_token_id = eos_token_id
self.lora_request = lora_request
- self.data = SequenceData(prompt_token_ids)
+ self.data: SequenceData = SequenceData(prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@@ -559,6 +562,9 @@ class SequenceGroupMetadata:
sampling_params: The sampling parameters used to generate the outputs.
block_tables: The block tables. (Seq id -> list of physical block
numbers)
+ do_sample: True if sampling is required. Sampling is not required when
+ e.g., prefill is chunked, and the current iteration only computes
+ query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
@@ -573,6 +579,7 @@ def __init__(
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
+ do_sample: bool = True,
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
@@ -589,6 +596,7 @@ def __init__(
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self._token_chunk_size = token_chunk_size
+ self.do_sample = do_sample
if self._token_chunk_size is None:
if is_prompt:
@@ -650,6 +658,7 @@ def __init__(
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
+ # Prompt logprob for each prompt query token.
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index bf0a6c84e6f07..34d7d3dffea18 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
import torch
from torch import nn
@@ -10,9 +10,8 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
-from vllm.sampling_params import SamplingParams, SamplingType
-from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
-from vllm.utils import make_tensor_with_pad, maybe_expand_dim
+from vllm.sequence import SamplerOutput, SequenceGroupMetadata
+from vllm.utils import make_tensor_with_pad
logger = init_logger(__name__)
@@ -38,6 +37,8 @@ def __init__(
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
+ # Currently, CPU worker doesn't support chunked prefill.
+ assert self.scheduler_config.chunked_prefill_enabled is False
self.lora_config = lora_config
self.vision_language_config = vision_language_config
self.load_config = load_config
@@ -252,99 +253,6 @@ def _prepare_decode(
attn_metadata,
)
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- generators: List[torch.Generator] = []
- selected_token_start_idx = 0
- categorized_sample_indices: Dict[SamplingType,
- List[Tuple[int, int]]] = {
- t: []
- for t in SamplingType
- }
- categorized_sample_indices_start_idx = 0
- categorized_sampled_token_indices_start_idx = 0
-
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
-
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- subquery_len = prompt_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += subquery_len - 1
-
- categorized_sample_indices[
- sampling_params.sampling_type].append(
- (categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx))
- categorized_sample_indices_start_idx += 1
- categorized_sampled_token_indices_start_idx += 1
-
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + subquery_len - 1))
- selected_token_indices.append(selected_token_start_idx +
- subquery_len - 1)
- selected_token_start_idx += subquery_len
-
- if sampling_params.seed is not None:
- seq_group_metadata.state.generator = torch.Generator(
- device=self.device).manual_seed(sampling_params.seed)
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + num_seqs))
- selected_token_start_idx += num_seqs
-
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- zip(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx +
- num_seqs),
- range(
- categorized_sampled_token_indices_start_idx,
- categorized_sampled_token_indices_start_idx +
- num_seqs)))
- categorized_sample_indices_start_idx += num_seqs
- categorized_sampled_token_indices_start_idx += num_seqs
-
- if sampling_params.seed is not None:
- generators.append(seq_group_metadata.state.generator)
-
- selected_token_indices = torch.tensor(selected_token_indices,
- dtype=torch.long)
-
- categorized_sample_indices = {
- t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2)
- for t, seq_ids in categorized_sample_indices.items()
- }
-
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
-
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- generators=generators,
- )
- return sampling_metadata
-
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -364,8 +272,15 @@ def prepare_input_tensors(
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ # subquery_lens is not needed if chunked prefill is not
+ # supported. Since CPU worker doesn't support chunked prefill
+ # just use prompt_lens instead.
+ prompt_lens,
+ self.device,
+ pin_memory=False)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
@@ -389,7 +304,6 @@ def prepare_input_tensors(
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
- perform_sampling=False,
)
return (input_tokens, input_positions, attn_metadata,
@@ -421,7 +335,7 @@ def execute_model(
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
- if not sampling_metadata.perform_sampling:
+ if not self.is_driver_worker:
return None
# Sample the next token.
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index c6da28f110325..0704f5fec54d0 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -20,12 +20,11 @@
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
-from vllm.sampling_params import SamplingParams, SamplingType
+from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata)
-from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
- is_pin_memory_available, make_tensor_with_pad,
- maybe_expand_dim)
+from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available,
+ make_tensor_with_pad)
logger = init_logger(__name__)
@@ -547,108 +546,6 @@ def _prepare_decode(
slot_mapping=slot_mapping,
)
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- subquery_lens: Optional[List[int]],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- generators: List[torch.Generator] = []
- selected_token_start_idx = 0
- categorized_sample_indices: Dict[SamplingType,
- List[Tuple[int, int]]] = {
- t: []
- for t in SamplingType
- }
- categorized_sample_indices_start_idx = 0
- categorized_sampled_token_indices_start_idx = 0
-
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
-
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- assert subquery_lens is not None
- subquery_len = subquery_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += subquery_len - 1
-
- categorized_sample_indices[
- sampling_params.sampling_type].append(
- (categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx))
- categorized_sample_indices_start_idx += 1
- categorized_sampled_token_indices_start_idx += 1
-
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + subquery_len - 1))
- selected_token_indices.append(selected_token_start_idx +
- subquery_len - 1)
- selected_token_start_idx += subquery_len
-
- if sampling_params.seed is not None:
- seq_group_metadata.state.generator = torch.Generator(
- device=self.device).manual_seed(sampling_params.seed)
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + num_seqs))
- selected_token_start_idx += num_seqs
-
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- list(
- zip(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx +
- num_seqs),
- range(
- categorized_sampled_token_indices_start_idx,
- categorized_sampled_token_indices_start_idx
- + num_seqs))))
- categorized_sample_indices_start_idx += num_seqs
- categorized_sampled_token_indices_start_idx += num_seqs
-
- if sampling_params.seed is not None:
- generators.append(seq_group_metadata.state.generator)
-
- selected_token_indices = async_tensor_h2d(selected_token_indices,
- dtype=torch.long,
- target_device=self.device,
- pin_memory=self.pin_memory)
-
- categorized_sample_indices = {
- t: maybe_expand_dim(
- async_tensor_h2d(seq_ids,
- dtype=torch.int,
- target_device=self.device,
- pin_memory=self.pin_memory), 2, 2)
- for t, seq_ids in categorized_sample_indices.items()
- }
-
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
-
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- generators=generators,
- )
- return sampling_metadata
-
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -685,9 +582,9 @@ def prepare_input_tensors(
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- prompt_lens,
- subquery_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list, prompt_lens, subquery_lens,
+ self.device, self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
@@ -788,12 +685,9 @@ def prepare_input_tensors(
**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
- seq_data=None,
- prompt_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
- generators=None,
- perform_sampling=False,
+ num_prompts=0,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
@@ -852,7 +746,7 @@ def execute_model(
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
- if not sampling_metadata.perform_sampling:
+ if not self.is_driver_worker:
return None
# Sample the next token.
@@ -860,6 +754,7 @@ def execute_model(
logits=logits,
sampling_metadata=sampling_metadata,
)
+
return output
@torch.inference_mode()
diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py
index 487df334d73e3..a974e85c22f45 100644
--- a/vllm/worker/neuron_model_runner.py
+++ b/vllm/worker/neuron_model_runner.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Optional, Tuple
+from typing import List, Optional, Tuple
import torch
from torch import nn
@@ -8,10 +8,8 @@
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
-from vllm.sampling_params import SamplingParams, SamplingType
-from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
-from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
- make_tensor_with_pad, maybe_expand_dim)
+from vllm.sequence import SamplerOutput, SequenceGroupMetadata
+from vllm.utils import is_pin_memory_available, make_tensor_with_pad
logger = init_logger(__name__)
@@ -141,106 +139,6 @@ def _prepare_decode(
return input_tokens, input_positions, input_block_ids
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- prompt_lens: List[int],
- ) -> SamplingMetadata:
- seq_groups: List[Tuple[List[int], SamplingParams]] = []
- selected_token_indices: List[int] = []
- generators: List[torch.Generator] = []
- selected_token_start_idx = 0
- categorized_sample_indices: Dict[SamplingType,
- List[Tuple[int, int]]] = {
- t: []
- for t in SamplingType
- }
- categorized_sample_indices_start_idx = 0
- categorized_sampled_token_indices_start_idx = 0
-
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = list(seq_group_metadata.seq_data.keys())
- sampling_params = seq_group_metadata.sampling_params
- seq_groups.append((seq_ids, sampling_params))
-
- if seq_group_metadata.is_prompt:
- assert len(seq_ids) == 1
- assert prompt_lens is not None
- prompt_len = prompt_lens[i]
- if sampling_params.prompt_logprobs is not None:
- # NOTE: prompt token positions do not need sample, skip
- categorized_sample_indices_start_idx += prompt_len - 1
-
- categorized_sample_indices[
- sampling_params.sampling_type].append(
- (categorized_sample_indices_start_idx,
- categorized_sampled_token_indices_start_idx))
- categorized_sample_indices_start_idx += 1
- categorized_sampled_token_indices_start_idx += 1
-
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + prompt_len - 1))
- selected_token_indices.append(selected_token_start_idx +
- prompt_len - 1)
- selected_token_start_idx += prompt_len
-
- if sampling_params.seed is not None:
- seq_group_metadata.state.generator = torch.Generator(
- device=self.device).manual_seed(sampling_params.seed)
- else:
- num_seqs = len(seq_ids)
- selected_token_indices.extend(
- range(selected_token_start_idx,
- selected_token_start_idx + num_seqs))
- selected_token_start_idx += num_seqs
-
- categorized_sample_indices[
- sampling_params.sampling_type].extend(
- zip(
- range(
- categorized_sample_indices_start_idx,
- categorized_sample_indices_start_idx +
- num_seqs),
- range(
- categorized_sampled_token_indices_start_idx,
- categorized_sampled_token_indices_start_idx +
- num_seqs)))
- categorized_sample_indices_start_idx += num_seqs
- categorized_sampled_token_indices_start_idx += num_seqs
-
- if sampling_params.seed is not None:
- generators.append(seq_group_metadata.state.generator)
-
- selected_token_indices = async_tensor_h2d(selected_token_indices,
- dtype=torch.long,
- target_device=self.device,
- pin_memory=self.pin_memory)
-
- categorized_sample_indices = {
- t: maybe_expand_dim(
- async_tensor_h2d(seq_ids,
- dtype=torch.int,
- target_device=self.device,
- pin_memory=self.pin_memory), 2, 2)
- for t, seq_ids in categorized_sample_indices.items()
- }
-
- seq_data: Dict[int, SequenceData] = {}
- for seq_group_metadata in seq_group_metadata_list:
- seq_data.update(seq_group_metadata.seq_data)
-
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- seq_data=seq_data,
- prompt_lens=prompt_lens,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- generators=generators,
- )
- return sampling_metadata
-
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -256,8 +154,15 @@ def prepare_input_tensors(
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
- sampling_metadata = self._prepare_sample(seq_group_metadata_list,
- prompt_lens)
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ prompt_lens,
+ # subquery_lens is not needed if chunked prefill is not
+ # supported. Since neuron worker doesn't support chunked prefill
+ # just use prompt_lens instead.
+ prompt_lens,
+ self.device,
+ self.pin_memory)
return (input_tokens, input_positions, input_block_ids,
sampling_metadata)
From a62aaf1df558d69658a42c1ab749368ab0325f35 Mon Sep 17 00:00:00 2001
From: Cody Yu
Date: Fri, 26 Apr 2024 13:41:14 -0700
Subject: [PATCH 131/413] [Misc][Refactor] Generalize linear_method to be
quant_method (#4373)
---
tests/quantization/test_fp8.py | 2 +-
tests/tensorizer_loader/test_tensorizer.py | 4 +-
vllm/lora/layers.py | 30 ++--
vllm/model_executor/layers/linear.py | 169 +++++++++---------
.../layers/quantization/__init__.py | 4 +-
.../layers/quantization/aqlm.py | 13 +-
.../model_executor/layers/quantization/awq.py | 19 +-
.../layers/quantization/base_config.py | 31 +++-
.../model_executor/layers/quantization/fp8.py | 60 +++----
.../layers/quantization/gptq.py | 19 +-
.../layers/quantization/marlin.py | 13 +-
.../layers/quantization/squeezellm.py | 24 +--
vllm/model_executor/model_loader/loader.py | 38 ++--
.../model_executor/model_loader/tensorizer.py | 13 +-
vllm/model_executor/models/baichuan.py | 43 ++---
vllm/model_executor/models/bloom.py | 33 ++--
vllm/model_executor/models/chatglm.py | 37 ++--
vllm/model_executor/models/commandr.py | 33 ++--
vllm/model_executor/models/dbrx.py | 35 ++--
vllm/model_executor/models/decilm.py | 7 +-
vllm/model_executor/models/deepseek.py | 45 +++--
vllm/model_executor/models/falcon.py | 33 ++--
vllm/model_executor/models/gemma.py | 33 ++--
vllm/model_executor/models/gpt2.py | 33 ++--
vllm/model_executor/models/gpt_bigcode.py | 33 ++--
vllm/model_executor/models/gpt_j.py | 33 ++--
vllm/model_executor/models/gpt_neox.py | 33 ++--
vllm/model_executor/models/internlm2.py | 33 ++--
vllm/model_executor/models/jais.py | 33 ++--
vllm/model_executor/models/llama.py | 32 ++--
vllm/model_executor/models/llava.py | 9 +-
vllm/model_executor/models/minicpm.py | 35 ++--
vllm/model_executor/models/mixtral.py | 44 ++---
vllm/model_executor/models/mixtral_quant.py | 41 ++---
vllm/model_executor/models/mpt.py | 33 ++--
vllm/model_executor/models/olmo.py | 32 ++--
vllm/model_executor/models/opt.py | 37 ++--
vllm/model_executor/models/orion.py | 33 ++--
vllm/model_executor/models/phi.py | 35 ++--
vllm/model_executor/models/qwen.py | 33 ++--
vllm/model_executor/models/qwen2.py | 33 ++--
vllm/model_executor/models/qwen2_moe.py | 45 +++--
vllm/model_executor/models/stablelm.py | 29 +--
vllm/model_executor/models/starcoder2.py | 32 ++--
vllm/model_executor/models/xverse.py | 33 ++--
45 files changed, 759 insertions(+), 713 deletions(-)
diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py
index fa10e60de10a7..607544a1c8394 100644
--- a/tests/quantization/test_fp8.py
+++ b/tests/quantization/test_fp8.py
@@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
fc1 = model.model.decoder.layers[0].fc1
- assert isinstance(fc1.linear_method, Fp8LinearMethod)
+ assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py
index a97cc0b3706b4..df1db4e6c4001 100644
--- a/tests/tensorizer_loader/test_tensorizer.py
+++ b/tests/tensorizer_loader/test_tensorizer.py
@@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
- linear_method=mock_linear_method)
+ quant_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
- linear_method=mock_linear_method)
+ quant_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 98e74168002c4..4eaf73fbcfda4 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -389,10 +389,9 @@ def set_mapping(
self.indices = base_indices
self.indices_len = indices_len
- def apply_weights(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- output = self.base_layer.linear_method.apply_weights(
- self.base_layer, x, bias)
+ def apply(self, x: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora(
x,
self.lora_a_stacked,
@@ -416,7 +415,7 @@ def forward(self, input_):
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
- output_parallel = self.apply_weights(input_, bias)
+ output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
@@ -523,10 +522,9 @@ def set_lora(
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
- def apply_weights(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- output = self.base_layer.linear_method.apply_weights(
- self.base_layer, x, bias)
+ def apply(self, x: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
@@ -765,10 +763,9 @@ def set_lora(
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
- def apply_weights(self, x: torch.Tensor,
- bias: Optional[torch.Tensor]) -> torch.Tensor:
- output = self.base_layer.linear_method.apply_weights(
- self.base_layer, x, bias)
+ def apply(self, x: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
_apply_lora_packed_nslice(
x,
self.lora_a_stacked,
@@ -862,9 +859,8 @@ def set_mapping(
self.indices = base_indices
self.indices_len = indices_len
- def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
- output = self.base_layer.linear_method.apply_weights(
- self.base_layer, x)
+ def apply(self, x: torch.Tensor) -> torch.Tensor:
+ output = self.base_layer.quant_method.apply(self.base_layer, x)
_apply_lora(
x,
self.lora_a_stacked,
@@ -897,7 +893,7 @@ def forward(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
- output_parallel = self.apply_weights(input_parallel)
+ output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 6ad7ae0f63197..1bd6c42ab3fd8 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -1,9 +1,8 @@
-from abc import ABC, abstractmethod
+from abc import abstractmethod
from typing import List, Optional
import torch
import torch.nn.functional as F
-from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@@ -12,6 +11,8 @@
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
@@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
-class LinearMethodBase(ABC):
+class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
@@ -50,22 +51,15 @@ def create_weights(self, layer: torch.nn.Module,
raise NotImplementedError
@abstractmethod
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
- def process_weights_after_loading(self, layer: nn.Module) -> None:
- """Process the weight after loading.
-
- This can be used for example, to transpose weights for computation.
- """
- return
-
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
@@ -92,10 +86,10 @@ def create_weights(self, layer: torch.nn.Module,
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
@@ -104,8 +98,8 @@ def apply_weights(self,
return F.linear(x, weight, bias)
-class ReplicatedLinear(torch.nn.Module):
- """Replicated linear layer.
+class LinearBase(torch.nn.Module):
+ """Base linear layer.
Args:
input_size: input dimension of the linear layer.
@@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
- linear_method: (Maybe quantized) linear method.
+ quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
- bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -134,12 +127,43 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
- if linear_method is None:
- linear_method = UnquantizedLinearMethod()
- self.linear_method = linear_method
- self.linear_method.create_weights(self, self.input_size,
- [self.output_size], self.input_size,
- self.output_size, self.params_dtype)
+ if quant_config is None:
+ self.quant_method = UnquantizedLinearMethod()
+ else:
+ self.quant_method = quant_config.get_quant_method(self)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ raise NotImplementedError
+
+
+class ReplicatedLinear(LinearBase):
+ """Replicated linear layer.
+
+ Args:
+ input_size: input dimension of the linear layer.
+ output_size: output dimension of the linear layer.
+ bias: If true, add bias.
+ skip_bias_add: If true, skip adding bias but instead return it.
+ params_dtype: Data type for the parameters.
+ quant_config: Quantization configure.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = True,
+ skip_bias_add: bool = False,
+ params_dtype: Optional[torch.dtype] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ ):
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
+ quant_config)
+
+ self.quant_method.create_weights(self, self.input_size,
+ [self.output_size], self.input_size,
+ self.output_size, self.params_dtype)
+
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
@@ -149,12 +173,12 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
- output = self.linear_method.apply_weights(self, x, bias)
+ output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
-class ColumnParallelLinear(torch.nn.Module):
+class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
@@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
- linear_method: (Maybe quantized) linear method.
+ quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
@@ -184,34 +208,26 @@ def __init__(
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
):
- super().__init__()
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
+ quant_config)
- # Keep input parameters
- self.input_size = input_size
- self.output_size = output_size
self.gather_output = gather_output
+
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
- self.skip_bias_add = skip_bias_add
- if params_dtype is None:
- params_dtype = torch.get_default_dtype()
- self.params_dtype = params_dtype
- if linear_method is None:
- linear_method = UnquantizedLinearMethod()
if output_sizes is None:
output_sizes = [output_size]
- self.linear_method = linear_method
- self.linear_method.create_weights(self,
- self.input_size,
- [x // tp_size for x in output_sizes],
- self.input_size,
- self.output_size,
- self.params_dtype,
- weight_loader=self.weight_loader)
+ self.quant_method.create_weights(self,
+ self.input_size,
+ [x // tp_size for x in output_sizes],
+ self.input_size,
+ self.output_size,
+ self.params_dtype,
+ weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
@@ -239,7 +255,7 @@ def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
- output_parallel = self.linear_method.apply_weights(self, input_, bias)
+ output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
@@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
- linear_method: (Maybe quantized) linear method.
+ quant_config: Quantization configure.
"""
def __init__(
@@ -278,13 +294,13 @@ def __init__(
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
- skip_bias_add, params_dtype, linear_method,
+ skip_bias_add, params_dtype, quant_config,
self.output_sizes)
def weight_loader(self,
@@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
- linear_method: (Maybe quantized) linear method.
+ quant_config: Quantization configure.
"""
def __init__(
@@ -396,7 +412,7 @@ def __init__(
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
@@ -424,7 +440,7 @@ def __init__(
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
- params_dtype, linear_method, output_sizes)
+ params_dtype, quant_config, output_sizes)
def weight_loader(self,
param: Parameter,
@@ -517,7 +533,7 @@ def weight_loader(self,
param_data.copy_(loaded_weight)
-class RowParallelLinear(torch.nn.Module):
+class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
@@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
- linear_method: (Maybe quantized) linear method.
+ quant_config: Quantization configure.
"""
def __init__(
@@ -552,32 +568,24 @@ def __init__(
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
- super().__init__()
- # Keep input parameters
- self.input_size = input_size
- self.output_size = output_size
+ super().__init__(input_size, output_size, skip_bias_add, params_dtype,
+ quant_config)
+
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
- if params_dtype is None:
- params_dtype = torch.get_default_dtype()
- self.params_dtype = params_dtype
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
- self.skip_bias_add = skip_bias_add
- if linear_method is None:
- linear_method = UnquantizedLinearMethod()
- self.linear_method = linear_method
- self.linear_method.create_weights(self,
- self.input_size_per_partition,
- [self.output_size],
- self.input_size,
- self.output_size,
- self.params_dtype,
- weight_loader=self.weight_loader)
+ self.quant_method.create_weights(self,
+ self.input_size_per_partition,
+ [self.output_size],
+ self.input_size,
+ self.output_size,
+ self.params_dtype,
+ weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
@@ -616,8 +624,7 @@ def forward(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
- output_parallel = self.linear_method.apply_weights(
- self, input_parallel)
+ output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index a525add458499..0820f17c5c50d 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -4,7 +4,7 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
-from vllm.model_executor.layers.quantization.fp8 import FP8Config
+from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
@@ -12,7 +12,7 @@
QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
- "fp8": FP8Config,
+ "fp8": Fp8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py
index b48c6e1702be4..83e24fadc1405 100644
--- a/vllm/model_executor/layers/quantization/aqlm.py
+++ b/vllm/model_executor/layers/quantization/aqlm.py
@@ -9,10 +9,10 @@
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- set_weight_attrs)
+from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype:
@@ -207,8 +207,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
- def get_linear_method(self) -> "AQLMLinearMethod":
- return AQLMLinearMethod(self)
+ def get_quant_method(
+ self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return AQLMLinearMethod(self)
+ return None
def get_scaled_act_names(self) -> List[str]:
return []
@@ -321,7 +324,7 @@ def create_weights(self, layer: torch.nn.Module,
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
- def apply_weights(
+ def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py
index 4f75134ee1889..f4fc7ce020e95 100644
--- a/vllm/model_executor/layers/quantization/awq.py
+++ b/vllm/model_executor/layers/quantization/awq.py
@@ -4,10 +4,10 @@
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- set_weight_attrs)
+from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.utils import set_weight_attrs
class AWQConfig(QuantizationConfig):
@@ -62,8 +62,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
- def get_linear_method(self) -> "AWQLinearMethod":
- return AWQLinearMethod(self)
+ def get_quant_method(
+ self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return AWQLinearMethod(self)
+ return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@@ -147,10 +150,10 @@ def create_weights(self, layer: torch.nn.Module,
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py
index 6115e7c3be956..b755b1328504a 100644
--- a/vllm/model_executor/layers/quantization/base_config.py
+++ b/vllm/model_executor/layers/quantization/base_config.py
@@ -2,8 +2,33 @@
from typing import Any, Dict, List
import torch
+from torch import nn
-from vllm.model_executor.layers.linear import LinearMethodBase
+
+class QuantizeMethodBase(ABC):
+ """Base class for different quantized methods."""
+
+ @abstractmethod
+ def create_weights(self, layer: torch.nn.Module, *weight_args,
+ **extra_weight_attrs):
+ """Create weights for a layer.
+
+ The weights will be set as attributes of the layer."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
+ """Apply the weights in layer to the input tensor.
+
+ Expects create_weights to have been called before on the layer."""
+ raise NotImplementedError
+
+ def process_weights_after_loading(self, layer: nn.Module) -> None:
+ """Process the weight after loading.
+
+ This can be used for example, to transpose weights for computation.
+ """
+ return
class QuantizationConfig(ABC):
@@ -51,8 +76,8 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"quantization config.")
@abstractmethod
- def get_linear_method(self) -> LinearMethodBase:
- """Get the linear method to use for the quantized linear layer."""
+ def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
+ """Get the quantize method to use for the quantized layer."""
raise NotImplementedError
@abstractmethod
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 01e494c870e71..39679834b545c 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -1,16 +1,17 @@
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- set_weight_attrs)
+from vllm import _custom_ops as ops
+from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
- QuantizationConfig)
+ QuantizationConfig, QuantizeMethodBase)
+from vllm.model_executor.utils import set_weight_attrs
-class FP8Config(QuantizationConfig):
+class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@classmethod
@@ -33,11 +34,14 @@ def get_config_filenames(cls) -> List[str]:
return []
@classmethod
- def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
+ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
return cls()
- def get_linear_method(self) -> "Fp8LinearMethod":
- return Fp8LinearMethod(self)
+ def get_quant_method(
+ self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
+ if isinstance(layer, LinearBase):
+ return Fp8LinearMethod(self)
+ return None
def get_scaled_act_names(self) -> List[str]:
return []
@@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""
- def __init__(self, quant_config: FP8Config):
+ def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(
@@ -86,24 +90,24 @@ def create_weights(
layer.register_parameter("weight_scaling_factor", w_scale)
def process_weights_after_loading(self, layer: Module) -> None:
- # Although the linear_method is propagated to all layers,
+ # Although the quant_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
return
- qweight, weight_scale = per_tensor_quantize(layer.weight)
+ qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- qinput, x_scale = per_tensor_quantize(x)
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ qinput, x_scale = ops.scaled_fp8_quant(x)
output, _ = torch._scaled_mm(
qinput,
layer.weight,
@@ -113,27 +117,3 @@ def apply_weights(self,
bias=bias,
)
return output
-
-
-def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
- """Quantize a tensor using per-tensor static scaling factor.
-
- Args:
- tensor: The input tensor.
- """
- finfo = torch.finfo(torch.float8_e4m3fn)
- # Calculate the scale as dtype max divided by absmax.
- # Since .abs() creates a new tensor, we use aminmax to get
- # the min and max first and then calculate the absmax.
- min_val, max_val = tensor.aminmax()
- amax = min_val.abs().max(max_val.abs())
- scale = finfo.max / amax.clamp(min=1e-12)
- # scale and clamp the tensor to bring it to
- # the representative range of float8 data type
- # (as default cast is unsaturated)
- qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
- # Return both float8 data and the inverse scale (as float),
- # as both required as inputs to torch._scaled_mm
- qweight = qweight.to(torch.float8_e4m3fn)
- scale = scale.float().reciprocal()
- return qweight, scale
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index 92a5cdb9af928..ae9f7019f0592 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -7,10 +7,10 @@
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- set_weight_attrs)
+from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.utils import set_weight_attrs
class GPTQConfig(QuantizationConfig):
@@ -63,8 +63,11 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
- def get_linear_method(self) -> "GPTQLinearMethod":
- return GPTQLinearMethod(self)
+ def get_quant_method(
+ self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return GPTQLinearMethod(self)
+ return None
def get_scaled_act_names(self) -> List[str]:
return []
@@ -194,10 +197,10 @@ def create_weights(
layer.exllama_state = exllama_state
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py
index 00c3c404c2d7a..94aba620ea083 100644
--- a/vllm/model_executor/layers/quantization/marlin.py
+++ b/vllm/model_executor/layers/quantization/marlin.py
@@ -4,10 +4,10 @@
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- set_weight_attrs)
+from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm.model_executor.utils import set_weight_attrs
class MarlinConfig(QuantizationConfig):
@@ -72,8 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
- def get_linear_method(self) -> "MarlinLinearMethod":
- return MarlinLinearMethod(self)
+ def get_quant_method(
+ self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return MarlinLinearMethod(self)
+ return None
def get_scaled_act_names(self) -> List[str]:
return []
@@ -197,7 +200,7 @@ def create_weights(
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
- def apply_weights(
+ def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py
index cc44447d347b8..971078fe25a9b 100644
--- a/vllm/model_executor/layers/quantization/squeezellm.py
+++ b/vllm/model_executor/layers/quantization/squeezellm.py
@@ -4,10 +4,10 @@
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- set_weight_attrs)
+from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.base_config import (
- QuantizationConfig)
+ QuantizationConfig, QuantizeMethodBase)
+from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
@@ -51,14 +51,18 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
- def get_linear_method(self) -> "SqueezeLLMLinearMethod":
- return SqueezeLLMLinearMethod(self)
+ def get_quant_method(
+ self,
+ layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return SqueezeLLMLinearMethod(self)
+ return
def get_scaled_act_names(self) -> List[str]:
return []
-class SqueezeLLMLinearMethod(LinearMethodBase):
+class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM.
Args:
@@ -112,10 +116,10 @@ def create_weights(self, layer: torch.nn.Module,
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
- def apply_weights(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index f75c35a69d4a9..ad80243019a65 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -3,8 +3,7 @@
import glob
import os
from abc import ABC, abstractmethod
-from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
- Type)
+from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import torch
from torch import nn
@@ -13,6 +12,8 @@
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator)
@@ -24,9 +25,6 @@
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
-if TYPE_CHECKING:
- from vllm.model_executor.layers.linear import LinearMethodBase
-
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
@@ -34,11 +32,10 @@
logger = init_logger(__name__)
-def _get_linear_method(
+def _get_quantization_config(
model_config: ModelConfig,
- load_config: LoadConfig) -> Optional["LinearMethodBase"]:
- """Get the (maybe quantized) linear method."""
- linear_method = None
+ load_config: LoadConfig) -> Optional[QuantizationConfig]:
+ """Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
@@ -55,9 +52,8 @@ def _get_linear_method(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
-
- linear_method = quant_config.get_linear_method()
- return linear_method
+ return quant_config
+ return None
def _get_model_initialization_kwargs(
@@ -85,10 +81,10 @@ def _initialize_model(
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
- linear_method = _get_linear_method(model_config, load_config)
+ quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config,
- linear_method=linear_method,
+ quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
@@ -229,9 +225,11 @@ def load_model(self, *, model_config: ModelConfig,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
- linear_method = getattr(module, "linear_method", None)
- if linear_method is not None:
- linear_method.process_weights_after_loading(module)
+ quant_method = getattr(module, "quant_method", None)
+ if quant_method is not None:
+ quant_method.process_weights_after_loading(module)
+ # FIXME: Remove this after Mixtral is updated
+ # to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()
@@ -314,11 +312,11 @@ def _load_model_serialized(
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
- linear_method = _get_linear_method(model_config,
- self.load_config)
+ quant_config = _get_quantization_config(
+ model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
- extra_kwargs["linear_method"] = linear_method
+ extra_kwargs["quant_config"] = quant_config
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py
index 7e65d54bc522f..8fc6d16672117 100644
--- a/vllm/model_executor/model_loader/tensorizer.py
+++ b/vllm/model_executor/model_loader/tensorizer.py
@@ -13,7 +13,8 @@
from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger
-from vllm.model_executor.layers.linear import LinearMethodBase
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -251,7 +252,7 @@ class TensorizerAgent:
"""
def __init__(self, tensorizer_config: TensorizerConfig,
- linear_method: LinearMethodBase, **extra_kwargs):
+ quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_load_fail is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
@@ -262,10 +263,10 @@ def __init__(self, tensorizer_config: TensorizerConfig,
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs
- if extra_kwargs.get("linear_method", None) is not None:
- self.linear_method = extra_kwargs["linear_method"]
+ if extra_kwargs.get("quant_config", None) is not None:
+ self.quant_config = extra_kwargs["quant_config"]
else:
- self.linear_method = linear_method
+ self.quant_config = quant_config
self.model = self._init_model()
def _init_model(self):
@@ -274,7 +275,7 @@ def _init_model(self):
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
- linear_method=self.linear_method,
+ quant_config=self.quant_config,
**self.extra_kwargs)
def _resize_lora_embeddings(self):
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index 69162b0a92d65..186cee2584369 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -31,11 +31,12 @@
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -77,17 +78,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -110,7 +111,7 @@ def __init__(
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
@@ -132,13 +133,13 @@ def __init__(
self.total_num_heads,
self.total_num_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@@ -196,13 +197,13 @@ def __init__(self,
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
@@ -254,7 +255,7 @@ def __init__(self,
config.hidden_size,
)
self.layers = nn.ModuleList([
- BaiChuanDecoderLayer(config, position_embedding, linear_method)
+ BaiChuanDecoderLayer(config, position_embedding, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -303,13 +304,13 @@ def __init__(
self,
config,
position_embedding: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = BaiChuanModel(config, position_embedding, linear_method)
+ self.quant_config = quant_config
+ self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
- super().__init__(config, "ROPE", linear_method, lora_config)
+ super().__init__(config, "ROPE", quant_config, lora_config)
else: # baichuan 13b, baichuan2 13b
- super().__init__(config, "ALIBI", linear_method, lora_config)
+ super().__init__(config, "ALIBI", quant_config, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
- super().__init__(config, "ROPE", linear_method, lora_config)
+ super().__init__(config, "ROPE", quant_config, lora_config)
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index 14f325e624f41..b425af4863c36 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -28,10 +28,11 @@
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
@@ -87,13 +88,13 @@ def __init__(
self.head_dim,
self.total_num_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
# Create the alibi slopes and slice them.
@@ -129,21 +130,21 @@ class BloomMLP(nn.Module):
def __init__(
self,
config: BloomConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -158,17 +159,17 @@ class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
- self.self_attention = BloomAttention(config, linear_method)
+ self.self_attention = BloomAttention(config, quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
- self.mlp = BloomMLP(config, linear_method)
+ self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
@@ -214,7 +215,7 @@ class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.hidden_size
@@ -229,7 +230,7 @@ def __init__(
# Transformer blocks
self.h = nn.ModuleList([
- BloomBlock(config, linear_method)
+ BloomBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
@@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module):
def __init__(
self,
config: BloomConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.transformer = BloomModel(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = BloomModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 3cdb7a7bca1c1..e116af2ed080d 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -13,11 +13,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
@@ -65,13 +66,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
- linear_method=linear_method,
+ quant_config=quant_config,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -134,7 +135,7 @@ def __init__(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.activation_func = SiluAndMul()
@@ -144,7 +145,7 @@ def __init__(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
- linear_method=linear_method,
+ quant_config=quant_config,
)
def forward(self, hidden_states):
@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
@@ -180,7 +181,7 @@ def __init__(
eps=config.layernorm_epsilon)
# Self attention.
- self.self_attention = GLMAttention(config, linear_method)
+ self.self_attention = GLMAttention(config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
@@ -188,7 +189,7 @@ def __init__(
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
- self.mlp = GLMMLP(config, linear_method)
+ self.mlp = GLMMLP(config, quant_config)
def forward(
self,
@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
@@ -246,7 +247,7 @@ def __init__(
# Transformer layers.
self.layers = nn.ModuleList(
- [GLMBlock(config, linear_method) for i in range(self.num_layers)])
+ [GLMBlock(config, quant_config) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -291,7 +292,7 @@ def __init__(
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
- self.encoder = GLMTransformer(config, linear_method)
+ self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
- self.linear_method = linear_method
- self.transformer = ChatGLMModel(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index d80969773e163..17c2f1223d96b 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -32,11 +32,12 @@
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -101,13 +102,13 @@ def __init__(
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.act_fn = SiluAndMul()
@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def __init__(
self,
config: CohereConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
@@ -158,13 +159,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = CohereAttention(config, linear_method=linear_method)
+ self.self_attn = CohereAttention(config, quant_config=quant_config)
- self.mlp = CohereMLP(config, linear_method=linear_method)
+ self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def __init__(
self,
config: CohereConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -265,7 +266,7 @@ def __init__(
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
- CohereDecoderLayer(config, linear_method=linear_method)
+ CohereDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(param_shape=(config.hidden_size),
@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def __init__(
self,
config: CohereConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
+ self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
- self.model = CohereModel(config, linear_method)
+ self.model = CohereModel(config, quant_config)
self.sampler = Sampler()
@torch.no_grad()
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 179094b8fd7aa..a4a0ae50c645e 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -9,11 +9,12 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- QKVParallelLinear,
+from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -44,7 +45,7 @@ def __init__(
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
- linear_method=None,
+ quant_config=None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def __init__(
self,
config: DbrxConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
@@ -183,13 +184,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
- self.attn = DbrxAttention(config, linear_method)
+ self.attn = DbrxAttention(config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
- self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
- self.ffn = DbrxExperts(config, linear_method)
+ self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
+ self.ffn = DbrxExperts(config, quant_config)
def forward(
self,
@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.wte = VocabParallelEmbedding(
@@ -315,7 +316,7 @@ def __init__(
config.d_model,
)
self.blocks = nn.ModuleList(
- [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
+ [DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def __init__(
self,
config: DbrxConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
+ self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
- self.transformer = DbrxModel(config, linear_method)
+ self.transformer = DbrxModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py
index d476630ee6f11..be9a6b6813f8f 100644
--- a/vllm/model_executor/models/decilm.py
+++ b/vllm/model_executor/models/decilm.py
@@ -29,7 +29,8 @@
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
-from vllm.model_executor.layers.linear import LinearMethodBase
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config,
- linear_method=linear_method,
+ quant_config=quant_config,
lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py
index 46101a152ec0d..e5f7ba086a35d 100644
--- a/vllm/model_executor/models/deepseek.py
+++ b/vllm/model_executor/models/deepseek.py
@@ -34,12 +34,13 @@
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -56,18 +57,18 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -103,7 +104,7 @@ def __init__(
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
@@ -112,7 +113,7 @@ def __init__(
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
- linear_method=None)
+ quant_config=None)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
@@ -121,7 +122,7 @@ def __init__(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=False,
)
@@ -177,7 +178,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -208,14 +209,14 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -251,7 +252,7 @@ def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -266,18 +267,18 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
- self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
+ self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
@@ -331,9 +332,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- DeepseekDecoderLayer(config,
- layer_idx,
- linear_method=linear_method)
+ DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = DeepseekModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index 25ce239d14662..4be1f064cdd3e 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -32,10 +32,11 @@
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -115,7 +116,7 @@ def __init__(
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
@@ -129,7 +130,7 @@ def __init__(
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def __init__(
self,
config: FalconConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -201,8 +202,8 @@ def __init__(
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
- linear_method=linear_method)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config=quant_config)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
@@ -212,7 +213,7 @@ def __init__(
bias=config.bias,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
- linear_method=linear_method)
+ quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
@@ -229,13 +230,13 @@ class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
- self.self_attention = FalconAttention(config, linear_method)
- self.mlp = FalconMLP(config, linear_method)
+ self.self_attention = FalconAttention(config, quant_config)
+ self.mlp = FalconMLP(config, quant_config)
self.config = config
if config.new_decoder_architecture:
@@ -311,7 +312,7 @@ class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -327,7 +328,7 @@ def __init__(
# Transformer blocks
self.h = nn.ModuleList([
- FalconDecoderLayer(config, linear_method)
+ FalconDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
@@ -359,12 +360,12 @@ class FalconForCausalLM(nn.Module):
def __init__(
self,
config: FalconConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.transformer = FalconModel(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = FalconModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index c3193258d6418..bb73ff4d206da 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -27,11 +27,12 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -77,17 +78,17 @@ def __init__(
intermediate_size: int,
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
@@ -106,7 +107,7 @@ def __init__(self,
head_dim: int,
max_position_embeddings: int = 8192,
rope_theta: float = 10000,
- linear_method: Optional[LinearMethodBase] = None) -> None:
+ quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -135,13 +136,13 @@ def __init__(self,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: GemmaConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -187,14 +188,14 @@ def __init__(
head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
def __init__(
self,
config: GemmaConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -245,7 +246,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- GemmaDecoderLayer(config, linear_method)
+ GemmaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: GemmaConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = GemmaModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = GemmaModel(config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 850050c7232d0..ac1dce6dec8a6 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -27,10 +27,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
def __init__(
self,
config: GPT2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
@@ -61,13 +62,13 @@ def __init__(
self.head_dim,
total_num_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
@@ -90,7 +91,7 @@ def __init__(
self,
intermediate_size: int,
config: GPT2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -98,15 +99,15 @@ def __init__(
hidden_size,
intermediate_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
@@ -122,7 +123,7 @@ class GPT2Block(nn.Module):
def __init__(
self,
config: GPT2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -130,9 +131,9 @@ def __init__(
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = GPT2Attention(config, linear_method)
+ self.attn = GPT2Attention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.mlp = GPT2MLP(inner_dim, config, linear_method)
+ self.mlp = GPT2MLP(inner_dim, config, quant_config)
def forward(
self,
@@ -163,7 +164,7 @@ class GPT2Model(nn.Module):
def __init__(
self,
config: GPT2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -174,7 +175,7 @@ def __init__(
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
- GPT2Block(config, linear_method)
+ GPT2Block(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -203,12 +204,12 @@ class GPT2LMHeadModel(nn.Module):
def __init__(
self,
config: GPT2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.transformer = GPT2Model(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = GPT2Model(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index 8278ba02514d5..e52ac679f5d03 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -28,10 +28,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
@@ -72,14 +73,14 @@ def __init__(
total_num_heads,
total_num_kv_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
@@ -111,7 +112,7 @@ def __init__(
self,
intermediate_size: int,
config: GPTBigCodeConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -119,15 +120,15 @@ def __init__(
hidden_size,
intermediate_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
@@ -143,7 +144,7 @@ class GPTBigCodeBlock(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -151,9 +152,9 @@ def __init__(
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = GPTBigCodeAttention(config, linear_method)
+ self.attn = GPTBigCodeAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.mlp = GPTBigMLP(inner_dim, config, linear_method)
+ self.mlp = GPTBigMLP(inner_dim, config, quant_config)
def forward(
self,
@@ -184,7 +185,7 @@ class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -195,7 +196,7 @@ def __init__(
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
- GPTBigCodeBlock(config, linear_method)
+ GPTBigCodeBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -224,12 +225,12 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.transformer = GPTBigCodeModel(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = GPTBigCodeModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index 7a830d7f9c965..287f4186f7469 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -26,10 +26,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
def __init__(
self,
config: GPTJConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
@@ -56,13 +57,13 @@ def __init__(
self.head_size,
self.total_num_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
@@ -105,21 +106,21 @@ def __init__(
self,
intermediate_size: int,
config: GPTJConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
@@ -135,14 +136,14 @@ class GPTJBlock(nn.Module):
def __init__(
self,
config: GPTJConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
inner_dim = (4 * config.n_embd
if config.n_inner is None else config.n_inner)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
- self.attn = GPTJAttention(config, linear_method)
- self.mlp = GPTJMLP(inner_dim, config, linear_method)
+ self.attn = GPTJAttention(config, quant_config)
+ self.mlp = GPTJMLP(inner_dim, config, quant_config)
def forward(
self,
@@ -169,7 +170,7 @@ class GPTJModel(nn.Module):
def __init__(
self,
config: GPTJConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -179,7 +180,7 @@ def __init__(
self.embed_dim,
)
self.h = nn.ModuleList(
- [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
+ [GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
@@ -207,13 +208,13 @@ class GPTJForCausalLM(nn.Module):
def __init__(
self,
config: GPTJConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
+ self.quant_config = quant_config
assert not config.tie_word_embeddings
- self.transformer = GPTJModel(config, linear_method)
+ self.transformer = GPTJModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.n_embd,
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index b946aed92ed35..cbc5115bd377b 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -26,10 +26,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.total_num_heads = config.num_attention_heads
@@ -63,13 +64,13 @@ def __init__(
self.head_size,
self.total_num_heads,
bias=self.bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=self.bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
@@ -105,20 +106,20 @@ class GPTNeoXMLP(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
@@ -134,7 +135,7 @@ class GPTNeoXLayer(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
@@ -142,8 +143,8 @@ def __init__(
eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
- self.attention = GPTNeoXAttention(config, linear_method)
- self.mlp = GPTNeoXMLP(config, linear_method)
+ self.attention = GPTNeoXAttention(config, quant_config)
+ self.mlp = GPTNeoXMLP(config, quant_config)
def forward(
self,
@@ -182,7 +183,7 @@ class GPTNeoXModel(nn.Module):
def __init__(
self,
config: GPTNeoXConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -192,7 +193,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- GPTNeoXLayer(config, linear_method)
+ GPTNeoXLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
@@ -223,12 +224,12 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.gpt_neox = GPTNeoXModel(config, linear_method)
+ self.quant_config = quant_config
+ self.gpt_neox = GPTNeoXModel(config, quant_config)
self.embed_out = ParallelLMHead(
config.vocab_size,
config.hidden_size,
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index db1da8bdc4fb9..5811cae83bf8b 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -9,11 +9,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -30,17 +31,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.w2 = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -63,7 +64,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -94,13 +95,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -150,13 +151,13 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -206,7 +207,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- InternLMDecoderLayer(config, linear_method)
+ InternLMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = InternLM2Model(config, linear_method)
+ self.quant_config = quant_config
+ self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index e7ee749e824e4..bd6a180ec8dfc 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -29,10 +29,11 @@
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
def __init__(
self,
config: JAISConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
@@ -88,13 +89,13 @@ def __init__(
self.head_dim,
total_num_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
tp_rank = get_tensor_model_parallel_rank()
@@ -128,7 +129,7 @@ def __init__(
self,
intermediate_size: int,
config: JAISConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -137,19 +138,19 @@ def __init__(
hidden_size,
intermediate_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_fc2 = (ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
) if self.swiglu else None)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.act = SwiGLUActivation()
@@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
def __init__(
self,
config: JAISConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
@@ -177,9 +178,9 @@ def __init__(
hidden_size)
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.attn = JAISAttention(config, linear_method)
+ self.attn = JAISAttention(config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
- self.mlp = JAISMLP(inner_dim, config, linear_method)
+ self.mlp = JAISMLP(inner_dim, config, quant_config)
def forward(
self,
@@ -210,7 +211,7 @@ class JAISModel(nn.Module):
def __init__(
self,
config: JAISConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -227,7 +228,7 @@ def __init__(
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
- JAISBlock(config, linear_method)
+ JAISBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
def __init__(
self,
config: JAISConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.transformer = JAISModel(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = JAISModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index c102b40045c92..f6d7fc8733fce 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -33,11 +33,12 @@
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -56,17 +57,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QKVParallelLinear] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -89,7 +90,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
@@ -131,13 +132,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -174,7 +175,7 @@ class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -199,7 +200,7 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
bias=attention_bias,
sliding_window=sliding_window,
)
@@ -207,7 +208,7 @@ def __init__(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -248,7 +249,7 @@ class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
@@ -264,7 +265,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
- LlamaDecoderLayer(config, linear_method)
+ LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -329,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = LlamaModel(config, linear_method, lora_config=lora_config)
+ self.model = LlamaModel(config, quant_config, lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index 314a2792bf167..dcde4dfa0795e 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -9,8 +9,9 @@
from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.activation import get_act_fn
-from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
def __init__(self,
config: "LlavaConfig",
vision_language_config: VisionLanguageConfig,
- linear_method: Optional["LinearMethodBase"] = None) -> None:
+ quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
self.config = config
@@ -83,8 +84,8 @@ def __init__(self,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
- self.linear_method = linear_method
- self.language_model = LlamaModel(config.text_config, linear_method)
+ self.quant_config = quant_config
+ self.language_model = LlamaModel(config.text_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index f0d72fafcaf70..c90bcfbfc4707 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -35,12 +35,13 @@
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -84,7 +85,7 @@ def __init__(
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
- linear_method=None)
+ quant_config=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
@@ -147,17 +148,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -180,7 +181,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -211,13 +212,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -274,7 +275,7 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.num_experts = getattr(self.config, "num_experts", 0)
if self.num_experts == 0:
@@ -282,7 +283,7 @@ def __init__(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
else:
self.mlp = MiniCPMMoE(num_experts=config.num_experts,
@@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
@@ -345,7 +346,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
- MiniCPMDecoderLayer(config, linear_method)
+ MiniCPMDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.num_experts = getattr(self.config, "num_experts", 0)
- self.linear_method = linear_method
+ self.quant_config = quant_config
self.model = MiniCPMModel(config,
- linear_method,
+ quant_config,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index a33b795d7088e..7847df735ab44 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -27,6 +27,7 @@
from torch import nn
from transformers import MixtralConfig
+from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
@@ -34,13 +35,13 @@
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- QKVParallelLinear,
+from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
- per_tensor_quantize)
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -69,7 +70,7 @@ def __init__(
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
@@ -79,7 +80,7 @@ def __init__(
self.intermediate_size = intermediate_size // self.tp_size
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
- self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
+ self.use_fp8 = isinstance(quant_config, Fp8Config)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@@ -89,7 +90,7 @@ def __init__(
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
- linear_method=None)
+ quant_config=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
@@ -140,10 +141,10 @@ def process_weights_after_loading(self):
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
- ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
+ ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
self.ws.data[expert, :, :])
w2s[expert, :, :], self.w2s_scale[
- expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
+ expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
self.ws = nn.Parameter(ws, requires_grad=False)
self.w2s = nn.Parameter(w2s, requires_grad=False)
@@ -178,7 +179,7 @@ def __init__(self,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -203,12 +204,12 @@ def __init__(self,
self.rope_theta = rope_theta
self.sliding_window = sliding_window
- if isinstance(linear_method, Fp8LinearMethod):
+ if isinstance(quant_config, Fp8Config):
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
- linear_method = None
+ quant_config = None
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -216,13 +217,13 @@ def __init__(self,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -259,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -272,13 +273,13 @@ def __init__(
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
- linear_method=linear_method)
+ quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
- linear_method=linear_method)
+ quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -318,7 +319,7 @@ class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
@@ -334,7 +335,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
- MixtralDecoderLayer(config, linear_method=linear_method)
+ MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -384,14 +385,13 @@ class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
self.model = MixtralModel(config,
- linear_method,
+ quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py
index acd13cc27f159..38c62afced28a 100644
--- a/vllm/model_executor/models/mixtral_quant.py
+++ b/vllm/model_executor/models/mixtral_quant.py
@@ -34,11 +34,12 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- QKVParallelLinear,
+from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -55,7 +56,7 @@ def __init__(
num_experts: int,
hidden_size: int,
intermediate_size: int,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
@@ -65,15 +66,15 @@ def __init__(
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
@@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -115,14 +116,14 @@ def __init__(
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
- linear_method=linear_method)
+ quant_config=quant_config)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
- linear_method=None)
+ quant_config=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
@@ -162,7 +163,7 @@ def __init__(self,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -193,13 +194,13 @@ def __init__(self,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -249,9 +250,9 @@ def __init__(
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
- linear_method=linear_method)
+ quant_config=quant_config)
self.block_sparse_moe = MixtralMoE(config=config,
- linear_method=linear_method)
+ quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
@@ -302,7 +303,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- MixtralDecoderLayer(config, linear_method=linear_method)
+ MixtralDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = MixtralModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = MixtralModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index 340f63286739b..8c5e7e77c9306 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -11,10 +11,11 @@
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
def __init__(
self,
config: MPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
@@ -65,7 +66,7 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=not config.no_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
@@ -74,7 +75,7 @@ def __init__(
self.d_model,
self.d_model,
bias=not config.no_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
@@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
def __init__(
self,
config: MPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
@@ -143,15 +144,15 @@ def __init__(
hidden_size,
intermediate_size,
bias=not config.no_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -166,14 +167,14 @@ class MPTBlock(nn.Module):
def __init__(
self,
config: MPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
- self.attn = MPTAttention(config, linear_method)
+ self.attn = MPTAttention(config, quant_config)
self.norm_2 = nn.LayerNorm(hidden_size)
- self.ffn = MPTMLP(config, linear_method)
+ self.ffn = MPTMLP(config, quant_config)
def forward(
self,
@@ -201,7 +202,7 @@ class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
assert config.embedding_fraction == 1.0
@@ -212,7 +213,7 @@ def __init__(
config.d_model,
)
self.blocks = nn.ModuleList(
- [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
+ [MPTBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
@@ -246,14 +247,14 @@ class MPTForCausalLM(nn.Module):
def __init__(
self,
config: MPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
assert config.tie_word_embeddings
- self.linear_method = linear_method
+ self.quant_config = quant_config
- self.transformer = MPTModel(config, linear_method)
+ self.transformer = MPTModel(config, quant_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 15527569b9e20..f212ea2166e1d 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -30,11 +30,12 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -54,7 +55,7 @@ class OlmoAttention(nn.Module):
def __init__(
self,
config: OlmoConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -79,7 +80,7 @@ def __init__(
self.head_dim,
self.total_num_heads,
bias=config.attention_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
# Rotary embeddings.
@@ -99,7 +100,7 @@ def __init__(
self.hidden_size,
self.hidden_size,
bias=config.attention_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
def forward(
@@ -129,7 +130,7 @@ class OlmoMLP(nn.Module):
def __init__(
self,
config: OlmoConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -141,7 +142,7 @@ def __init__(
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
# Activation function.
@@ -152,7 +153,7 @@ def __init__(
self.intermediate_size,
self.hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
def forward(
@@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module):
def __init__(self,
config: OlmoConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Attention block.
- self.self_attn = OlmoAttention(config, linear_method)
+ self.self_attn = OlmoAttention(config, quant_config)
# MLP block.
- self.mlp = OlmoMLP(config, linear_method)
+ self.mlp = OlmoMLP(config, quant_config)
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
@@ -216,14 +217,14 @@ class OlmoModel(nn.Module):
def __init__(self,
config: OlmoConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
- OlmoDecoderLayer(config, linear_method)
+ OlmoDecoderLayer(config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size,
@@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module):
def __init__(self,
config: OlmoConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = OlmoModel(config, linear_method)
+ self.model = OlmoModel(config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 89263166bca81..838a2f0adc4d1 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -27,11 +27,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -60,7 +61,7 @@ def __init__(
embed_dim: int,
num_heads: int,
bias: bool = True,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_dim = embed_dim
@@ -77,13 +78,13 @@ def __init__(
self.head_dim,
total_num_heads,
bias=bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
@@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
def __init__(
self,
config: OPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -116,7 +117,7 @@ def __init__(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
bias=config.enable_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.do_layer_norm_before = config.do_layer_norm_before
@@ -127,16 +128,16 @@ def __init__(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
@@ -181,7 +182,7 @@ class OPTDecoder(nn.Module):
def __init__(
self,
config: OPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -202,7 +203,7 @@ def __init__(
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
else:
self.project_out = None
@@ -210,7 +211,7 @@ def __init__(
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
else:
self.project_in = None
@@ -226,7 +227,7 @@ def __init__(
self.final_layer_norm = None
self.layers = nn.ModuleList([
- OPTDecoderLayer(config, linear_method)
+ OPTDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
@@ -259,10 +260,10 @@ class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
- self.decoder = OPTDecoder(config, linear_method)
+ self.decoder = OPTDecoder(config, quant_config)
def forward(
self,
@@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module):
def __init__(
self,
config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = OPTModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = OPTModel(config, quant_config)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index bbb9fa5347cc8..9ab5dfb97c19a 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -13,11 +13,12 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -34,17 +35,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -67,7 +68,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -98,13 +99,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -154,13 +155,13 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.mlp = OrionMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
@@ -201,7 +202,7 @@ class OrionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -212,7 +213,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- OrionDecoderLayer(config, linear_method)
+ OrionDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = OrionModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = OrionModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index f974b78a0fbda..7a9b8dcd6a509 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -45,10 +45,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.total_num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
@@ -80,12 +81,12 @@ def __init__(self,
self.head_size,
self.total_num_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
scaling = self.head_size**-0.5
@@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
n_inner = getattr(config, "n_inner", None)
@@ -134,14 +135,14 @@ def __init__(self,
self.fc1 = ColumnParallelLinear(
config.hidden_size,
n_inner,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
def forward(self, hidden_states):
@@ -155,12 +156,12 @@ class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
- self.self_attn = PhiAttention(config, linear_method)
- self.mlp = PhiMLP(config, linear_method)
+ self.self_attn = PhiAttention(config, quant_config)
+ self.mlp = PhiMLP(config, quant_config)
def forward(
self,
@@ -186,14 +187,14 @@ class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
- self.linear_method = linear_method
+ self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
- PhiLayer(config, linear_method)
+ PhiLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
@@ -225,12 +226,12 @@ class PhiForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
- self.linear_method = linear_method
+ self.quant_config = quant_config
- self.model = PhiModel(config, linear_method)
+ self.model = PhiModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index a77da7cb15984..e5e0028888c88 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -14,11 +14,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -35,17 +36,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -67,7 +68,7 @@ def __init__(
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
@@ -83,13 +84,13 @@ def __init__(
self.head_dim,
self.total_num_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.scaling = self.head_dim**-0.5
@@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -134,13 +135,13 @@ def __init__(
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
- linear_method=linear_method)
+ quant_config=quant_config)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(config.hidden_size,
config.intermediate_size // 2,
- linear_method=linear_method)
+ quant_config=quant_config)
def forward(
self,
@@ -174,7 +175,7 @@ class QWenModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -185,7 +186,7 @@ def __init__(
config.hidden_size,
)
self.h = nn.ModuleList([
- QWenBlock(config, linear_method)
+ QWenBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.transformer = QWenModel(config, linear_method)
+ self.quant_config = quant_config
+ self.transformer = QWenModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 71b906e20ac19..62bc7fe22c367 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -33,11 +33,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -54,17 +55,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -86,7 +87,7 @@ def __init__(self,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -117,13 +118,13 @@ def __init__(self,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -159,7 +160,7 @@ def __init__(
self,
config: Qwen2Config,
layer_idx: int,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -174,13 +175,13 @@ def __init__(
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
- linear_method=linear_method,
+ quant_config=quant_config,
sliding_window=config.sliding_window)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
@@ -233,7 +234,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- Qwen2DecoderLayer(config, layer_idx, linear_method)
+ Qwen2DecoderLayer(config, layer_idx, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
def __init__(
self,
config: Qwen2Config,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = Qwen2Model(config, linear_method)
+ self.quant_config = quant_config
+ self.model = Qwen2Model(config, quant_config)
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 59908bc9ef26a..8da89a2b7ba6c 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -36,12 +36,13 @@
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -58,18 +59,18 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
@@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
@@ -105,7 +106,7 @@ def __init__(
Qwen2MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
@@ -114,13 +115,13 @@ def __init__(
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
- linear_method=None)
+ quant_config=None)
if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
reduce_results=False,
)
else:
@@ -186,7 +187,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -217,14 +218,14 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -260,7 +261,7 @@ def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -275,18 +276,18 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
)
if (config.num_experts is not None
and (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
- linear_method=linear_method)
+ quant_config=quant_config)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
@@ -338,9 +339,7 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
- Qwen2MoeDecoderLayer(config,
- layer_idx,
- linear_method=linear_method)
+ Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = Qwen2MoeModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = Qwen2MoeModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 3e6c2db6f3c65..3d4f4f700f867 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -28,11 +28,12 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None) -> None:
+ quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
@@ -54,7 +55,7 @@ def __init__(self,
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False)
@@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None) -> None:
+ quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
@@ -109,11 +110,11 @@ def __init__(self,
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
- linear_method=linear_method)
+ quant_config=quant_config)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
@@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
- self.mlp = StablelmMLP(config, linear_method)
+ self.mlp = StablelmMLP(config, quant_config)
norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None) -> None:
+ quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
- StablelmDecoderLayer(config, linear_method)
+ StablelmDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
norm_eps = getattr(config, "norm_eps",
@@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = StableLMEpochModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = StableLMEpochModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index b90f3da141c2e..29d887b21032b 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -28,10 +28,11 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
- LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
@@ -79,13 +80,13 @@ def __init__(self,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.use_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=self.use_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
@@ -121,21 +122,21 @@ class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=config.use_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
bias=config.use_bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
- quant_config = getattr(linear_method, "quant_config", None)
+ quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
@@ -150,12 +151,11 @@ class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
- self.self_attn = Starcoder2Attention(config,
- linear_method=linear_method)
- self.mlp = Starcoder2MLP(config, linear_method=linear_method)
+ self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
+ self.mlp = Starcoder2MLP(config, quant_config=quant_config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
@@ -192,7 +192,7 @@ class Starcoder2Model(nn.Module):
def __init__(self,
config: Starcoder2Config,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
@@ -202,7 +202,7 @@ def __init__(self,
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
- Starcoder2DecoderLayer(config, linear_method=linear_method)
+ Starcoder2DecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
@@ -227,10 +227,10 @@ class Starcoder2ForCausalLM(nn.Module):
def __init__(self,
config: Starcoder2Config,
- linear_method: Optional[LinearMethodBase] = None):
+ quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
- self.model = Starcoder2Model(config, linear_method=linear_method)
+ self.model = Starcoder2Model(config, quant_config=quant_config)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py
index 4e905390c2340..0fb2662b2f715 100644
--- a/vllm/model_executor/models/xverse.py
+++ b/vllm/model_executor/models/xverse.py
@@ -31,11 +31,12 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import (LinearMethodBase,
- MergedColumnParallelLinear,
+from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -52,17 +53,17 @@ def __init__(
hidden_size: int,
intermediate_size: int,
hidden_act: str,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
- linear_method=linear_method)
+ quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -85,7 +86,7 @@ def __init__(
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
sliding_window: Optional[int] = None,
) -> None:
@@ -112,13 +113,13 @@ def __init__(
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=bias,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.rotary_emb = get_rope(
@@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -171,7 +172,7 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
- linear_method=linear_method,
+ quant_config=quant_config,
bias=getattr(config, "bias", False),
sliding_window=sliding_window,
)
@@ -179,7 +180,7 @@ def __init__(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
- linear_method=linear_method,
+ quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -220,7 +221,7 @@ class XverseModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
@@ -236,7 +237,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
- XverseDecoderLayer(config, linear_method)
+ XverseDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
- linear_method: Optional[LinearMethodBase] = None,
+ quant_config: Optional[QuantizationConfig] = None,
lora_config=None,
) -> None:
super().__init__()
self.config = config
- self.linear_method = linear_method
- self.model = XverseModel(config, linear_method)
+ self.quant_config = quant_config
+ self.model = XverseModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
From aba47be3fef57f37196a91f899068506b9075f4f Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Fri, 26 Apr 2024 15:47:45 -0700
Subject: [PATCH 132/413] [Misc] add RFC issue template (#4401)
Co-authored-by: Simon Mo
---
.github/ISSUE_TEMPLATE/750-RFC.yml | 49 ++++++++++++++++++++++++++++++
1 file changed, 49 insertions(+)
create mode 100644 .github/ISSUE_TEMPLATE/750-RFC.yml
diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml
new file mode 100644
index 0000000000000..5382b124dcd79
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/750-RFC.yml
@@ -0,0 +1,49 @@
+name: 💬 Request for comments (RFC).
+description: Ask for feedback on major architectural changes or design choices.
+title: "[RFC]: "
+labels: ["RFC"]
+
+body:
+- type: markdown
+ attributes:
+ value: >
+ #### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference.
+- type: textarea
+ attributes:
+ label: Motivation.
+ description: >
+ The motivation of the RFC.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Proposed Change.
+ description: >
+ The proposed change of the RFC.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Feedback Period.
+ description: >
+ The feedback period of the RFC. Usually at least one week.
+ validations:
+ required: false
+- type: textarea
+ attributes:
+ label: CC List.
+ description: >
+ The list of people you want to CC.
+ validations:
+ required: false
+- type: textarea
+ attributes:
+ label: Any Other Things.
+ description: >
+ Any other things you would like to mention.
+ validations:
+ required: false
+- type: markdown
+ attributes:
+ value: >
+ Thanks for contributing 🎉!
From 258a2c58d08fc7a242556120877a89404861fbce Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Fri, 26 Apr 2024 21:14:26 -0700
Subject: [PATCH 133/413] [Core] Introduce `DistributedGPUExecutor` abstract
class (#4348)
---
vllm/executor/distributed_gpu_executor.py | 114 ++++++++++++++++++++++
vllm/executor/ray_gpu_executor.py | 94 ++----------------
2 files changed, 122 insertions(+), 86 deletions(-)
create mode 100644 vllm/executor/distributed_gpu_executor.py
diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py
new file mode 100644
index 0000000000000..9dccfa4946391
--- /dev/null
+++ b/vllm/executor/distributed_gpu_executor.py
@@ -0,0 +1,114 @@
+from abc import abstractmethod
+from typing import Any, Dict, Optional, Set, Tuple
+
+from vllm.executor.executor_base import ExecutorAsyncBase
+from vllm.executor.gpu_executor import GPUExecutor
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.sequence import SamplerOutput
+
+logger = init_logger(__name__)
+
+
+class DistributedGPUExecutor(GPUExecutor):
+ """Abstract superclass of multi-GPU executor implementations."""
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ """Determine the number of available KV blocks.
+
+ This invokes `determine_num_available_blocks` on each worker and takes
+ the min of the results, guaranteeing that the selected cache sizes are
+ compatible with all workers.
+
+ Returns:
+ - tuple[num_gpu_blocks, num_cpu_blocks]
+ """
+ # Get the maximum number of blocks that can be allocated on GPU and CPU.
+ num_blocks = self._run_workers("determine_num_available_blocks", )
+
+ # Since we use a shared centralized controller, we take the minimum
+ # number of blocks across all workers to make sure all the memory
+ # operators can be applied to all workers.
+ num_gpu_blocks = min(b[0] for b in num_blocks)
+ num_cpu_blocks = min(b[1] for b in num_blocks)
+
+ return num_gpu_blocks, num_cpu_blocks
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ """Initialize the KV cache in all workers.
+ """
+
+ # NOTE: We log here to avoid multiple logs when number of workers is
+ # greater than one. We could log in the engine, but not all executors
+ # have GPUs.
+ logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
+ num_cpu_blocks)
+
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+ self._run_workers("initialize_cache",
+ num_gpu_blocks=num_gpu_blocks,
+ num_cpu_blocks=num_cpu_blocks)
+
+ def execute_model(self, *args, **kwargs) -> SamplerOutput:
+ all_outputs = self._run_workers("execute_model",
+ driver_args=args,
+ driver_kwargs=kwargs)
+
+ # Only the driver worker returns the sampling results.
+ return all_outputs[0]
+
+ def add_lora(self, lora_request: LoRARequest) -> bool:
+ assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
+ return self._run_workers(
+ "add_lora",
+ lora_request=lora_request,
+ )
+
+ def remove_lora(self, lora_id: int) -> bool:
+ assert lora_id > 0, "lora_id must be greater than 0."
+ return self._run_workers(
+ "remove_lora",
+ lora_id=lora_id,
+ )
+
+ def list_loras(self) -> Set[int]:
+ return self._run_workers("list_loras")
+
+ @abstractmethod
+ def _run_workers(
+ self,
+ method: str,
+ *args,
+ driver_args: Optional[Tuple[Any, ...]] = None,
+ driver_kwargs: Optional[Dict[str, Any]] = None,
+ max_concurrent_workers: Optional[int] = None,
+ **kwargs,
+ ) -> Any:
+ """Runs the given method on all workers."""
+ raise NotImplementedError
+
+
+class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase):
+
+ @abstractmethod
+ async def _run_workers_async(
+ self,
+ method: str,
+ *args,
+ driver_args: Optional[Tuple[Any, ...]] = None,
+ driver_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> Any:
+ """Runs the given method on all workers."""
+ raise NotImplementedError
+
+ async def execute_model_async(self, *args, **kwargs) -> SamplerOutput:
+ all_outputs = await self._run_workers_async("execute_model",
+ driver_args=args,
+ driver_kwargs=kwargs)
+
+ # Only the driver worker returns the sampling results.
+ return all_outputs[0]
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 6f72babe14fd5..1082984828357 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -3,12 +3,12 @@
import pickle
from collections import defaultdict
from itertools import islice, repeat
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from vllm.executor.distributed_gpu_executor import ( # yapf: disable
+ DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger
-from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async)
@@ -27,7 +27,7 @@
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
-class RayGPUExecutor(ExecutorBase):
+class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
@@ -179,50 +179,9 @@ def collect_arg_helper_func(**kwargs):
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
- self._run_workers(
- "load_model",
- max_concurrent_workers=self.parallel_config.
- max_parallel_loading_workers,
- )
-
- def determine_num_available_blocks(self) -> Tuple[int, int]:
- """Determine the number of available KV blocks.
-
- This invokes `determine_num_available_blocks` on each worker and takes
- the min of the results, guaranteeing that the selected cache sizes are
- compatible with all workers.
-
- Returns:
- - Tuple[num_gpu_blocks, num_cpu_blocks]
- """
- # Get the maximum number of blocks that can be allocated on GPU and CPU.
- num_blocks = self._run_workers("determine_num_available_blocks", )
-
- # Since we use a shared centralized controller, we take the minimum
- # number of blocks across all workers to make sure all the memory
- # operators can be applied to all workers.
- num_gpu_blocks = min(b[0] for b in num_blocks)
- num_cpu_blocks = min(b[1] for b in num_blocks)
-
- return num_gpu_blocks, num_cpu_blocks
-
- def initialize_cache(self, num_gpu_blocks: int,
- num_cpu_blocks: int) -> None:
- """Initialize the KV cache in all workers.
- """
-
- # NOTE: We log here to avoid multiple logs when number of workers is
- # greater than one. We could log in the engine, but not all executors
- # have GPUs.
- logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
- num_cpu_blocks)
-
- self.cache_config.num_gpu_blocks = num_gpu_blocks
- self.cache_config.num_cpu_blocks = num_cpu_blocks
-
- self._run_workers("initialize_cache",
- num_gpu_blocks=num_gpu_blocks,
- num_cpu_blocks=num_cpu_blocks)
+ self._run_workers("load_model",
+ max_concurrent_workers=self.parallel_config.
+ max_parallel_loading_workers)
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -244,23 +203,6 @@ def execute_model(self,
output = all_outputs[0]
return output
- def add_lora(self, lora_request: LoRARequest) -> bool:
- assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
- return self._run_workers(
- "add_lora",
- lora_request=lora_request,
- )
-
- def remove_lora(self, lora_id: int) -> bool:
- assert lora_id > 0, "lora_id must be greater than 0."
- return self._run_workers(
- "remove_lora",
- lora_id=lora_id,
- )
-
- def list_loras(self) -> Set[int]:
- return self._run_workers("list_loras")
-
def _run_workers(
self,
method: str,
@@ -378,7 +320,7 @@ def _check_if_any_actor_is_dead(self):
f"Dead Workers: {dead_actors}. ")
-class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
+class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -409,23 +351,3 @@ async def _run_workers_async(
all_outputs = await asyncio.gather(*coros)
return all_outputs
-
- async def execute_model_async(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- blocks_to_swap_in: Dict[int, int],
- blocks_to_swap_out: Dict[int, int],
- blocks_to_copy: Dict[int, List[int]],
- ) -> SamplerOutput:
- all_outputs = await self._run_workers_async(
- "execute_model",
- driver_kwargs={
- "seq_group_metadata_list": seq_group_metadata_list,
- "blocks_to_swap_in": blocks_to_swap_in,
- "blocks_to_swap_out": blocks_to_swap_out,
- "blocks_to_copy": blocks_to_copy,
- })
-
- # Only the driver worker returns the sampling results.
- output = all_outputs[0]
- return output
From 12628d3c787efd3483aaa74b5ae4175b28fd5805 Mon Sep 17 00:00:00 2001
From: Philipp Moritz
Date: Fri, 26 Apr 2024 21:49:59 -0700
Subject: [PATCH 134/413] [Kernel] Optimize FP8 support for MoE kernel /
Mixtral via static scales (#4343)
Co-authored-by: Woosuk Kwon
---
csrc/ops.h | 7 ++-
csrc/pybind.cpp | 3 +-
csrc/quantization/fp8/fp8_cuda_kernels.cu | 25 ++++++++++-
vllm/_custom_ops.py | 12 +++--
.../layers/fused_moe/fused_moe.py | 13 ++++--
.../model_executor/layers/quantization/fp8.py | 9 +++-
vllm/model_executor/models/mixtral.py | 44 ++++++++++++++++---
7 files changed, 95 insertions(+), 18 deletions(-)
diff --git a/csrc/ops.h b/csrc/ops.h
index ff7a3de1a0a8c..03bb1e24dc68e 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -146,7 +146,12 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);
-void scaled_fp8_quant(
+void static_scaled_fp8_quant(
+ torch::Tensor& out,
+ torch::Tensor& input,
+ torch::Tensor& scale);
+
+void dynamic_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp
index a5b16c5abc3ed..2250c7f69f0ab 100644
--- a/csrc/pybind.cpp
+++ b/csrc/pybind.cpp
@@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
- ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
+ ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
+ ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu
index c3337cede1282..2477051eb60d7 100644
--- a/csrc/quantization/fp8/fp8_cuda_kernels.cu
+++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu
@@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
} // namespace vllm
-void scaled_fp8_quant(
+void static_scaled_fp8_quant(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input, // [..., d]
+ torch::Tensor& scale) // [1]
+{
+ int64_t num_tokens = input.numel() / input.size(-1);
+ int64_t num_elems = input.numel();
+ dim3 grid(num_tokens);
+ dim3 block(1024);
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(),
+ "scaled_fp8_quant_kernel",
+ [&] {
+ vllm::scaled_fp8_quant_kernel<<>>(
+ out.data_ptr(),
+ input.data_ptr(),
+ scale.data_ptr(),
+ num_elems);
+ });
+}
+
+void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 508d35656eb00..5ba104bada7ac 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# fp8
-def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+def scaled_fp8_quant(
+ input: torch.Tensor,
+ scale: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
- vllm_ops.scaled_fp8_quant(output, input, scale)
+ if scale is None:
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
+ vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
+ else:
+ vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output, scale
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index aed2c350bdd10..d37837a0b2ce8 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -220,8 +220,9 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
- B_scale: torch.Tensor, topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
+ A_scale: Optional[torch.Tensor],
+ B_scale: Optional[torch.Tensor],
+ topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert sorted_token_ids.stride(0) == 1
if not use_fp8:
- A_scale = None
+ assert A_scale is None
assert B_scale is None
else:
- A, A_scale = ops.scaled_fp8_quant(A)
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
@@ -318,6 +319,8 @@ def fused_moe(
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
+ a1_scale: Optional[torch.Tensor] = None,
+ a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -434,6 +437,7 @@ def fused_moe(
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
+ a1_scale,
w1_scale,
topk_weights,
topk_ids,
@@ -451,6 +455,7 @@ def fused_moe(
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
+ a2_scale,
w2_scale,
topk_weights,
topk_ids,
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 39679834b545c..ba9f3149649c1 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -14,6 +14,12 @@
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
+ def __init__(
+ self,
+ activation_scheme: str = "dynamic",
+ ) -> None:
+ self.activation_scheme = activation_scheme
+
@classmethod
def get_name(cls) -> str:
return "fp8"
@@ -35,7 +41,8 @@ def get_config_filenames(cls) -> List[str]:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
- return cls()
+ activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
+ return cls(activation_scheme)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index 7847df735ab44..c5dd1a63e2f7a 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -105,6 +105,13 @@ def __init__(
device="cuda",
dtype=self.params_dtype))
+ set_weight_attrs(self.ws, {
+ "weight_loader": self.weight_loader,
+ })
+ set_weight_attrs(self.w2s, {
+ "weight_loader": self.weight_loader,
+ })
+
# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
@@ -115,12 +122,23 @@ def __init__(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
- set_weight_attrs(self.ws, {
- "weight_loader": self.weight_loader,
- })
- set_weight_attrs(self.w2s, {
- "weight_loader": self.weight_loader,
- })
+ # Scaling factors for FP8 activations
+ need_act_scales = (self.use_fp8
+ and quant_config.activation_scheme == "static")
+ self.as_scale = nn.Parameter(
+ torch.zeros(1, device="cuda", dtype=torch.float32),
+ requires_grad=False) if need_act_scales else None
+ self.a2s_scale = nn.Parameter(
+ torch.zeros(1, device="cuda", dtype=torch.float32),
+ requires_grad=False) if need_act_scales else None
+
+ if need_act_scales:
+ set_weight_attrs(self.as_scale, {
+ "weight_loader": self.weight_loader,
+ })
+ set_weight_attrs(self.a2s_scale, {
+ "weight_loader": self.weight_loader,
+ })
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
@@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
+ if "act_scale" in weight_name:
+ param_data[:] = param_data[:].max(loaded_weight)
def process_weights_after_loading(self):
if self.use_fp8:
@@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.ws_scale,
- w2_scale=self.w2s_scale)
+ w2_scale=self.w2s_scale,
+ a1_scale=self.as_scale,
+ a2_scale=self.a2s_scale)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
@@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]
expert_params_mapping = [
+ # These are the weights for the experts
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
+ ] + [
+ # These are the activation scales for the experts
+ # (param_name, weight_name, expert_id)
+ ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
+ f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
+ for expert_id in range(self.config.num_local_experts)
+ for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
From 8947bc3c156963dfc66e7ca1e4c436506ed6a512 Mon Sep 17 00:00:00 2001
From: Cyrus Leung
Date: Sat, 27 Apr 2024 13:08:24 +0800
Subject: [PATCH 135/413] [Frontend][Bugfix] Disallow extra fields in OpenAI
API (#4355)
---
requirements-common.txt | 1 +
requirements-dev.txt | 1 -
tests/entrypoints/test_openai_server.py | 16 +++++
vllm/entrypoints/openai/cli_args.py | 4 +-
vllm/entrypoints/openai/protocol.py | 64 ++++++++++---------
vllm/entrypoints/openai/serving_chat.py | 55 ++++++++++++----
vllm/entrypoints/openai/serving_completion.py | 9 +--
vllm/entrypoints/openai/serving_engine.py | 18 +++---
8 files changed, 113 insertions(+), 55 deletions(-)
diff --git a/requirements-common.txt b/requirements-common.txt
index 3cc7bba8f84db..e9db261c6aec9 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -8,6 +8,7 @@ py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
+openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
diff --git a/requirements-dev.txt b/requirements-dev.txt
index d9816828d007d..324039186142b 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -21,7 +21,6 @@ pytest-rerunfailures
pytest-shard
httpx
einops # required for MPT
-openai
requests
ray
peft
diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py
index 85a7ef464c032..68332228ace08 100644
--- a/tests/entrypoints/test_openai_server.py
+++ b/tests/entrypoints/test_openai_server.py
@@ -15,6 +15,7 @@
import requests
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
+from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer
@@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded
+async def test_extra_fields(server, client: openai.AsyncOpenAI):
+ with pytest.raises(BadRequestError) as exc_info:
+ await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=[{
+ "role": "system",
+ "content": "You are a helpful assistant.",
+ "extra_field": "0",
+ }], # type: ignore
+ temperature=0,
+ seed=0)
+
+ assert "extra_forbidden" in exc_info.value.message
+
+
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index 5c361b4d184ee..16c5b6c08d37f 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -9,7 +9,7 @@
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs
-from vllm.entrypoints.openai.serving_engine import LoRA
+from vllm.entrypoints.openai.serving_engine import LoRAModulePath
class LoRAParserAction(argparse.Action):
@@ -18,7 +18,7 @@ def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
- lora_list.append(LoRA(name, path))
+ lora_list.append(LoRAModulePath(name, path))
setattr(namespace, self.dest, lora_list)
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index d9763d024eb83..0a949f9867754 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -4,14 +4,20 @@
from typing import Dict, List, Literal, Optional, Union
import torch
-from pydantic import BaseModel, Field, model_validator
+from openai.types.chat import ChatCompletionMessageParam
+from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
-class ErrorResponse(BaseModel):
+class OpenAIBaseModel(BaseModel):
+ # OpenAI API does not allow extra fields
+ model_config = ConfigDict(extra="forbid")
+
+
+class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
@@ -19,7 +25,7 @@ class ErrorResponse(BaseModel):
code: int
-class ModelPermission(BaseModel):
+class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -34,7 +40,7 @@ class ModelPermission(BaseModel):
is_blocking: bool = False
-class ModelCard(BaseModel):
+class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -44,26 +50,26 @@ class ModelCard(BaseModel):
permission: List[ModelPermission] = Field(default_factory=list)
-class ModelList(BaseModel):
+class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
-class UsageInfo(BaseModel):
+class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
-class ResponseFormat(BaseModel):
+class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
-class ChatCompletionRequest(BaseModel):
+class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
- messages: List[Dict[str, str]]
+ messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
@@ -204,7 +210,7 @@ def check_guided_decoding_count(cls, data):
return data
-class CompletionRequest(BaseModel):
+class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
@@ -343,19 +349,19 @@ def check_guided_decoding_count(cls, data):
return data
-class LogProbs(BaseModel):
+class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
-class CompletionResponseChoice(BaseModel):
+class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
- finish_reason: Optional[Literal["stop", "length"]] = None
- stop_reason: Union[None, int, str] = Field(
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
@@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel):
)
-class CompletionResponse(BaseModel):
+class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -373,12 +379,12 @@ class CompletionResponse(BaseModel):
usage: UsageInfo
-class CompletionResponseStreamChoice(BaseModel):
+class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
- finish_reason: Optional[Literal["stop", "length"]] = None
- stop_reason: Union[None, int, str] = Field(
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
@@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel):
)
-class CompletionStreamResponse(BaseModel):
+class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = Field(default=None)
-class ChatMessage(BaseModel):
+class ChatMessage(OpenAIBaseModel):
role: str
content: str
-class ChatCompletionResponseChoice(BaseModel):
+class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
- finish_reason: Optional[Literal["stop", "length"]] = None
- stop_reason: Union[None, int, str] = None
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = None
-class ChatCompletionResponse(BaseModel):
+class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
@@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel):
usage: UsageInfo
-class DeltaMessage(BaseModel):
+class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
-class ChatCompletionResponseStreamChoice(BaseModel):
+class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
- finish_reason: Optional[Literal["stop", "length"]] = None
- stop_reason: Union[None, int, str] = None
+ finish_reason: Optional[str] = None
+ stop_reason: Optional[Union[int, str]] = None
-class ChatCompletionStreamResponse(BaseModel):
+class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index f6011b6fc4cb6..629dd929dc1af 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -1,8 +1,11 @@
import codecs
import time
-from typing import AsyncGenerator, AsyncIterator, List, Optional, Union
+from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
+ Optional, Tuple, TypedDict, Union, final)
from fastapi import Request
+from openai.types.chat import (ChatCompletionContentPartParam,
+ ChatCompletionRole)
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
@@ -10,7 +13,8 @@
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
-from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
+from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
+ OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
@@ -20,20 +24,41 @@
logger = init_logger(__name__)
+@final # So that it should be compatible with Dict[str, str]
+class ConversationMessage(TypedDict):
+ role: str
+ content: str
+
+
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
response_role: str,
- lora_modules: Optional[List[LoRA]] = None,
- chat_template=None):
+ lora_modules: Optional[List[LoRAModulePath]] = None,
+ chat_template: Optional[str] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
+ def _parse_chat_message_content(
+ self,
+ role: ChatCompletionRole,
+ content: Optional[Union[str,
+ Iterable[ChatCompletionContentPartParam]]],
+ ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
+ if content is None:
+ return [], []
+ if isinstance(content, str):
+ return [ConversationMessage(role=role, content=content)], []
+
+ # To be implemented: https://github.com/vllm-project/vllm/pull/3467
+ # To be implemented: https://github.com/vllm-project/vllm/pull/4200
+ raise NotImplementedError("Complex input not supported yet")
+
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
@@ -52,10 +77,19 @@ async def create_chat_completion(
return error_check_ret
try:
+ conversation: List[ConversationMessage] = []
+
+ for m in request.messages:
+ messages, _ = self._parse_chat_message_content(
+ m["role"], m["content"])
+
+ conversation.extend(messages)
+
prompt = self.tokenizer.apply_chat_template(
- conversation=request.messages,
+ conversation=conversation,
tokenize=False,
- add_generation_prompt=request.add_generation_prompt)
+ add_generation_prompt=request.add_generation_prompt,
+ )
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
@@ -105,9 +139,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
- result_generator: AsyncIterator[RequestOutput], request_id: str
- ) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
-
+ result_generator: AsyncIterator[RequestOutput],
+ request_id: str) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
@@ -252,7 +285,7 @@ async def chat_completion_full_generator(
model_name = self.served_model_names[0]
created_time = int(time.time())
- final_res: RequestOutput = None
+ final_res: Optional[RequestOutput] = None
async for res in result_generator:
if await raw_request.is_disconnected():
@@ -317,7 +350,7 @@ async def chat_completion_full_generator(
return response
- def _load_chat_template(self, chat_template):
+ def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
if chat_template is not None:
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 211b2e0424c3e..7904bb698c45a 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -11,7 +11,8 @@
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
-from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing
+from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
+ OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
@@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
- lora_modules: Optional[List[LoRA]] = None):
+ lora_modules: Optional[List[LoRAModulePath]] = None):
super().__init__(engine=engine,
served_model_names=served_model_names,
lora_modules=lora_modules)
@@ -84,7 +85,7 @@ async def create_completion(self, request: CompletionRequest,
created_time = int(time.time())
# Schedule the request and get the result generator.
- generators = []
+ generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
@@ -148,7 +149,7 @@ async def create_completion(self, request: CompletionRequest,
num_prompts=len(prompts))
# Non-streaming response
- final_res_batch: RequestOutput = [None] * len(prompts)
+ final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index 31da27a447c6c..3d5ed328b9d19 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -22,17 +22,15 @@
@dataclass
-class LoRA:
+class LoRAModulePath:
name: str
local_path: str
class OpenAIServing:
- def __init__(self,
- engine: AsyncLLMEngine,
- served_model_names: List[str],
- lora_modules=Optional[List[LoRA]]):
+ def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
+ lora_modules: Optional[List[LoRAModulePath]]):
self.engine = engine
self.served_model_names = served_model_names
if lora_modules is None:
@@ -158,7 +156,9 @@ def create_streaming_error_response(
})
return json_str
- async def _check_model(self, request) -> Optional[ErrorResponse]:
+ async def _check_model(
+ self, request: Union[CompletionRequest, ChatCompletionRequest]
+ ) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
@@ -168,14 +168,16 @@ async def _check_model(self, request) -> Optional[ErrorResponse]:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
- def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
+ def _maybe_get_lora(
+ self, request: Union[CompletionRequest, ChatCompletionRequest]
+ ) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
- raise ValueError("The model `{request.model}` does not exist.")
+ raise ValueError(f"The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,
From 87f545ba6fdbbbe9813736fc398874563e2604a7 Mon Sep 17 00:00:00 2001
From: Roy
Date: Sat, 27 Apr 2024 13:45:02 +0800
Subject: [PATCH 136/413] [Misc] Fix logger format typo (#4396)
---
vllm/engine/metrics.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index d3560f5fefff1..eb54f5641171e 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -230,8 +230,8 @@ def log(self, stats: Stats) -> None:
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
- "Pending: %d reqs, GPU KV cache usage: %.1f%, "
- "CPU KV cache usage: %.1f%",
+ "Pending: %d reqs, GPU KV cache usage: %.1f%%, "
+ "CPU KV cache usage: %.1f%%",
prompt_throughput,
generation_throughput,
stats.num_running,
From 18d23f642af9c56f45ccf684a22f386fc54c6ead Mon Sep 17 00:00:00 2001
From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Date: Sat, 27 Apr 2024 02:37:40 -0400
Subject: [PATCH 137/413] [ROCm][Hardware][AMD] Enable group query attention
for triton FA (#4406)
---
vllm/attention/backends/rocm_flash_attn.py | 53 +++++++++-----------
vllm/attention/ops/triton_flash_attention.py | 24 ++++-----
2 files changed, 36 insertions(+), 41 deletions(-)
diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py
index 7c5863a030ff5..934acea0a3d60 100644
--- a/vllm/attention/backends/rocm_flash_attn.py
+++ b/vllm/attention/backends/rocm_flash_attn.py
@@ -253,36 +253,31 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
- if self.use_triton_flash_attn or self.use_naive_attn:
+ if self.use_triton_flash_attn:
+ out, _ = self.attn_func(
+ query,
+ key,
+ value,
+ None,
+ prefill_meta.seq_start_loc,
+ prefill_meta.seq_start_loc,
+ prefill_meta.max_prompt_len,
+ prefill_meta.max_prompt_len,
+ True,
+ self.scale,
+ )
+ elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
- if self.use_naive_attn:
- out = self.attn_func(
- query,
- key,
- value,
- prefill_meta.prompt_lens,
- self.scale,
- )
- assert output[:num_prefill_tokens].shape == out.shape
- output[:num_prefill_tokens] = out
- else:
- out, _ = self.attn_func(
- query,
- key,
- value,
- None,
- prefill_meta.seq_start_loc,
- prefill_meta.seq_start_loc,
- prefill_meta.max_prompt_len,
- prefill_meta.max_prompt_len,
- True,
- self.scale,
- )
- assert output[:num_prefill_tokens].shape == out.shape
- output[:num_prefill_tokens] = out
+ out = self.attn_func(
+ query,
+ key,
+ value,
+ prefill_meta.prompt_lens,
+ self.scale,
+ )
else:
out = self.attn_func(
q=query,
@@ -295,8 +290,10 @@ def forward(
softmax_scale=self.scale,
causal=True,
)
- assert output[:num_prefill_tokens].shape == out.shape
- output[:num_prefill_tokens] = out
+
+ # common code for prefill
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index e160411859f0b..1147664183ff1 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps=4,
),
],
- key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
+ key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
@@ -330,8 +330,8 @@ def attn_fwd(
philox_seed,
philox_offset_base,
encoded_softmax,
- hq,
- hk,
+ HQ: tl.constexpr,
+ HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
@@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
- # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
+ # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
@@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
return
- is_mqa = hq != hk
- if is_mqa: # noqa: SIM108
- off_h_k = off_h_q % hk
- else:
- off_h_k = off_h_q
+ # If MQA / GQA, set the K and V head offsets appropriately.
+ GROUP_SIZE: tl.constexpr = HQ // HK
+ off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
@@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
- + (off_z * hq + off_h_q) \
+ + (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
@@ -624,7 +622,7 @@ def attn_fwd(
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
- # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
+ # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
@@ -784,8 +782,8 @@ def forward(
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
- hq=nheads_q,
- hk=nheads_k,
+ HQ=nheads_q,
+ HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
From eefeb16464af5f3a61e3052d1a4128480bff7f47 Mon Sep 17 00:00:00 2001
From: Austin Veselka <50646302+FurtherAI@users.noreply.github.com>
Date: Sat, 27 Apr 2024 02:03:48 -0500
Subject: [PATCH 138/413] [Kernel] Full Tensor Parallelism for LoRA Layers
(#3524)
Co-authored-by: Antoni Baum
---
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu | 1 +
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu | 1 +
csrc/punica/bgmv/bgmv_config.h | 78 +++++++
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu | 1 +
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu | 1 +
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu | 1 +
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu | 1 +
csrc/punica/bgmv/bgmv_impl.cuh | 5 +-
csrc/punica/bgmv/generator.py | 1 +
csrc/punica/punica_ops.cc | 2 +-
tests/lora/test_layers.py | 29 ++-
tests/lora/test_punica.py | 51 ++++-
vllm/config.py | 1 +
vllm/engine/arg_utils.py | 10 +
vllm/lora/fully_sharded_layers.py | 262 ++++++++++++++++++++++++
vllm/lora/layers.py | 243 +++++++++++++---------
vllm/lora/models.py | 6 +-
vllm/lora/punica.py | 43 ++++
vllm/lora/utils.py | 60 +++++-
19 files changed, 686 insertions(+), 111 deletions(-)
create mode 100644 vllm/lora/fully_sharded_layers.py
diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
index c642e94925fe5..86846c274c90f 100644
--- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
+++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
index 0607cebfeac40..de39c3121f5d3 100644
--- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h
index fec484d693055..19c058cacfbc4 100644
--- a/csrc/punica/bgmv/bgmv_config.h
+++ b/csrc/punica/bgmv/bgmv_config.h
@@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
+// Used for defining kernels going from the variety of
+// dim in to the narrow dim out
+ // Using it for the fully sharded column
+ // parallel LoRA A which splits the rank dim
+#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
+ f(in_T, out_T, W_T, 128, narrow) \
+ f(in_T, out_T, W_T, 256, narrow) \
+ f(in_T, out_T, W_T, 512, narrow) \
+ f(in_T, out_T, W_T, 640, narrow) \
+ f(in_T, out_T, W_T, 768, narrow) \
+ f(in_T, out_T, W_T, 1024, narrow) \
+ f(in_T, out_T, W_T, 1152, narrow) \
+ f(in_T, out_T, W_T, 1280, narrow) \
+ f(in_T, out_T, W_T, 1536, narrow) \
+ f(in_T, out_T, W_T, 1728, narrow) \
+ f(in_T, out_T, W_T, 1792, narrow) \
+ f(in_T, out_T, W_T, 2048, narrow) \
+ f(in_T, out_T, W_T, 2304, narrow) \
+ f(in_T, out_T, W_T, 2560, narrow) \
+ f(in_T, out_T, W_T, 2752, narrow) \
+ f(in_T, out_T, W_T, 2816, narrow) \
+ f(in_T, out_T, W_T, 3072, narrow) \
+ f(in_T, out_T, W_T, 3456, narrow) \
+ f(in_T, out_T, W_T, 3584, narrow) \
+ f(in_T, out_T, W_T, 4096, narrow) \
+ f(in_T, out_T, W_T, 4608, narrow) \
+ f(in_T, out_T, W_T, 5120, narrow) \
+ f(in_T, out_T, W_T, 5504, narrow) \
+ f(in_T, out_T, W_T, 5632, narrow) \
+ f(in_T, out_T, W_T, 6144, narrow) \
+ f(in_T, out_T, W_T, 6848, narrow) \
+ f(in_T, out_T, W_T, 6912, narrow) \
+ f(in_T, out_T, W_T, 7168, narrow) \
+ f(in_T, out_T, W_T, 8192, narrow) \
+ f(in_T, out_T, W_T, 9216, narrow) \
+ f(in_T, out_T, W_T, 10240, narrow) \
+ f(in_T, out_T, W_T, 11008, narrow) \
+ f(in_T, out_T, W_T, 12288, narrow) \
+ f(in_T, out_T, W_T, 13696, narrow) \
+ f(in_T, out_T, W_T, 13824, narrow) \
+ f(in_T, out_T, W_T, 14336, narrow) \
+ f(in_T, out_T, W_T, 15360, narrow) \
+ f(in_T, out_T, W_T, 16384, narrow) \
+ f(in_T, out_T, W_T, 20480, narrow) \
+ f(in_T, out_T, W_T, 22016, narrow) \
+ f(in_T, out_T, W_T, 24576, narrow) \
+ f(in_T, out_T, W_T, 27392, narrow) \
+ f(in_T, out_T, W_T, 28672, narrow) \
+ f(in_T, out_T, W_T, 32000, narrow) \
+ f(in_T, out_T, W_T, 32256, narrow) \
+ f(in_T, out_T, W_T, 32512, narrow) \
+ f(in_T, out_T, W_T, 32768, narrow) \
+ f(in_T, out_T, W_T, 33024, narrow) \
+ f(in_T, out_T, W_T, 36864, narrow) \
+ f(in_T, out_T, W_T, 43264, narrow) \
+ f(in_T, out_T, W_T, 49152, narrow) \
+ f(in_T, out_T, W_T, 64000, narrow) \
+ f(in_T, out_T, W_T, 64256, narrow) \
+ f(in_T, out_T, W_T, 64512, narrow) \
+ f(in_T, out_T, W_T, 102400, narrow) \
+ f(in_T, out_T, W_T, 102656, narrow) \
+ f(in_T, out_T, W_T, 102912, narrow) \
+ f(in_T, out_T, W_T, 128000, narrow) \
+ f(in_T, out_T, W_T, 128256, narrow) \
+ f(in_T, out_T, W_T, 128512, narrow) \
+// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
+
+
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
@@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
+
+#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
+ FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
+ FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
+ FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
+ f(in_T, out_T, W_T, 8, 64) \
+ f(in_T, out_T, W_T, 16, 64) \
+ f(in_T, out_T, W_T, 32, 64) \
+ f(in_T, out_T, W_T, 64, 64)
+
// clang-format on
diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
index f1db6df5f7338..d225a1eaa82b0 100644
--- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
index c01ddd009d74e..b37d288a75561 100644
--- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
index f45183ffd3486..a1ab2deecbabf 100644
--- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
index 4097743488087..0b35bf5699898 100644
--- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
@@ -2,3 +2,4 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh
index 995de26e8bada..dad8805c750cb 100644
--- a/csrc/punica/bgmv/bgmv_impl.cuh
+++ b/csrc/punica/bgmv/bgmv_impl.cuh
@@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- if constexpr (feat_in < feat_out) {
+ if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
@@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
+#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
+ INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
+
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py
index 9bf7f6358880f..972df5a7208c2 100644
--- a/csrc/punica/bgmv/generator.py
+++ b/csrc/punica/bgmv/generator.py
@@ -10,6 +10,7 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
+FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501
for input_dtype in DTYPES:
diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc
index a1eaa90e85f27..8797fde85744a 100644
--- a/csrc/punica/punica_ops.cc
+++ b/csrc/punica/punica_ops.cc
@@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
+ FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
-
return true;
}
diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py
index 1616fdfd4cff9..0eb04f4ccd133 100644
--- a/tests/lora/test_layers.py
+++ b/tests/lora/test_layers.py
@@ -8,6 +8,10 @@
import torch.nn.functional as F
from vllm.config import LoRAConfig
+from vllm.lora.fully_sharded_layers import (
+ ColumnParallelLinearWithShardedLoRA,
+ MergedColumnParallelLinearWithShardedLoRA,
+ MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
@@ -524,13 +528,16 @@ def _pretest():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
+@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
-def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
+def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
+ device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
+ fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
def create_random_linear_parallel_layer():
@@ -540,14 +547,17 @@ def create_random_linear_parallel_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
- lora_linear = RowParallelLinearWithLoRA(linear)
+ lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
+ else RowParallelLinearWithShardedLoRA(linear))
else:
linear = ColumnParallelLinear(4096,
4096,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
- lora_linear = ColumnParallelLinearWithLoRA(linear)
+ lora_linear = (ColumnParallelLinearWithLoRA(linear)
+ if not fully_shard else
+ ColumnParallelLinearWithShardedLoRA(linear))
lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear
@@ -629,13 +639,16 @@ def create_random_linear_parallel_layer():
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3])
+@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
-def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
+def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
+ device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
+ fully_sharded_loras=fully_shard,
lora_dtype=torch.float16)
def create_column_parallel_packed_layer():
@@ -644,7 +657,9 @@ def create_column_parallel_packed_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
- lora_linear = MergedColumnParallelLinearWithLoRA(linear)
+ lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
+ if not fully_shard else
+ MergedColumnParallelLinearWithShardedLoRA(linear))
elif repeats == 3:
linear = QKVParallelLinear(4096,
64,
@@ -652,7 +667,9 @@ def create_column_parallel_packed_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
- lora_linear = MergedQKVParallelLinearWithLora(linear)
+ lora_linear = (MergedQKVParallelLinearWithLora(linear)
+ if not fully_shard else
+ MergedQKVParallelLinearWithShardedLora(linear))
else:
linear = QKVParallelLinear(4096,
64,
diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py
index f3b9bd5912967..fd2a1b75f460c 100644
--- a/tests/lora/test_punica.py
+++ b/tests/lora/test_punica.py
@@ -34,11 +34,14 @@ def _lora_ref_impl(
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
- wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
+ if wb_T_all is not None:
+ wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
+ -2).to(torch.float32)
tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0)
- y_final[i] += (tmp @ wb).squeeze(0) * s
+ y_final[i] += ((tmp @ wb).squeeze(0) *
+ s if wb_T_all is not None else y_stage_1[i])
return y_final, y_stage_1
@@ -91,12 +94,56 @@ def _lora_ref_impl(
128000,
128256,
]
+H2 = [64] + H2
+R = [1, 2, 4]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
+@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
+@pytest.mark.parametrize("h1", H1)
+@pytest.mark.parametrize("r", R)
+@pytest.mark.parametrize("seed", SEED)
+@torch.inference_mode()
+def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
+ torch.manual_seed(seed)
+ num_loras = 4
+ num_layers = 1
+ bs = 32
+ dtype = getattr(torch, dtype_str)
+ device = torch.device("cuda")
+
+ wa_T_all = torch.randn(num_loras,
+ num_layers,
+ r,
+ h1,
+ dtype=dtype,
+ device=device)
+ indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
+
+ for layer_idx in range(num_layers):
+ x = torch.randn(bs, h1, dtype=dtype, device=device)
+ y = torch.randn(bs, r, dtype=dtype, device=device)
+
+ y_ref = y.clone()
+ _lora_ref_impl(
+ y_ref,
+ x,
+ wa_T_all,
+ None,
+ indices,
+ layer_idx,
+ 1.0,
+ )
+
+ y_our = y.clone()
+ punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)
+
+ assert_close(y_ref, y_our)
+
+
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
diff --git a/vllm/config.py b/vllm/config.py
index 887a73d9462dc..aedb589247646 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -862,6 +862,7 @@ def __repr__(self) -> str:
class LoRAConfig:
max_lora_rank: int
max_loras: int
+ fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 6a6ac49ae3211..bd6437ee44c28 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -52,6 +52,7 @@ class EngineArgs:
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
+ fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
@@ -376,6 +377,14 @@ def add_cli_args(
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
+ parser.add_argument(
+ '--fully-sharded-loras',
+ action='store_true',
+ help=('By default, only half of the LoRA computation is '
+ 'sharded with tensor parallelism. '
+ 'Enabling this will use the fully sharded layers. '
+ 'At high sequence length, max rank or '
+ 'tensor parallel size, this is likely faster.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
@@ -509,6 +518,7 @@ def create_engine_config(self, ) -> EngineConfig:
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
+ fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py
new file mode 100644
index 0000000000000..1720566840bb1
--- /dev/null
+++ b/vllm/lora/fully_sharded_layers.py
@@ -0,0 +1,262 @@
+# pylint: disable=unused-argument
+from typing import TYPE_CHECKING, List, Optional
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from vllm.config import LoRAConfig
+from vllm.distributed.communication_op import (
+ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
+from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
+from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
+ MergedColumnParallelLinearWithLoRA,
+ MergedQKVParallelLinearWithLora,
+ RowParallelLinearWithLoRA)
+from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
+
+if TYPE_CHECKING:
+ pass
+
+
+def _fully_sharded_can_replace(can_replace):
+ """
+ decorator which adds the condition of fully sharded loras
+ intended to wrap can_replace_layer()
+ """
+
+ def dec(*args, **kwargs):
+ return (can_replace(*args, **kwargs)
+ and kwargs['lora_config'].fully_sharded_loras)
+
+ return dec
+
+
+# these layers are based on the tensor parallelism strategy given in
+# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
+# https://arxiv.org/abs/2311.03285.
+
+
+class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
+ """
+ Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
+
+ Based on S-LoRA, slicing happens along the rank dim.
+ """
+
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
+ tp_rank = get_tensor_model_parallel_rank()
+ shard_size = self.lora_a_stacked.shape[2]
+ start_idx = tp_rank * shard_size
+ lora_a = lora_a[:, start_idx:start_idx + shard_size]
+ return lora_a
+
+ def apply_weights(self, x: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ output = self.base_layer.linear_method.apply_weights(
+ self.base_layer, x, bias)
+
+ x = x.view(-1, x.shape[-1])
+ output, out_orig_shape = output.view(-1,
+ output.shape[-1]), output.shape
+ buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
+ dtype=torch.float32,
+ device=x.device)
+
+ bgmv(buffer, x, self.lora_a_stacked,
+ self.indices[:self.indices_len[0]], 0, 1.0)
+ buffer = tensor_model_parallel_all_gather(buffer)
+ bgmv(output, buffer, self.lora_b_stacked,
+ self.indices[:self.indices_len[0]], 0, 1.0)
+ # now have column partitioned output
+
+ output = output.view(*out_orig_shape)
+ return output
+
+ @classmethod
+ @_fully_sharded_can_replace
+ def can_replace_layer(cls, source_layer: nn.Module,
+ lora_config: LoRAConfig, packed_modules_list: List,
+ model_config: Optional[PretrainedConfig]) -> bool:
+ # specifying kwargs so they can be easily accessed in decorator
+ return super().can_replace_layer(
+ source_layer=source_layer,
+ lora_config=lora_config,
+ packed_modules_list=packed_modules_list,
+ model_config=model_config,
+ decorate=False,
+ )
+
+
+def _mcp_apply_weights(x, bias, layer):
+ """
+ MergedColumnParallelLinearWithShardedLoRA and
+ QKVParallelLinearWithShardedLora share the same
+ LoRa weight application method.
+
+ The main difference is the step by shard_size for lora_b which can
+ vary for QKVParallelLinearWithShardedLora but is constant for
+ MergedColumnParallelLinearWithShardedLoRA.
+ """
+ # expecting 2 for column parallel and 3 for qkv
+ n = len(layer.lora_a_stacked)
+ output = layer.base_layer.linear_method.apply_weights(
+ layer.base_layer, x, bias)
+
+ x = x.view(-1, x.shape[-1])
+ output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
+ buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
+ dtype=torch.float32,
+ device=x.device)
+ for idx in range(n):
+ bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
+ layer.indices[:layer.indices_len[0]], 0, 1.0)
+
+ buffers = tensor_model_parallel_all_gather(buffers)
+ left_offset = 0
+ for idx in range(n):
+ shard_size = layer.lora_b_stacked[idx].shape[2]
+ dispatch_bgmv_low_level(output, buffers[idx],
+ layer.lora_b_stacked[idx],
+ layer.indices[:layer.indices_len[0]], 0, 1.0,
+ left_offset, shard_size)
+ left_offset += shard_size
+
+ output = output.view(*out_orig_shape)
+ # now have column partitioned and packed output
+ return output
+
+
+class MergedColumnParallelLinearWithShardedLoRA(
+ MergedColumnParallelLinearWithLoRA):
+ """
+ Differs from MergedColumnParallelLinearWithLoRA by slicing the
+ LoRA A's also.
+
+ Based on S-LoRA, slicing happens along the rank dim.
+ """
+
+ def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
+ output_shard_size = self.lora_a_stacked[0].shape[2]
+ output_start_idx = self.tp_rank * output_shard_size
+ lora_a = [
+ lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
+ for i in range(2)
+ ]
+ return lora_a
+
+ def apply_weights(self, x: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ return _mcp_apply_weights(x, bias, self)
+
+ @classmethod
+ @_fully_sharded_can_replace
+ def can_replace_layer(cls, source_layer: nn.Module,
+ lora_config: LoRAConfig, packed_modules_list: List,
+ model_config: Optional[PretrainedConfig]) -> bool:
+ # specifying kwargs so they can be easily accessed in decorator
+ return super().can_replace_layer(
+ source_layer=source_layer,
+ lora_config=lora_config,
+ packed_modules_list=packed_modules_list,
+ model_config=model_config,
+ decorate=False,
+ )
+
+
+class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
+ """
+ Differs from QKVParallelLinearWithLora by slicing the
+ LoRA A's also.
+
+ Based on S-LoRA, slicing happens along the rank dim.
+ """
+
+ def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
+ shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
+ start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
+ lora_a = [
+ lora_a[i][:, start_idx[i]:start_idx[i] +
+ shard_size[i]] if lora_a[i] is not None else None
+ for i in range(3)
+ ]
+ return lora_a
+
+ def apply_weights(self, x: torch.Tensor,
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ return _mcp_apply_weights(x, bias, self)
+
+ @classmethod
+ @_fully_sharded_can_replace
+ def can_replace_layer(cls, source_layer: nn.Module,
+ lora_config: LoRAConfig, packed_modules_list: List,
+ model_config: Optional[PretrainedConfig]) -> bool:
+ # specifying kwargs so they can be easily accessed in decorator
+ return super().can_replace_layer(
+ source_layer=source_layer,
+ lora_config=lora_config,
+ packed_modules_list=packed_modules_list,
+ model_config=model_config,
+ decorate=False,
+ )
+
+
+class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
+ """
+ Differs from RowParallelLinearWithLoRA by slicing the
+ LoRA B's also.
+
+ Based on S-LoRA, slicing happens along the output dim.
+ This yields a combined partial sum from the row parallel base
+ layer and column partitioned output from the LoRA.
+ """
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ shard_size = self.lora_b_stacked.shape[2]
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+ lora_b = lora_b[:, start_idx:end_idx]
+ return lora_b
+
+ def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
+ output = self.base_layer.linear_method.apply_weights(
+ self.base_layer, x)
+
+ x = x.view(-1, x.shape[-1])
+ output, out_orig_shape = output.view(-1,
+ output.shape[-1]), output.shape
+ buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
+ dtype=torch.float32,
+ device=x.device)
+ bgmv(buffer, x, self.lora_a_stacked,
+ self.indices[:self.indices_len[0]], 0, 1.0)
+ buffer = tensor_model_parallel_all_reduce(buffer)
+
+ # following S-LoRA, allows the fusing of all_gather and all_reduce
+ # by adding the column partitioned lora output to a slice of output
+ # tensor, which is a partial sum due to row parallel. All that
+ # remains is a standard all_reduce. User should be aware though that
+ # the output is not the same as a normal row_parallel, it should be
+ # reduced before being used
+ shard_size = self.lora_b_stacked.shape[2]
+ start_idx = self.tp_rank * shard_size
+ dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
+ self.indices[:self.indices_len[0]], 0, 1.0,
+ start_idx, shard_size)
+
+ output = output.view(*out_orig_shape)
+ return output
+
+ @classmethod
+ @_fully_sharded_can_replace
+ def can_replace_layer(cls, source_layer: nn.Module,
+ lora_config: LoRAConfig, packed_modules_list: List,
+ model_config: Optional[PretrainedConfig]) -> bool:
+ # specifying kwargs so they can be easily accessed in decorator
+ return super().can_replace_layer(
+ source_layer=source_layer,
+ lora_config=lora_config,
+ packed_modules_list=packed_modules_list,
+ model_config=model_config,
+ decorate=False,
+ )
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 4eaf73fbcfda4..b3609666b2ec7 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -1,8 +1,7 @@
# pylint: disable=unused-argument
-import inspect
import math
from dataclasses import dataclass
-from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import torch.nn as nn
@@ -16,6 +15,7 @@
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
+from vllm.distributed.utils import divide
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -23,7 +23,7 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
- ParallelLMHead, VocabParallelEmbedding)
+ VocabParallelEmbedding)
if TYPE_CHECKING:
pass
@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
raise ValueError(f"Unsupported base layer: {base_layer}")
+def _not_fully_sharded_can_replace(can_replace):
+ """
+ decorator which adds the condition of not using fully sharded loras
+ intended to wrap can_replace_layer()
+ """
+
+ def dec(*args, **kwargs):
+ decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
+ condition = (not kwargs['lora_config'].fully_sharded_loras
+ if decorate else True)
+ return can_replace(*args, **kwargs) and condition
+
+ return dec
+
+
def _apply_lora(
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
@@ -130,6 +145,14 @@ def __post_init__(self):
class BaseLayerWithLoRA(nn.Module):
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
+ """Slice lora a if splitting for tensor parallelism."""
+ ...
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ """Slice lora b if splitting with tensor parallelism."""
+ ...
+
def create_lora_weights(
self,
max_loras: int,
@@ -317,6 +340,11 @@ def can_replace_layer(cls, source_layer: nn.Module,
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
+ """
+ LoRA on top of ColumnParallelLinear layer.
+
+ LoRA B is sliced for tensor parallelism.
+ """
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
@@ -331,10 +359,15 @@ def create_lora_weights(
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
+ self.lora_config = lora_config
+ self.tp_size = get_tensor_model_parallel_world_size()
+ lora_a_output_size_per_partition = (
+ lora_config.max_lora_rank if not lora_config.fully_sharded_loras
+ else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = torch.zeros(
max_loras,
1,
- lora_config.max_lora_rank,
+ lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
@@ -357,6 +390,17 @@ def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
+ return lora_a
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+ shard_size = self.output_dim
+ start_idx = tensor_model_parallel_rank * shard_size
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
+ lora_b = lora_b[:, start_idx:end_idx]
+ return lora_b
+
def set_lora(
self,
index: int,
@@ -365,12 +409,11 @@ def set_lora(
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
+
if self.tp_size > 1:
- tensor_model_parallel_rank = get_tensor_model_parallel_rank()
- shard_size = self.output_dim
- start_idx = tensor_model_parallel_rank * shard_size
- end_idx = (tensor_model_parallel_rank + 1) * shard_size
- lora_b = lora_b[:, start_idx:end_idx]
+ lora_a = self.slice_lora_a(lora_a)
+ lora_b = self.slice_lora_b(lora_b)
+
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
@@ -426,6 +469,7 @@ def forward(self, input_):
return output, output_bias
@classmethod
+ @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
@@ -451,6 +495,7 @@ def create_lora_weights(
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
+ self.lora_config = lora_config
n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0]
@@ -459,12 +504,17 @@ def create_lora_weights(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+
+ lora_a_output_size_per_partition = (
+ lora_config.max_lora_rank if not lora_config.fully_sharded_loras
+ else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
- lora_config.max_lora_rank,
+ lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
@@ -489,6 +539,18 @@ def reset_lora(self, index: int):
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
+ def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
+ return lora_a
+
+ def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
+ shard_size = self.output_dim
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+ lora_b = [
+ lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
+ ]
+ return lora_b
+
def set_lora(
self,
index: int,
@@ -499,13 +561,8 @@ def set_lora(
self.reset_lora(index)
if self.tp_size > 1:
- tensor_model_parallel_rank = get_tensor_model_parallel_rank()
- shard_size = self.output_dim
- start_idx = tensor_model_parallel_rank * shard_size
- end_idx = (tensor_model_parallel_rank + 1) * shard_size
- lora_b = lora_b[0][:,
- start_idx:end_idx], lora_b[1][:,
- start_idx:end_idx]
+ lora_a = self.slice_lora_a(lora_a)
+ lora_b = self.slice_lora_b(lora_b)
if lora_a[0] is not None:
self.lora_a_stacked[0][
@@ -536,6 +593,7 @@ def apply(self, x: torch.Tensor,
return output
@classmethod
+ @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
@@ -627,21 +685,25 @@ def create_lora_weights(
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
+ self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
- tp_rank = get_tensor_model_parallel_rank()
+ self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
- self.q_shard_id = tp_rank
- self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
+ self.q_shard_id = self.tp_rank
+ self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
+ lora_a_output_size_per_partition = (
+ lora_config.max_lora_rank if not lora_config.fully_sharded_loras
+ else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
- lora_config.max_lora_rank,
+ lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
@@ -649,7 +711,7 @@ def create_lora_weights(
torch.zeros(
max_loras,
1,
- lora_config.max_lora_rank,
+ lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
@@ -657,7 +719,7 @@ def create_lora_weights(
torch.zeros(
max_loras,
1,
- lora_config.max_lora_rank,
+ lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
@@ -705,6 +767,25 @@ def reset_lora(self, index: int):
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
+ def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
+ return lora_a
+
+ def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
+ if lora_b[0] is not None:
+ lora_b_q = lora_b[0][:, self.q_proj_shard_size *
+ self.q_shard_id:self.q_proj_shard_size *
+ (self.q_shard_id + 1)]
+ if lora_b[1] is not None:
+ lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
+ self.kv_shard_id:self.kv_proj_shard_size *
+ (self.kv_shard_id + 1)]
+ if lora_b[2] is not None:
+ lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
+ self.kv_shard_id:self.kv_proj_shard_size *
+ (self.kv_shard_id + 1)]
+ lora_b = [lora_b_q, lora_b_k, lora_b_v]
+ return lora_b
+
def set_lora(
self,
index: int,
@@ -715,40 +796,24 @@ def set_lora(
self.reset_lora(index)
if self.tp_size > 1:
- if lora_b[0] is not None:
- lora_b_q = lora_b[0][:, self.q_proj_shard_size *
- self.q_shard_id:self.q_proj_shard_size *
- (self.q_shard_id + 1)]
- self.lora_b_stacked[0][
- index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
- lora_b_q.T, non_blocking=True)
- if lora_b[1] is not None:
- lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
- self.kv_shard_id:self.kv_proj_shard_size *
- (self.kv_shard_id + 1)]
- self.lora_b_stacked[1][
- index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
- lora_b_k.T, non_blocking=True)
- if lora_b[2] is not None:
- lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
- self.kv_shard_id:self.kv_proj_shard_size *
- (self.kv_shard_id + 1)]
- self.lora_b_stacked[2][
- index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
- lora_b_v.T, non_blocking=True)
- else:
- if lora_b[0] is not None:
- self.lora_b_stacked[0][
- index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
- lora_b[0].T, non_blocking=True)
- if lora_b[1] is not None:
- self.lora_b_stacked[1][
- index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
- lora_b[1].T, non_blocking=True)
- if lora_b[2] is not None:
- self.lora_b_stacked[2][
- index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
- lora_b[2].T, non_blocking=True)
+ lora_a = self.slice_lora_a(lora_a)
+ lora_b = self.slice_lora_b(lora_b)
+
+ if lora_b[0] is not None:
+ lora_b_q = lora_b[0]
+ self.lora_b_stacked[0][
+ index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
+ lora_b_q.T, non_blocking=True)
+ if lora_b[1] is not None:
+ lora_b_k = lora_b[1]
+ self.lora_b_stacked[1][
+ index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
+ lora_b_k.T, non_blocking=True)
+ if lora_b[2] is not None:
+ lora_b_v = lora_b[2]
+ self.lora_b_stacked[2][
+ index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
+ lora_b_v.T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
@@ -777,6 +842,7 @@ def apply(self, x: torch.Tensor,
return output
@classmethod
+ @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
@@ -798,6 +864,8 @@ def create_lora_weights(
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
+ self.lora_config = lora_config
+ self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros(
(
max_loras,
@@ -808,11 +876,16 @@ def create_lora_weights(
dtype=lora_config.lora_dtype,
device=self.device,
)
+ tp_size = get_tensor_model_parallel_world_size()
+ lora_b_output_size_per_partition = (
+ self.output_size if not lora_config.fully_sharded_loras else
+ divide(self.output_size, tp_size))
+
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
- self.output_size,
+ lora_b_output_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
@@ -826,6 +899,17 @@ def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+ shard_size = self.input_size
+ start_idx = tensor_model_parallel_rank * shard_size
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
+ lora_a = lora_a[start_idx:end_idx, :]
+ return lora_a
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ return lora_b
+
def set_lora(
self,
index: int,
@@ -834,12 +918,10 @@ def set_lora(
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
+
if self.base_layer.tp_size > 1:
- tensor_model_parallel_rank = get_tensor_model_parallel_rank()
- shard_size = self.input_size
- start_idx = tensor_model_parallel_rank * shard_size
- end_idx = (tensor_model_parallel_rank + 1) * shard_size
- lora_a = lora_a[start_idx:end_idx, :]
+ lora_a = self.slice_lora_a(lora_a)
+ lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@@ -915,6 +997,7 @@ def weight(self):
self.base_layer, "weight") else self.base_layer.qweight
@classmethod
+ @_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
@@ -1096,37 +1179,3 @@ def can_replace_layer(cls, source_layer: nn.Module,
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
-
-
-_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
- cls
- for cls in globals().values() if inspect.isclass(cls)
- and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
-}
-
-
-def from_layer(layer: nn.Module,
- max_loras: int,
- lora_config: LoRAConfig,
- packed_modules_list: List,
- model_config: Optional[PretrainedConfig] = None) -> nn.Module:
- for lora_cls in _all_lora_classes:
- if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
- model_config):
- ret = lora_cls(layer)
- ret.create_lora_weights(max_loras, lora_config, model_config)
- return ret
- return layer
-
-
-def from_layer_logits_processor(
- layer: LogitsProcessor,
- lm_head: ParallelLMHead,
- max_loras: int,
- lora_config: LoRAConfig,
- model_config: Optional[PretrainedConfig] = None,
-) -> LogitsProcessorWithLoRA:
- ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
- lm_head.weight.dtype, lm_head.weight.device)
- ret.create_lora_weights(max_loras, lora_config, model_config)
- return ret
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index 6a077e9b0c755..50d7e9133e0e8 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -11,10 +11,10 @@
from vllm.config import LoRAConfig
from vllm.logger import init_logger
-from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
- from_layer_logits_processor)
+from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
-from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
+from vllm.lora.utils import (from_layer, from_layer_logits_processor,
+ parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__)
diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py
index fc74269e55876..c87bed54726fc 100644
--- a/vllm/lora/punica.py
+++ b/vllm/lora/punica.py
@@ -49,6 +49,49 @@ def bgmv(
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
+def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
+ w_t_all: torch.Tensor, indicies: torch.LongTensor,
+ layer_idx: int, scale: float, y_offset: int,
+ y_slice_size: int):
+ """
+ Same as `bgmv` but you can operate on slices of y.
+ Pass whole y, define y_offset and y_slice_size.
+
+ Semantics:
+ y[i] += (
+ x[i].unsqueeze(0)
+ @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
+ * scale
+ ).squeeze(0)
+
+ Args:
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
+ x: Shape: `[B, H1]`. Input vectors.
+ w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
+ all of the transposed LoRA matrices.
+ indicies: Shape: `[B]`. Indices of the LoRA weights.
+ layer_idx: Layer index of LoRA weights.
+ scale: Scaling factor.
+ y_offset: Offset to apply to the starting column of y.
+ y_slice_size: Size of the y column slice.
+ """
+ try:
+ import vllm._punica_C as punica_kernels
+ except ImportError as e:
+ _raise_import_error(e)
+ punica_kernels.dispatch_bgmv_low_level(
+ y,
+ x,
+ w_t_all,
+ indicies,
+ layer_idx,
+ scale,
+ x.size(1),
+ y_slice_size,
+ y_offset,
+ )
+
+
def add_lora(y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py
index 39e08f0412e4a..9942a5fd40dec 100644
--- a/vllm/lora/utils.py
+++ b/vllm/lora/utils.py
@@ -1,11 +1,69 @@
-from typing import Tuple
+from typing import List, Optional, Set, Tuple, Type
from torch import nn
+from transformers import PretrainedConfig
+from vllm.config import LoRAConfig
from vllm.logger import init_logger
+from vllm.lora.fully_sharded_layers import (
+ ColumnParallelLinearWithShardedLoRA,
+ MergedColumnParallelLinearWithShardedLoRA,
+ MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
+# being imported for _all_lora_classes below
+# yapf conflicts with isort for this block
+# yapf: disable
+from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
+ LogitsProcessorWithLoRA,
+ MergedColumnParallelLinearWithLoRA,
+ MergedQKVParallelLinearWithLora,
+ QKVParallelLinearWithLora,
+ RowParallelLinearWithLoRA,
+ VocabParallelEmbeddingWithLoRA)
+# yapf: enable
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
+_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
+ VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
+ MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
+ MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
+ LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
+ MergedColumnParallelLinearWithShardedLoRA,
+ MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
+}
+
+
+def from_layer(layer: nn.Module,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig] = None) -> nn.Module:
+ for lora_cls in _all_lora_classes:
+ # specifying kwargs so they can be easily accessed in decorator
+ if lora_cls.can_replace_layer(source_layer=layer,
+ lora_config=lora_config,
+ packed_modules_list=packed_modules_list,
+ model_config=model_config):
+ ret = lora_cls(layer)
+ ret.create_lora_weights(max_loras, lora_config, model_config)
+ return ret
+ return layer
+
+
+def from_layer_logits_processor(
+ layer: LogitsProcessor,
+ lm_head: ParallelLMHead,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+) -> LogitsProcessorWithLoRA:
+ ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
+ lm_head.weight.dtype, lm_head.weight.device)
+ ret.create_lora_weights(max_loras, lora_config, model_config)
+ return ret
+
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
From 3da24c2df735354ccb463650c29cca8ce506fa07 Mon Sep 17 00:00:00 2001
From: Caio Mendes
Date: Sat, 27 Apr 2024 07:08:15 -0300
Subject: [PATCH 139/413] [Model] Phi-3 4k sliding window temp. fix (#4380)
---
vllm/core/block_manager_v1.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py
index be093922b84f2..1fac2636e86fa 100644
--- a/vllm/core/block_manager_v1.py
+++ b/vllm/core/block_manager_v1.py
@@ -1,4 +1,5 @@
"""A block manager that manages token blocks."""
+import math
from abc import ABC, abstractmethod
from itertools import count, takewhile
from os.path import commonprefix
@@ -220,9 +221,9 @@ def __init__(
self.block_sliding_window = None
if sliding_window is not None:
- assert sliding_window % block_size == 0, (sliding_window,
- block_size)
- self.block_sliding_window = sliding_window // block_size
+ # Round up to nearest block size to regularize sliding window
+ # allocation sizes.
+ self.block_sliding_window = math.ceil(sliding_window / block_size)
self.watermark = watermark
assert watermark >= 0.0
From 7134303cbbb7c82cdfcb0c87d59bb48fe6ad642b Mon Sep 17 00:00:00 2001
From: Roy
Date: Sat, 27 Apr 2024 19:30:08 +0800
Subject: [PATCH 140/413] [Bugfix][Core] Fix get decoding config from ray
(#4335)
---
tests/async_engine/test_async_llm_engine.py | 2 +
tests/async_engine/test_openapi_server_ray.py | 157 ++++++++++++++++++
vllm/engine/async_llm_engine.py | 10 +-
vllm/engine/llm_engine.py | 4 +
vllm/entrypoints/openai/serving_chat.py | 2 +-
vllm/entrypoints/openai/serving_completion.py | 2 +-
6 files changed, 174 insertions(+), 3 deletions(-)
create mode 100644 tests/async_engine/test_openapi_server_ray.py
diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py
index cb125a7bfec30..b69cdc0a21409 100644
--- a/tests/async_engine/test_async_llm_engine.py
+++ b/tests/async_engine/test_async_llm_engine.py
@@ -91,4 +91,6 @@ async def test_new_requests_event():
assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
+ assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
+ assert engine.get_decoding_config() is not None
diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py
new file mode 100644
index 0000000000000..4b97af88012b9
--- /dev/null
+++ b/tests/async_engine/test_openapi_server_ray.py
@@ -0,0 +1,157 @@
+# imports for guided decoding tests
+import os
+import subprocess
+import sys
+import time
+
+import openai # use the official client for correctness check
+import pytest
+# using Ray for overall ease of process management, parallel requests,
+# and debugging.
+import ray
+import requests
+
+MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
+# any model with a chat template should work here
+MODEL_NAME = "facebook/opt-125m"
+
+
+@ray.remote(num_gpus=1)
+class ServerRunner:
+
+ def __init__(self, args):
+ env = os.environ.copy()
+ env["PYTHONUNBUFFERED"] = "1"
+ self.proc = subprocess.Popen(
+ ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
+ env=env,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ )
+ self._wait_for_server()
+
+ def ready(self):
+ return True
+
+ def _wait_for_server(self):
+ # run health check
+ start = time.time()
+ while True:
+ try:
+ if requests.get(
+ "http://localhost:8000/health").status_code == 200:
+ break
+ except Exception as err:
+ if self.proc.poll() is not None:
+ raise RuntimeError("Server exited unexpectedly.") from err
+
+ time.sleep(0.5)
+ if time.time() - start > MAX_SERVER_START_WAIT_S:
+ raise RuntimeError(
+ "Server failed to start in time.") from err
+
+ def __del__(self):
+ if hasattr(self, "proc"):
+ self.proc.terminate()
+
+
+@pytest.fixture(scope="session")
+def server():
+ ray.init()
+ server_runner = ServerRunner.remote([
+ "--model",
+ MODEL_NAME,
+ # use half precision for speed and memory savings in CI environment
+ "--dtype",
+ "float16",
+ "--max-model-len",
+ "2048",
+ "--enforce-eager",
+ "--engine-use-ray"
+ ])
+ ray.get(server_runner.ready.remote())
+ yield server_runner
+ ray.shutdown()
+
+
+@pytest.fixture(scope="session")
+def client():
+ client = openai.AsyncOpenAI(
+ base_url="http://localhost:8000/v1",
+ api_key="token-abc123",
+ )
+ yield client
+
+
+@pytest.mark.asyncio
+async def test_check_models(server, client: openai.AsyncOpenAI):
+ models = await client.models.list()
+ models = models.data
+ served_model = models[0]
+ assert served_model.id == MODEL_NAME
+ assert all(model.root == MODEL_NAME for model in models)
+
+
+@pytest.mark.asyncio
+async def test_single_completion(server, client: openai.AsyncOpenAI):
+ completion = await client.completions.create(model=MODEL_NAME,
+ prompt="Hello, my name is",
+ max_tokens=5,
+ temperature=0.0)
+
+ assert completion.id is not None
+ assert completion.choices is not None and len(completion.choices) == 1
+ assert completion.choices[0].text is not None and len(
+ completion.choices[0].text) >= 5
+ assert completion.choices[0].finish_reason == "length"
+ assert completion.usage == openai.types.CompletionUsage(
+ completion_tokens=5, prompt_tokens=6, total_tokens=11)
+
+ # test using token IDs
+ completion = await client.completions.create(
+ model=MODEL_NAME,
+ prompt=[0, 0, 0, 0, 0],
+ max_tokens=5,
+ temperature=0.0,
+ )
+ assert completion.choices[0].text is not None and len(
+ completion.choices[0].text) >= 5
+
+
+@pytest.mark.asyncio
+async def test_single_chat_session(server, client: openai.AsyncOpenAI):
+ messages = [{
+ "role": "system",
+ "content": "you are a helpful assistant"
+ }, {
+ "role": "user",
+ "content": "what is 1+1?"
+ }]
+
+ # test single completion
+ chat_completion = await client.chat.completions.create(model=MODEL_NAME,
+ messages=messages,
+ max_tokens=10,
+ logprobs=True,
+ top_logprobs=5)
+ assert chat_completion.id is not None
+ assert chat_completion.choices is not None and len(
+ chat_completion.choices) == 1
+ assert chat_completion.choices[0].message is not None
+ assert chat_completion.choices[0].logprobs is not None
+ assert chat_completion.choices[0].logprobs.top_logprobs is not None
+ assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
+ message = chat_completion.choices[0].message
+ assert message.content is not None and len(message.content) >= 10
+ assert message.role == "assistant"
+ messages.append({"role": "assistant", "content": message.content})
+
+ # test multi-turn dialogue
+ messages.append({"role": "user", "content": "express your result in json"})
+ chat_completion = await client.chat.completions.create(
+ model=MODEL_NAME,
+ messages=messages,
+ max_tokens=10,
+ )
+ message = chat_completion.choices[0].message
+ assert message.content is not None and len(message.content) >= 0
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 89ee3f0db491c..7c1eb2ecbe550 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -7,7 +7,7 @@
from transformers import PreTrainedTokenizer
-from vllm.config import ModelConfig
+from vllm.config import DecodingConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
@@ -697,6 +697,14 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()
+ async def get_decoding_config(self) -> DecodingConfig:
+ """Get the decoding configuration of the vLLM engine."""
+ if self.engine_use_ray:
+ return await self.engine.get_decoding_config.remote( # type: ignore
+ )
+ else:
+ return self.engine.get_decoding_config()
+
async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote() # type: ignore
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 741d3bcd80890..292504631b06d 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -467,6 +467,10 @@ def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
+ def get_decoding_config(self) -> DecodingConfig:
+ """Gets the decoding configuration."""
+ return self.decoding_config
+
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 629dd929dc1af..5ed042ef386ea 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -101,7 +101,7 @@ async def create_chat_completion(
request, prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
- decoding_config = self.engine.engine.decoding_config
+ decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 7904bb698c45a..6a7f29c4c96f2 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -89,7 +89,7 @@ async def create_completion(self, request: CompletionRequest,
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
- decoding_config = self.engine.engine.decoding_config
+ decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
From dfea17314827845d55dabb03ebe905f58e6682e4 Mon Sep 17 00:00:00 2001
From: Ruoyu Qin
Date: Sun, 28 Apr 2024 00:48:37 +0800
Subject: [PATCH 141/413] [Bugfix] Abort requests when the connection to
/v1/completions is interrupted (#4363)
---
.../test_merge_async_iterators.py | 41 +++++++++++++++++++
vllm/utils.py | 17 +++++---
2 files changed, 53 insertions(+), 5 deletions(-)
create mode 100644 tests/async_engine/test_merge_async_iterators.py
diff --git a/tests/async_engine/test_merge_async_iterators.py b/tests/async_engine/test_merge_async_iterators.py
new file mode 100644
index 0000000000000..ea453526c77f8
--- /dev/null
+++ b/tests/async_engine/test_merge_async_iterators.py
@@ -0,0 +1,41 @@
+import asyncio
+from typing import AsyncIterator, Tuple
+
+import pytest
+
+from vllm.utils import merge_async_iterators
+
+
+@pytest.mark.asyncio
+async def test_merge_async_iterators():
+
+ async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
+ try:
+ while True:
+ yield f"item from iterator {idx}"
+ await asyncio.sleep(0.1)
+ except asyncio.CancelledError:
+ pass
+
+ iterators = [mock_async_iterator(i) for i in range(3)]
+ merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
+ *iterators)
+
+ async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
+ async for idx, output in generator:
+ print(f"idx: {idx}, output: {output}")
+
+ task = asyncio.create_task(stream_output(merged_iterator))
+ await asyncio.sleep(0.5)
+ task.cancel()
+ with pytest.raises(asyncio.CancelledError):
+ await task
+
+ for iterator in iterators:
+ try:
+ await asyncio.wait_for(anext(iterator), 1)
+ except StopAsyncIteration:
+ # All iterators should be cancelled and print this message.
+ print("Iterator was cancelled normally")
+ except (Exception, asyncio.CancelledError) as e:
+ raise AssertionError() from e
diff --git a/vllm/utils.py b/vllm/utils.py
index 76c2fc66e47c3..88447878f1706 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -225,11 +225,18 @@ async def producer(i: int, iterator: AsyncIterator[T]):
]
async def consumer():
- while not all(finished) or not queue.empty():
- item = await queue.get()
- if isinstance(item, Exception):
- raise item
- yield item
+ try:
+ while not all(finished) or not queue.empty():
+ item = await queue.get()
+ if isinstance(item, Exception):
+ raise item
+ yield item
+ except (Exception, asyncio.CancelledError) as e:
+ for task in _tasks:
+ # NOTE: Pass the error msg in cancel()
+ # when only Python 3.9+ is supported.
+ task.cancel()
+ raise e
await asyncio.gather(*_tasks)
return consumer()
From 81661da7b2d446cff7065fd6b34f1b7051098d24 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Sat, 27 Apr 2024 09:52:46 -0700
Subject: [PATCH 142/413] [BugFix] Fix `min_tokens` when `eos_token_id` is None
(#4389)
Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
---
tests/samplers/test_sampler.py | 9 +++------
vllm/engine/llm_engine.py | 5 +++--
vllm/model_executor/layers/sampler.py | 14 ++++++--------
vllm/sampling_params.py | 4 ++--
4 files changed, 14 insertions(+), 18 deletions(-)
diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py
index 6f2145f8cdcf4..7859f0b21812f 100644
--- a/tests/samplers/test_sampler.py
+++ b/tests/samplers/test_sampler.py
@@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
*,
- stop_token_ids: Optional[List[str]] = None,
+ stop_token_ids: Optional[List[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
@@ -216,7 +216,7 @@ def create_sampling_params(min_tokens,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
)
- sampling_params.eos_token_id = eos_token_id
+ sampling_params.all_stop_token_ids.add(eos_token_id)
return sampling_params
def create_sequence_data(num_input=3, num_generated=0):
@@ -461,10 +461,7 @@ def run_test_case(*,
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_row)):
- tokens_to_check = [sampling_params.eos_token_id]
- if sampling_params.stop_token_ids:
- tokens_to_check.extend(sampling_params.stop_token_ids)
- tokens_to_check = set(tokens_to_check)
+ tokens_to_check = sampling_params.all_stop_token_ids
if should_penalize:
for token_id in tokens_to_check:
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 292504631b06d..7e9553d839cea 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -431,9 +431,10 @@ def add_request(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
- # inject the eos token id into the sampling_params to support min_tokens
+ # Add the eos token id into the sampling_params to support min_tokens
# processing
- sampling_params.eos_token_id = seq.eos_token_id
+ if seq.eos_token_id is not None:
+ sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index 2ffa8227cc4ed..4ef25edecfd24 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -169,19 +169,17 @@ def _apply_min_tokens_penalty(
start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
- if min_tokens > 0:
+ token_ids_to_penalize = sampling_params.all_stop_token_ids
+ if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = []
- for i, seq_id in enumerate(seq_ids):
+ for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
- seqs_to_penalize.append(i)
+ seqs_to_penalize.append(j)
if seqs_to_penalize:
# convert to the index into logits
- seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
- # use set() to remove any duplicates
- token_ids_to_penalize = set(sampling_params.stop_token_ids +
- [sampling_params.eos_token_id])
+ seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
@@ -645,7 +643,7 @@ def _sample(
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
- sampled_token_ids_tensor: A tensor of sampled token ids.
+ sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index dc0e60344d858..0ed6a01a62212 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -185,8 +185,8 @@ def __init__(
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
- # injected by the engine
- self.eos_token_id = None
+ # eos_token_id is added to this by the engine
+ self.all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
if self.n < 1:
From d6e520e1700f78de2d5efdb8607a76cbab61182e Mon Sep 17 00:00:00 2001
From: Prashant Gupta
Date: Sat, 27 Apr 2024 09:59:55 -0700
Subject: [PATCH 143/413] [Core] Support offline use of local cache for models
(#4374)
Signed-off-by: Prashant Gupta
Co-authored-by: Travis Johnson
---
tests/model_executor/weight_utils.py | 30 +++++++++-
vllm/model_executor/model_loader/loader.py | 5 +-
.../model_loader/weight_utils.py | 59 +++++++++++--------
vllm/transformers_utils/tokenizer.py | 2 +
4 files changed, 69 insertions(+), 27 deletions(-)
diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py
index b0086dd7a7d71..c8b9bed691bba 100644
--- a/tests/model_executor/weight_utils.py
+++ b/tests/model_executor/weight_utils.py
@@ -1,9 +1,12 @@
import os
+import tempfile
import huggingface_hub.constants
import pytest
+from huggingface_hub.utils import LocalEntryNotFoundError
-from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer
+from vllm.model_executor.model_loader.weight_utils import (
+ download_weights_from_hf, enable_hf_transfer)
def test_hf_transfer_auto_activation():
@@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation():
HF_TRANFER_ACTIVE)
+def test_download_weights_from_hf():
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # assert LocalEntryNotFoundError error is thrown
+ # if offline is set and model is not cached
+ huggingface_hub.constants.HF_HUB_OFFLINE = True
+ with pytest.raises(LocalEntryNotFoundError):
+ download_weights_from_hf("facebook/opt-125m",
+ allow_patterns=["*.safetensors", "*.bin"],
+ cache_dir=tmpdir)
+
+ # download the model
+ huggingface_hub.constants.HF_HUB_OFFLINE = False
+ download_weights_from_hf("facebook/opt-125m",
+ allow_patterns=["*.safetensors", "*.bin"],
+ cache_dir=tmpdir)
+
+ # now it should work offline
+ huggingface_hub.constants.HF_HUB_OFFLINE = True
+ assert download_weights_from_hf(
+ "facebook/opt-125m",
+ allow_patterns=["*.safetensors", "*.bin"],
+ cache_dir=tmpdir) is not None
+
+
if __name__ == "__main__":
test_hf_transfer_auto_activation()
+ test_download_weights_from_hf()
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index ad80243019a65..70e64167f8698 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
+import huggingface_hub
import torch
from torch import nn
@@ -131,7 +132,9 @@ def _maybe_download_from_modelscope(
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
- revision=revision)
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
+ revision=revision,
+ )
else:
model_path = model
return model_path
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index c0905b9051314..c1abde9af7701 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
- hf_folder = snapshot_download(model_name_or_path,
- revision=model_config.revision,
- allow_patterns="*.json",
- cache_dir=load_config.download_dir,
- tqdm_class=DisabledTqdm)
+ hf_folder = snapshot_download(
+ model_name_or_path,
+ revision=model_config.revision,
+ allow_patterns="*.json",
+ cache_dir=load_config.download_dir,
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
+ tqdm_class=DisabledTqdm,
+ )
else:
hf_folder = model_name_or_path
@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config)
-def download_weights_from_hf(model_name_or_path: str,
- cache_dir: Optional[str],
- allow_patterns: List[str],
- revision: Optional[str] = None) -> str:
+def download_weights_from_hf(
+ model_name_or_path: str,
+ cache_dir: Optional[str],
+ allow_patterns: List[str],
+ revision: Optional[str] = None,
+) -> str:
"""Download model weights from Hugging Face Hub.
-
+
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns:
str: The path to the downloaded model weights.
"""
- # Before we download we look at that is available:
- fs = HfFileSystem()
- file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
-
- # depending on what is available we download different things
- for pattern in allow_patterns:
- matching = fnmatch.filter(file_list, pattern)
- if len(matching) > 0:
- allow_patterns = [pattern]
- break
+ if not huggingface_hub.constants.HF_HUB_OFFLINE:
+ # Before we download we look at that is available:
+ fs = HfFileSystem()
+ file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
+
+ # depending on what is available we download different things
+ for pattern in allow_patterns:
+ matching = fnmatch.filter(file_list, pattern)
+ if len(matching) > 0:
+ allow_patterns = [pattern]
+ break
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
- hf_folder = snapshot_download(model_name_or_path,
- allow_patterns=allow_patterns,
- cache_dir=cache_dir,
- tqdm_class=DisabledTqdm,
- revision=revision)
+ hf_folder = snapshot_download(
+ model_name_or_path,
+ allow_patterns=allow_patterns,
+ cache_dir=cache_dir,
+ tqdm_class=DisabledTqdm,
+ revision=revision,
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
+ )
return hf_folder
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index 2fcddc3bea5ab..fa4693cb7dac1 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -1,6 +1,7 @@
import os
from typing import Optional, Union
+import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
@@ -76,6 +77,7 @@ def get_tokenizer(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=revision,
+ local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"])
tokenizer_name = tokenizer_path
From ba4be44c32761d30f1e17656b863d2cc078af9e4 Mon Sep 17 00:00:00 2001
From: Nick Hill
Date: Sat, 27 Apr 2024 11:17:45 -0700
Subject: [PATCH 144/413] [BugFix] Fix return type of executor execute_model
methods (#4402)
---
vllm/executor/cpu_executor.py | 2 +-
vllm/executor/distributed_gpu_executor.py | 7 ++++---
vllm/executor/executor_base.py | 2 +-
vllm/executor/gpu_executor.py | 2 +-
vllm/executor/neuron_executor.py | 2 +-
vllm/executor/ray_gpu_executor.py | 2 +-
6 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index aa810f9743395..e4436b2144bd3 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -109,7 +109,7 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
- ) -> SamplerOutput:
+ ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py
index 9dccfa4946391..4c922ef63ee04 100644
--- a/vllm/executor/distributed_gpu_executor.py
+++ b/vllm/executor/distributed_gpu_executor.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Any, Dict, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
@@ -52,7 +52,7 @@ def initialize_cache(self, num_gpu_blocks: int,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
- def execute_model(self, *args, **kwargs) -> SamplerOutput:
+ def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
all_outputs = self._run_workers("execute_model",
driver_args=args,
driver_kwargs=kwargs)
@@ -105,7 +105,8 @@ async def _run_workers_async(
"""Runs the given method on all workers."""
raise NotImplementedError
- async def execute_model_async(self, *args, **kwargs) -> SamplerOutput:
+ async def execute_model_async(self, *args,
+ **kwargs) -> List[SamplerOutput]:
all_outputs = await self._run_workers_async("execute_model",
driver_args=args,
driver_kwargs=kwargs)
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index 1838c34be2fda..c36aa18fb25bb 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -112,7 +112,7 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
- ) -> SamplerOutput:
+ ) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index d2c60a3b68e14..5ac62f02b99c7 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -163,7 +163,7 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
- ) -> SamplerOutput:
+ ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index 5a137d1bdcb3b..f406287f3c1d8 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -84,7 +84,7 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
- ) -> SamplerOutput:
+ ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, )
return output
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index 1082984828357..b6bcda4e6b18c 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -188,7 +188,7 @@ def execute_model(self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
- num_lookahead_slots: int = 0) -> SamplerOutput:
+ num_lookahead_slots: int = 0) -> List[SamplerOutput]:
all_outputs = self._run_workers(
"execute_model",
driver_kwargs={
From 4ea1f9678dd93f02424ab3de2149f83a490e6c6f Mon Sep 17 00:00:00 2001
From: Robert Shaw
<114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Date: Sat, 27 Apr 2024 14:35:33 -0400
Subject: [PATCH 145/413] [BugFix] Resolved Issues For LinearMethod -->
QuantConfig (#4418)
---
vllm/model_executor/models/bloom.py | 1 -
vllm/model_executor/models/falcon.py | 1 -
vllm/model_executor/models/gpt2.py | 1 -
vllm/model_executor/models/gpt_bigcode.py | 1 -
vllm/model_executor/models/gpt_j.py | 1 -
vllm/model_executor/models/gpt_neox.py | 1 -
vllm/model_executor/models/mpt.py | 1 -
vllm/model_executor/models/opt.py | 1 -
vllm/model_executor/models/phi.py | 1 -
vllm/model_executor/models/starcoder2.py | 1 -
10 files changed, 10 deletions(-)
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index b425af4863c36..1d7e5d2517c72 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -139,7 +139,6 @@ def __init__(
4 * hidden_size,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index 4be1f064cdd3e..08dd69923dc6d 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -203,7 +203,6 @@ def __init__(
bias=config.bias,
skip_bias_add=True,
quant_config=quant_config)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index ac1dce6dec8a6..75eaebf0dbd15 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -107,7 +107,6 @@ def __init__(
bias=True,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index e52ac679f5d03..d057fd928fdb5 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -128,7 +128,6 @@ def __init__(
bias=True,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index 287f4186f7469..8d7fe8a5beef7 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -120,7 +120,6 @@ def __init__(
hidden_size,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
intermediate_size)
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index cbc5115bd377b..bab563b9c5a39 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -119,7 +119,6 @@ def __init__(
config.hidden_size,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index 8c5e7e77c9306..6fa5c5bd3014a 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -146,7 +146,6 @@ def __init__(
bias=not config.no_bias,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, intermediate_size)
self.down_proj = RowParallelLinear(
intermediate_size,
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 838a2f0adc4d1..336f765ababaa 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -130,7 +130,6 @@ def __init__(
bias=config.enable_bias,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index 7a9b8dcd6a509..4a45879201af3 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -142,7 +142,6 @@ def __init__(self,
config.hidden_size,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
def forward(self, hidden_states):
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 29d887b21032b..33998e2aad5c5 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -136,7 +136,6 @@ def __init__(self,
bias=config.use_bias,
quant_config=quant_config,
)
- quant_config = getattr(quant_config, "quant_config", None)
self.act = get_act_fn(config.hidden_act, quant_config,
config.intermediate_size)
From 9c7306ac114da3e31a5ff040a76f6c640354cce8 Mon Sep 17 00:00:00 2001
From: DefTruth <31974251+DefTruth@users.noreply.github.com>
Date: Sun, 28 Apr 2024 18:58:30 +0800
Subject: [PATCH 146/413] [Misc] fix typo in llm_engine init logging (#4428)
---
vllm/engine/llm_engine.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 7e9553d839cea..53a680580390f 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -101,7 +101,7 @@ def __init__(
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
- "tensor_parallel_size=%d, disable_custom_all_reduce=%s"
+ "tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d)",
From bf480c53027cb009427581af0100de75ac7a2e5f Mon Sep 17 00:00:00 2001
From: Ronen Schaffer
Date: Mon, 29 Apr 2024 01:59:33 +0300
Subject: [PATCH 147/413] Add more Prometheus metrics (#2764)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw
---
examples/production_monitoring/grafana.json | 283 ++++++++++++++++++++
requirements-common.txt | 1 +
vllm/core/scheduler.py | 2 +-
vllm/engine/llm_engine.py | 171 ++++++++----
vllm/engine/metrics.py | 221 +++++++++++----
vllm/sequence.py | 18 +-
6 files changed, 582 insertions(+), 114 deletions(-)
diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json
index 071f134c6e5e0..5e9bd5bd03869 100644
--- a/examples/production_monitoring/grafana.json
+++ b/examples/production_monitoring/grafana.json
@@ -873,6 +873,289 @@
],
"title": "Cache Utilization",
"type": "timeseries"
+ },
+ {
+ "type": "heatmap",
+ "title": "Request Prompt Length",
+ "description": "Heatmap of request prompt length",
+ "gridPos": {
+ "x": 0,
+ "y": 24,
+ "w": 12,
+ "h": 8
+ },
+ "datasource": {
+ "uid": "prometheus",
+ "type": "prometheus"
+ },
+ "id": 12,
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "prometheus"
+ },
+ "refId": "A",
+ "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))",
+ "range": true,
+ "instant": false,
+ "editorMode": "builder",
+ "legendFormat": "{{le}}",
+ "useBackend": false,
+ "disableTextWrap": false,
+ "fullMetaSearch": false,
+ "includeNullMetadata": true,
+ "format": "heatmap"
+ }
+ ],
+ "options": {
+ "calculate": false,
+ "yAxis": {
+ "axisPlacement": "left",
+ "reverse": false,
+ "unit": "none",
+ "axisLabel": "Prompt Length"
+ },
+ "rowsFrame": {
+ "layout": "auto",
+ "value": "Request count"
+ },
+ "color": {
+ "mode": "scheme",
+ "fill": "dark-orange",
+ "scale": "exponential",
+ "exponent": 0.5,
+ "scheme": "Spectral",
+ "steps": 64,
+ "reverse": false,
+ "min": 0
+ },
+ "cellGap": 1,
+ "filterValues": {
+ "le": 1e-9
+ },
+ "tooltip": {
+ "show": true,
+ "yHistogram": true
+ },
+ "legend": {
+ "show": true
+ },
+ "exemplars": {
+ "color": "rgba(255,0,255,0.7)"
+ },
+ "cellValues": {
+ "unit": "none"
+ }
+ },
+ "fieldConfig": {
+ "defaults": {
+ "custom": {
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "hideFrom": {
+ "tooltip": false,
+ "viz": false,
+ "legend": false
+ }
+ }
+ },
+ "overrides": []
+ },
+ "pluginVersion": "10.2.0"
+ },
+ {
+ "datasource": {
+ "uid": "prometheus",
+ "type": "prometheus"
+ },
+ "type": "heatmap",
+ "title": "Request Generation Length",
+ "description": "Heatmap of request generation length",
+ "gridPos": {
+ "x": 12,
+ "y": 24,
+ "w": 12,
+ "h": 8
+ },
+ "id": 13,
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "prometheus"
+ },
+ "refId": "A",
+ "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))",
+ "range": true,
+ "instant": false,
+ "editorMode": "builder",
+ "legendFormat": "{{le}}",
+ "useBackend": false,
+ "disableTextWrap": false,
+ "fullMetaSearch": false,
+ "includeNullMetadata": true,
+ "format": "heatmap"
+ }
+ ],
+ "options": {
+ "calculate": false,
+ "yAxis": {
+ "axisPlacement": "left",
+ "reverse": false,
+ "unit": "none",
+ "axisLabel": "Generation Length"
+ },
+ "rowsFrame": {
+ "layout": "auto",
+ "value": "Request count"
+ },
+ "color": {
+ "mode": "scheme",
+ "fill": "dark-orange",
+ "scale": "exponential",
+ "exponent": 0.5,
+ "scheme": "Spectral",
+ "steps": 64,
+ "reverse": false,
+ "min": 0
+ },
+ "cellGap": 1,
+ "filterValues": {
+ "le": 1e-9
+ },
+ "tooltip": {
+ "show": true,
+ "yHistogram": true
+ },
+ "legend": {
+ "show": true
+ },
+ "exemplars": {
+ "color": "rgba(255,0,255,0.7)"
+ },
+ "cellValues": {
+ "unit": "none"
+ }
+ },
+ "fieldConfig": {
+ "defaults": {
+ "custom": {
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "hideFrom": {
+ "tooltip": false,
+ "viz": false,
+ "legend": false
+ }
+ }
+ },
+ "overrides": []
+ },
+ "pluginVersion": "10.2.0"
+ },
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "prometheus"
+ },
+ "fieldConfig": {
+ "defaults": {
+ "custom": {
+ "drawStyle": "line",
+ "lineInterpolation": "linear",
+ "barAlignment": 0,
+ "lineWidth": 1,
+ "fillOpacity": 0,
+ "gradientMode": "none",
+ "spanNulls": false,
+ "insertNulls": false,
+ "showPoints": "auto",
+ "pointSize": 5,
+ "stacking": {
+ "mode": "none",
+ "group": "A"
+ },
+ "axisPlacement": "auto",
+ "axisLabel": "",
+ "axisColorMode": "text",
+ "axisBorderShow": false,
+ "scaleDistribution": {
+ "type": "linear"
+ },
+ "axisCenteredZero": false,
+ "hideFrom": {
+ "tooltip": false,
+ "viz": false,
+ "legend": false
+ },
+ "thresholdsStyle": {
+ "mode": "off"
+ }
+ },
+ "color": {
+ "mode": "palette-classic"
+ },
+ "mappings": [],
+ "thresholds": {
+ "mode": "absolute",
+ "steps": [
+ {
+ "color": "green",
+ "value": null
+ },
+ {
+ "color": "red",
+ "value": 80
+ }
+ ]
+ }
+ },
+ "overrides": []
+ },
+ "gridPos": {
+ "h": 8,
+ "w": 12,
+ "x": 0,
+ "y": 32
+ },
+ "id": 11,
+ "options": {
+ "tooltip": {
+ "mode": "single",
+ "sort": "none"
+ },
+ "legend": {
+ "showLegend": true,
+ "displayMode": "list",
+ "placement": "bottom",
+ "calcs": []
+ }
+ },
+ "targets": [
+ {
+ "datasource": {
+ "type": "prometheus",
+ "uid": "prometheus"
+ },
+ "disableTextWrap": false,
+ "editorMode": "builder",
+ "expr": "sum by(finished_reason) (increase(vllm:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))",
+ "fullMetaSearch": false,
+ "includeNullMetadata": true,
+ "instant": false,
+ "interval": "",
+ "legendFormat": "__auto",
+ "range": true,
+ "refId": "A",
+ "useBackend": false
+ }
+ ],
+ "title": "Finish Reason",
+ "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.",
+ "type": "timeseries"
}
],
"refresh": "",
diff --git a/requirements-common.txt b/requirements-common.txt
index e9db261c6aec9..3abb828116680 100644
--- a/requirements-common.txt
+++ b/requirements-common.txt
@@ -12,6 +12,7 @@ openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
+prometheus-fastapi-instrumentator >= 7.0.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.8
outlines == 0.0.34 # Requires torch >= 2.1.0
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index 7439f7dc33e8d..024b7e7013441 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -320,7 +320,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
for seq_group in state_queue:
if not request_ids:
# Using 'break' here may add two extra iterations,
- # but is acceptable to reduce complexity .
+ # but is acceptable to reduce complexity.
break
if seq_group.request_id in request_ids:
# Appending aborted group into pending list.
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 53a680580390f..835803fd4e75d 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -22,7 +22,8 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
- SequenceGroup, SequenceGroupMetadata)
+ SequenceGroup, SequenceGroupMetadata,
+ SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -217,7 +218,8 @@ def __init__(
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
- labels=dict(model_name=model_config.model))
+ labels=dict(model_name=model_config.model),
+ max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
# Create sequence output processor, e.g. for beam search or
@@ -619,59 +621,109 @@ def _get_stats(
"""
now = time.time()
- # KV Cache Usage in %.
+ # System State
+ # Scheduler State
+ num_running_sys = len(self.scheduler.running)
+ num_swapped_sys = len(self.scheduler.swapped)
+ num_waiting_sys = len(self.scheduler.waiting)
+
+ # KV Cache Usage in %
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
- gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
+ gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
num_total_cpu = self.cache_config.num_cpu_blocks
- cpu_cache_usage = 0.
+ cpu_cache_usage_sys = 0.
if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
- cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
-
- # Scheduler State
- num_running = len(self.scheduler.running)
- num_swapped = len(self.scheduler.swapped)
- num_waiting = len(self.scheduler.waiting)
-
- # Iteration stats if we have scheduler output.
- num_prompt_tokens = 0
- num_generation_tokens = 0
- time_to_first_tokens = []
- time_per_output_tokens = []
- time_e2e_requests = []
+ cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
+
+ # Iteration stats
+ num_prompt_tokens_iter = 0
+ num_generation_tokens_iter = 0
+ time_to_first_tokens_iter: List[float] = []
+ time_per_output_tokens_iter: List[float] = []
+
+ # Request stats
+ # Latency
+ time_e2e_requests: List[float] = []
+ # Metadata
+ num_prompt_tokens_requests: List[int] = []
+ num_generation_tokens_requests: List[int] = []
+ best_of_requests: List[int] = []
+ n_requests: List[int] = []
+ finished_reason_requests: List[str] = []
+
+ # NOTE: This loop assumes prefill seq_groups are before
+ # decode seq_groups in scheduled_seq_groups.
if scheduler_outputs is not None:
- prompt_run = scheduler_outputs.num_prefill_groups > 0
-
- # Number of Tokens.
- if prompt_run:
- num_prompt_tokens = sum(
- len(scheduled_seq_group.seq_group.prompt_token_ids)
- for scheduled_seq_group in
- scheduler_outputs.scheduled_seq_groups)
- num_generation_tokens = sum(
- scheduled_seq_group.seq_group.num_seqs()
- for scheduled_seq_group in
- scheduler_outputs.scheduled_seq_groups)
- else:
- num_generation_tokens = scheduler_outputs.num_batched_tokens
-
- # Latency Timings.
- time_last_iters = []
- for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
+ num_generation_tokens_from_prefill_groups = 0.
+ if scheduler_outputs.num_prefill_groups > 0 and len(
+ scheduler_outputs.scheduled_seq_groups
+ ) != scheduler_outputs.num_prefill_groups:
+ print("DETECTED CHUNKED")
+
+ for idx, scheduled_seq_group in enumerate(
+ scheduler_outputs.scheduled_seq_groups):
+ group_was_prefill = idx < scheduler_outputs.num_prefill_groups
seq_group = scheduled_seq_group.seq_group
- # Time since last token.
- # (n.b. updates seq_group.metrics.last_token_time)
- time_last_iters.append(seq_group.get_last_latency(now))
- # Time since arrival for all finished requests.
+
+ # NOTE: a seq_group that completed all of its prefill tokens
+ # in the last iteration will have seq_group.is_prefill() = False
+ # with group_was_prefill = True
+ if group_was_prefill:
+ # Number of prompt tokens.
+ num_prompt_tokens_iter += (
+ scheduled_seq_group.token_chunk_size)
+
+ # If the seq_group just finished the prefill state
+ # get TTFT.
+ if not seq_group.is_prefill():
+ latency = seq_group.get_last_latency(now)
+ time_to_first_tokens_iter.append(latency)
+
+ # One generation token per finished prefill.
+ num_generation_tokens_from_prefill_groups += (
+ seq_group.num_seqs())
+ else:
+ # TPOTs.
+ latency = seq_group.get_last_latency(now)
+ time_per_output_tokens_iter.append(latency)
+
+ # Because of chunked prefill, we can have a single sequence
+ # group that does multiple prompt_runs. To prevent logging
+ # the same metadata more than once per request, we standardize
+ # on logging request level information for finished requests,
+ # which can only happen once.
if seq_group.is_finished():
+ # Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
- time_to_first_tokens = time_last_iters if prompt_run else []
- time_per_output_tokens = [] if prompt_run else time_last_iters
+ # Metadata
+ num_prompt_tokens_requests.append(
+ len(seq_group.prompt_token_ids))
+ num_generation_tokens_requests.extend([
+ seq.get_output_len()
+ for seq in seq_group.get_finished_seqs()
+ ])
+ best_of_requests.append(seq_group.sampling_params.best_of)
+ n_requests.append(seq_group.sampling_params.n)
+ finished_reason_requests.extend([
+ SequenceStatus.get_finished_reason(seq.status)
+ for seq in seq_group.get_finished_seqs()
+ ])
+
+ # Number of generation tokens.
+ # num_batched_tokens equals the number of prompt_tokens plus the
+ # number of decode_tokens in a single iteration. So,
+ # num_generation_tokens = num_batched_tokens - num_prompt_tokens
+ # + num_generation_tokens_from_prefill_groups (since we generate
+ # one token on prefills on iters where the prefill finishes).
+ num_generation_tokens_iter = (
+ scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
+ num_generation_tokens_from_prefill_groups)
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
@@ -683,17 +735,32 @@ def _get_stats(
return Stats(
now=now,
- num_running=num_running,
- num_swapped=num_swapped,
- num_waiting=num_waiting,
- gpu_cache_usage=gpu_cache_usage,
- cpu_cache_usage=cpu_cache_usage,
- num_prompt_tokens=num_prompt_tokens,
- num_generation_tokens=num_generation_tokens,
- time_to_first_tokens=time_to_first_tokens,
- time_per_output_tokens=time_per_output_tokens,
- time_e2e_requests=time_e2e_requests,
+
+ # System stats
+ # Scheduler State
+ num_running_sys=num_running_sys,
+ num_swapped_sys=num_swapped_sys,
+ num_waiting_sys=num_waiting_sys,
+ # KV Cache Usage in %
+ gpu_cache_usage_sys=gpu_cache_usage_sys,
+ cpu_cache_usage_sys=cpu_cache_usage_sys,
+
+ # Iteration stats
+ num_prompt_tokens_iter=num_prompt_tokens_iter,
+ num_generation_tokens_iter=num_generation_tokens_iter,
+ time_to_first_tokens_iter=time_to_first_tokens_iter,
+ time_per_output_tokens_iter=time_per_output_tokens_iter,
spec_decode_metrics=spec_decode_metrics,
+
+ # Request stats
+ # Latency
+ time_e2e_requests=time_e2e_requests,
+ # Metadata
+ num_prompt_tokens_requests=num_prompt_tokens_requests,
+ num_generation_tokens_requests=num_generation_tokens_requests,
+ best_of_requests=best_of_requests,
+ n_requests=n_requests,
+ finished_reason_requests=finished_reason_requests,
)
def add_lora(self, lora_request: LoRARequest) -> bool:
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index eb54f5641171e..45bfad03ec867 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -1,6 +1,8 @@
import time
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Dict, List, Optional, Protocol
+from typing import TYPE_CHECKING
+from typing import Counter as CollectionsCounter
+from typing import Dict, List, Optional, Protocol, Union
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@@ -21,8 +23,9 @@
# begin-metrics-definitions
class Metrics:
+ labelname_finish_reason = "finished_reason"
- def __init__(self, labelnames: List[str]):
+ def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
@@ -34,18 +37,20 @@ def __init__(self, labelnames: List[str]):
documentation='information of cache_config')
# System stats
+ # Scheduler State
self.gauge_scheduler_running = Gauge(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
- self.gauge_scheduler_swapped = Gauge(
- name="vllm:num_requests_swapped",
- documentation="Number of requests swapped to CPU.",
- labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
+ self.gauge_scheduler_swapped = Gauge(
+ name="vllm:num_requests_swapped",
+ documentation="Number of requests swapped to CPU.",
+ labelnames=labelnames)
+ # KV Cache Usage in %
self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
@@ -55,7 +60,7 @@ def __init__(self, labelnames: List[str]):
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
- # Raw stats from last model iteration
+ # Iteration stats
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
@@ -80,18 +85,51 @@ def __init__(self, labelnames: List[str]):
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
- self.histogram_e2e_request_latency = Histogram(
+
+ # Request stats
+ # Latency
+ self.histogram_e2e_time_request = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
+ # Metadata
+ self.histogram_num_prompt_tokens_request = Histogram(
+ name="vllm:request_prompt_tokens",
+ documentation="Number of prefill tokens processed.",
+ labelnames=labelnames,
+ buckets=build_1_2_5_buckets(max_model_len),
+ )
+ self.histogram_num_generation_tokens_request = Histogram(
+ name="vllm:request_generation_tokens",
+ documentation="Number of generation tokens processed.",
+ labelnames=labelnames,
+ buckets=build_1_2_5_buckets(max_model_len),
+ )
+ self.histogram_best_of_request = Histogram(
+ name="vllm:request_params_best_of",
+ documentation="Histogram of the best_of request parameter.",
+ labelnames=labelnames,
+ buckets=[1, 2, 5, 10, 20],
+ )
+ self.histogram_n_request = Histogram(
+ name="vllm:request_params_n",
+ documentation="Histogram of the n request parameter.",
+ labelnames=labelnames,
+ buckets=[1, 2, 5, 10, 20],
+ )
+ self.counter_request_success = Counter(
+ name="vllm:request_success",
+ documentation="Count of successfully processed requests.",
+ labelnames=labelnames + [Metrics.labelname_finish_reason])
- # Legacy metrics
+ # Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
+ # Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
@@ -102,24 +140,57 @@ def __init__(self, labelnames: List[str]):
# end-metrics-definitions
+def build_1_2_5_buckets(max_value: int):
+ """
+ Builds a list of buckets with increasing powers of 10 multiplied by
+ mantissa values (1, 2, 5) until the value exceeds the specified maximum.
+
+ Example:
+ >>> build_1_2_5_buckets(100)
+ [1, 2, 5, 10, 20, 50, 100]
+ """
+ mantissa_lst = [1, 2, 5]
+ exponent = 0
+ buckets = []
+ while True:
+ for m in mantissa_lst:
+ value = m * 10**exponent
+ if value <= max_value:
+ buckets.append(value)
+ else:
+ return buckets
+ exponent += 1
+
+
@dataclass
class Stats:
"""Created by LLMEngine for use by StatLogger."""
now: float
- # System stats.
- num_running: int
- num_waiting: int
- num_swapped: int
- gpu_cache_usage: float
- cpu_cache_usage: float
-
- # Raw stats from last model iteration.
- num_prompt_tokens: int
- num_generation_tokens: int
- time_to_first_tokens: List[float]
- time_per_output_tokens: List[float]
+ # System stats (should have _sys suffix)
+ # Scheduler State
+ num_running_sys: int
+ num_waiting_sys: int
+ num_swapped_sys: int
+ # KV Cache Usage in %
+ gpu_cache_usage_sys: float
+ cpu_cache_usage_sys: float
+
+ # Iteration stats (should have _iter suffix)
+ num_prompt_tokens_iter: int
+ num_generation_tokens_iter: int
+ time_to_first_tokens_iter: List[float]
+ time_per_output_tokens_iter: List[float]
+
+ # Request stats (should have _requests suffix)
+ # Latency
time_e2e_requests: List[float]
+ # Metadata
+ num_prompt_tokens_requests: List[int]
+ num_generation_tokens_requests: List[int]
+ best_of_requests: List[int]
+ n_requests: List[int]
+ finished_reason_requests: List[str]
spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None
@@ -133,7 +204,8 @@ def metrics_info(self) -> Dict[str, str]:
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
- def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
+ def __init__(self, local_interval: float, labels: Dict[str, str],
+ max_model_len: int) -> None:
# Metadata for logging locally.
self.last_local_log = time.time()
self.local_interval = local_interval
@@ -144,7 +216,8 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Prometheus metrics
self.labels = labels
- self.metrics = Metrics(labelnames=list(labels.keys()))
+ self.metrics = Metrics(labelnames=list(labels.keys()),
+ max_model_len=max_model_len)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
@@ -158,34 +231,66 @@ def _local_interval_elapsed(self, now: float) -> bool:
return elapsed_time > self.local_interval
def _log_prometheus(self, stats: Stats) -> None:
- # Set system stat gauges.
- self.metrics.gauge_scheduler_running.labels(**self.labels).set(
- stats.num_running)
- self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
- stats.num_swapped)
- self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
- stats.num_waiting)
- self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
- stats.gpu_cache_usage)
- self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
- stats.cpu_cache_usage)
-
- # Add to token counters.
- self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
- stats.num_prompt_tokens)
- self.metrics.counter_generation_tokens.labels(**self.labels).inc(
- stats.num_generation_tokens)
-
- # Observe request level latencies in histograms.
- for ttft in stats.time_to_first_tokens:
- self.metrics.histogram_time_to_first_token.labels(
- **self.labels).observe(ttft)
- for tpot in stats.time_per_output_tokens:
- self.metrics.histogram_time_per_output_token.labels(
- **self.labels).observe(tpot)
- for e2e in stats.time_e2e_requests:
- self.metrics.histogram_e2e_request_latency.labels(
- **self.labels).observe(e2e)
+ # System state data
+ self._log_gauge(self.metrics.gauge_scheduler_running,
+ stats.num_running_sys)
+ self._log_gauge(self.metrics.gauge_scheduler_swapped,
+ stats.num_swapped_sys)
+ self._log_gauge(self.metrics.gauge_scheduler_waiting,
+ stats.num_waiting_sys)
+ self._log_gauge(self.metrics.gauge_gpu_cache_usage,
+ stats.gpu_cache_usage_sys)
+ self._log_gauge(self.metrics.gauge_cpu_cache_usage,
+ stats.cpu_cache_usage_sys)
+
+ # Iteration level data
+ self._log_counter(self.metrics.counter_prompt_tokens,
+ stats.num_prompt_tokens_iter)
+ self._log_counter(self.metrics.counter_generation_tokens,
+ stats.num_generation_tokens_iter)
+ self._log_histogram(self.metrics.histogram_time_to_first_token,
+ stats.time_to_first_tokens_iter)
+ self._log_histogram(self.metrics.histogram_time_per_output_token,
+ stats.time_per_output_tokens_iter)
+
+ # Request level data
+ # Latency
+ self._log_histogram(self.metrics.histogram_e2e_time_request,
+ stats.time_e2e_requests)
+ # Metadata
+ finished_reason_counter = CollectionsCounter(
+ stats.finished_reason_requests)
+ self._log_counter_labels(self.metrics.counter_request_success,
+ finished_reason_counter,
+ Metrics.labelname_finish_reason)
+ self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
+ stats.num_prompt_tokens_requests)
+ self._log_histogram(
+ self.metrics.histogram_num_generation_tokens_request,
+ stats.num_generation_tokens_requests)
+ self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
+ self._log_histogram(self.metrics.histogram_best_of_request,
+ stats.best_of_requests)
+
+ def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None:
+ # Convenience function for logging to gauge.
+ gauge.labels(**self.labels).set(data)
+
+ def _log_counter(self, counter: Counter, data: Union[int, float]) -> None:
+ # Convenience function for logging to counter.
+ counter.labels(**self.labels).inc(data)
+
+ def _log_counter_labels(self, counter: Counter, data: CollectionsCounter,
+ label_key: str) -> None:
+ # Convenience function for collection counter of labels.
+ for label, count in data.items():
+ counter.labels(**{**self.labels, label_key: label}).inc(count)
+
+ def _log_histogram(self, histogram: Histogram,
+ data: Union[List[int], List[float]]) -> None:
+ # Convenience function for logging list to histogram.
+ for datum in data:
+ histogram.labels(**self.labels).observe(datum)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
@@ -210,8 +315,8 @@ def log(self, stats: Stats) -> None:
self._log_prometheus(stats)
# Save tracked stats for token counters.
- self.num_prompt_tokens.append(stats.num_prompt_tokens)
- self.num_generation_tokens.append(stats.num_generation_tokens)
+ self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
+ self.num_generation_tokens.append(stats.num_generation_tokens_iter)
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
@@ -234,11 +339,11 @@ def log(self, stats: Stats) -> None:
"CPU KV cache usage: %.1f%%",
prompt_throughput,
generation_throughput,
- stats.num_running,
- stats.num_swapped,
- stats.num_waiting,
- stats.gpu_cache_usage * 100,
- stats.cpu_cache_usage * 100,
+ stats.num_running_sys,
+ stats.num_swapped_sys,
+ stats.num_waiting_sys,
+ stats.gpu_cache_usage_sys * 100,
+ stats.cpu_cache_usage_sys * 100,
)
# Reset tracked stats for next interval.
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 567fca5709518..0e931ebbb6571 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -442,15 +442,27 @@ def prompt_token_ids(self) -> List[int]:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
- def get_last_latency(self, now: float) -> float:
- """Gets last token latency for Request level timings."""
+ def get_last_latency(self, now: float) -> Optional[float]:
+ """Sets the last token time for Request level timings."""
+ # If still in prefill phase, raise Error.
+ if self.is_prefill():
+ raise ValueError(
+ "seq_group.get_last_latency() should not be called "
+ "if the seq_group is in prefill phase.")
+
+ # Otherwise return token latency.
latency = now - self.metrics.last_token_time
self.metrics.last_token_time = now
return latency
def maybe_set_first_token_time(self, time: float) -> None:
"""Sets the first token time for Request level timings."""
- if self.metrics.first_token_time is None:
+ # Note: in a case where a sequence_group is swapped and
+ # recomputed, the time between iterations is counted
+ # in TPOT, rather than recalculating TTFT (since from the )
+ # POV of the user, there is simply a long generation delay.
+ if (self.metrics.first_token_time is None
+ and self.get_seqs()[0].get_output_len() == 1):
self.metrics.first_token_time = time
def maybe_set_first_scheduled_time(self, time: float) -> None:
From 03dd7d52bfcc4f21ba964a0cfc3fb6e7a47fb242 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Sun, 28 Apr 2024 16:32:07 -0700
Subject: [PATCH 148/413] [CI] clean docker cache for neuron (#4441)
---
.buildkite/run-neuron-test.sh | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh
index 8ba03b78e8dbf..252c0f7fecd12 100644
--- a/.buildkite/run-neuron-test.sh
+++ b/.buildkite/run-neuron-test.sh
@@ -4,6 +4,20 @@ set -e
# Try building the docker image
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
+
+# prune old image and containers to save disk space, and only once a day
+# by using a timestamp file in tmp.
+if [ -f /tmp/neuron-docker-build-timestamp ]; then
+ last_build=$(cat /tmp/neuron-docker-build-timestamp)
+ current_time=$(date +%s)
+ if [ $((current_time - last_build)) -gt 86400 ]; then
+ docker system prune -f
+ echo $current_time > /tmp/neuron-docker-build-timestamp
+ fi
+else
+ echo $(date +%s) > /tmp/neuron-docker-build-timestamp
+fi
+
docker build -t neuron -f Dockerfile.neuron .
# Setup cleanup
From df29793dc73a83f3c86c19de967adffda1a28a93 Mon Sep 17 00:00:00 2001
From: SangBin Cho
Date: Mon, 29 Apr 2024 11:01:26 +0900
Subject: [PATCH 149/413] [mypy][5/N] Support all typing on model executor
(#4427)
---
.github/workflows/mypy.yaml | 2 +-
format.sh | 2 +-
.../lm_format_enforcer_decoding.py | 1 +
vllm/model_executor/layers/linear.py | 12 ++++-
.../layers/quantization/__init__.py | 4 +-
.../layers/quantization/base_config.py | 14 ++++--
.../layers/quantization/squeezellm.py | 5 +-
.../model_executor/layers/rotary_embedding.py | 4 +-
vllm/model_executor/layers/sampler.py | 47 +++++++++++--------
.../model_executor/model_loader/tensorizer.py | 4 +-
10 files changed, 61 insertions(+), 34 deletions(-)
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
index 089c7d18ad6f2..a19be8525f902 100644
--- a/.github/workflows/mypy.yaml
+++ b/.github/workflows/mypy.yaml
@@ -43,8 +43,8 @@ jobs:
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
+ mypy vllm/model_executor --config-file pyproject.toml
# TODO(sang): Fix nested dir
- mypy vllm/model_executor/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
diff --git a/format.sh b/format.sh
index 4ac1842daef0a..bd12e61d77806 100755
--- a/format.sh
+++ b/format.sh
@@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
-mypy vllm/model_executor/*.py --config-file pyproject.toml
+mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
index 0d74a5f8e81ff..d0a5ca5592f9d 100644
--- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
+++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
+ raise AssertionError(f"Unsupported schema type {schema}")
@lru_cache
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 1bd6c42ab3fd8..4d43ed4c5f14a 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -128,7 +128,8 @@ def __init__(
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
- self.quant_method = UnquantizedLinearMethod()
+ self.quant_method: Optional[
+ QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
@@ -160,6 +161,8 @@ def __init__(
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
+ # All the linear layer supports quant method.
+ assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
@@ -173,6 +176,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
+ assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -221,6 +225,8 @@ def __init__(
self.output_size_per_partition = divide(output_size, tp_size)
if output_sizes is None:
output_sizes = [output_size]
+ # All the linear layer supports quant method.
+ assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
@@ -255,6 +261,7 @@ def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
+ assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
@@ -579,6 +586,8 @@ def __init__(
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
+ # All the linear layer supports quant method.
+ assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
@@ -624,6 +633,7 @@ def forward(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
+ assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index 0820f17c5c50d..70e0a7cfe3e3b 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -1,4 +1,4 @@
-from typing import Type
+from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
@@ -9,7 +9,7 @@
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
-QUANTIZATION_METHODS = {
+QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": Fp8Config,
diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py
index b755b1328504a..ff5cf0b2bd61a 100644
--- a/vllm/model_executor/layers/quantization/base_config.py
+++ b/vllm/model_executor/layers/quantization/base_config.py
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
import torch
from torch import nn
@@ -76,8 +76,16 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"quantization config.")
@abstractmethod
- def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
- """Get the quantize method to use for the quantized layer."""
+ def get_quant_method(
+ self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
+ """Get the quantize method to use for the quantized layer.
+
+ Args:
+ layer: The layer for the quant method.
+ Returns:
+ The quantize method. None if the given layer doesn't support quant
+ method.
+ """
raise NotImplementedError
@abstractmethod
diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py
index 971078fe25a9b..207dbcee8afc5 100644
--- a/vllm/model_executor/layers/quantization/squeezellm.py
+++ b/vllm/model_executor/layers/quantization/squeezellm.py
@@ -52,11 +52,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
return cls(weight_bits)
def get_quant_method(
- self,
- layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
+ self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
- return
+ return None
def get_scaled_act_names(self) -> List[str]:
return []
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index b8361af61ae3f..25365a9b50a1f 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -431,8 +431,8 @@ def forward(
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
- self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
- idx.device)
+ self.long_short_cos_sin_cache: torch.Tensor = (
+ self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index 4ef25edecfd24..d79c99e5d0a45 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -13,6 +13,9 @@
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceGroupOutput, SequenceOutput)
+# (num_token_ids, num_parent_ids) per sequence group.
+SampleResultType = List[Tuple[List[int], List[int]]]
+
class Sampler(nn.Module):
"""Samples the next tokens from the model's outputs.
@@ -155,7 +158,7 @@ def _apply_min_tokens_penalty(
have not been generated yet
"""
# list of indices in logits that will be set to -inf
- logits_to_penalize = []
+ logits_to_penalize: List[Tuple[int, int]] = []
logits_applied = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
@@ -269,7 +272,7 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
-) -> List[Tuple[List[int], List[int]]]:
+) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
@@ -284,7 +287,7 @@ def _greedy_sample(
"""
samples = samples.tolist()
sample_idx = 0
- results = []
+ results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
@@ -304,7 +307,7 @@ def _greedy_sample(
def _random_sample(
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
-) -> List[Tuple[List[int], List[int]]]:
+) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
@@ -320,7 +323,7 @@ def _random_sample(
# Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
- results = []
+ results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
@@ -348,7 +351,7 @@ def _random_sample(
def _beam_search_sample(
selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
-) -> List[Tuple[List[int], List[int]]]:
+) -> SampleResultType:
"""Run beam sampling on a given samples.
Args:
@@ -370,7 +373,7 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
- results = []
+ results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
@@ -391,16 +394,16 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
- cumulative_logprobs = [
+ cumulative_logprobs: List[int] = [
seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
- cumulative_logprobs = torch.tensor(
+ cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
- cumulative_logprobs.unsqueeze(dim=1))
+ cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
@@ -452,8 +455,10 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
- categorized_seq_group_ids = {t: [] for t in SamplingType}
+) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
+ categorized_seq_group_ids: Dict[SamplingType,
+ List[int]] = {t: []
+ for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
@@ -555,8 +560,10 @@ def _sample_with_triton_kernel(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
-) -> List[Tuple[List[int], List[int]]]:
- categorized_seq_group_ids = {t: [] for t in SamplingType}
+) -> SampleResultType:
+ categorized_seq_group_ids: Dict[SamplingType,
+ List[int]] = {t: []
+ for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
@@ -632,7 +639,7 @@ def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
-) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
+) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
@@ -680,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
- sample_results: List[Tuple[List[int], List[int]]],
+ sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs.
@@ -751,8 +758,8 @@ def _get_logprobs(
assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0:
- empty_sampled_logprob = []
- empty_prompt_logprob = None
+ empty_sampled_logprob: SampleLogprobs = []
+ empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
@@ -965,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output(
- sample_results: List[Tuple[List[int], List[int]]],
+ sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
@@ -1009,7 +1016,7 @@ def _build_sampler_output(
)
-def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
+def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py
index 8fc6d16672117..2d654b2fefb8d 100644
--- a/vllm/model_executor/model_loader/tensorizer.py
+++ b/vllm/model_executor/model_loader/tensorizer.py
@@ -64,7 +64,7 @@ def _construct_tensorizer_args(self) -> "TensorizerArgs":
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
- return TensorizerArgs(**tensorizer_args)
+ return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config(
self,
@@ -270,8 +270,10 @@ def __init__(self, tensorizer_config: TensorizerConfig,
self.model = self._init_model()
def _init_model(self):
+ assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
+ assert self.tensorizer_config.model_class is not None
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
From 73c8d677e57e42374bcfb2271b8f1cf7f2c0a486 Mon Sep 17 00:00:00 2001
From: Robert Shaw
<114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Date: Mon, 29 Apr 2024 12:35:34 -0400
Subject: [PATCH 150/413] [Kernel] Marlin Expansion: Support AutoGPTQ Models
with Marlin (#3922)
Co-authored-by: alexm
Co-authored-by: mgoin
---
CMakeLists.txt | 2 +
csrc/ops.h | 18 +
csrc/pybind.cpp | 2 +
csrc/quantization/gptq_marlin/gptq_marlin.cu | 1520 +++++++++++++++++
csrc/quantization/gptq_marlin/gptq_marlin.cuh | 74 +
.../gptq_marlin/gptq_marlin_repack.cu | 324 ++++
tests/models/test_gptq_marlin.py | 93 +
tests/models/test_marlin.py | 45 +-
tests/models/utils.py | 29 +
.../test_autogptq_marlin_configs.py | 64 -
tests/quantization/test_configs.py | 73 +
vllm/config.py | 39 +-
.../layers/quantization/__init__.py | 3 +
.../layers/quantization/gptq_marlin.py | 444 +++++
14 files changed, 2626 insertions(+), 104 deletions(-)
create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin.cu
create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin.cuh
create mode 100644 csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
create mode 100644 tests/models/test_gptq_marlin.py
create mode 100644 tests/models/utils.py
delete mode 100644 tests/quantization/test_autogptq_marlin_configs.py
create mode 100644 tests/quantization/test_configs.py
create mode 100644 vllm/model_executor/layers/quantization/gptq_marlin.py
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e9262b57d0867..1558dbf313ce7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
+ "csrc/quantization/gptq_marlin/gptq_marlin.cu"
+ "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu")
endif()
diff --git a/csrc/ops.h b/csrc/ops.h
index 03bb1e24dc68e..04b97d1784cd2 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -124,6 +124,24 @@ torch::Tensor marlin_gemm(
int64_t size_m,
int64_t size_n,
int64_t size_k);
+
+torch::Tensor gptq_marlin_gemm(
+ torch::Tensor &a,
+ torch::Tensor &b_q_weight,
+ torch::Tensor &b_scales,
+ torch::Tensor &g_idx,
+ torch::Tensor &perm,
+ torch::Tensor &workspace,
+ int64_t size_m,
+ int64_t size_n,
+ int64_t size_k,
+ bool is_k_full);
+
+torch::Tensor gptq_marlin_repack(
+ torch::Tensor &b_q_weight,
+ torch::Tensor &perm,
+ int64_t size_k,
+ int64_t size_n);
#endif
void squeezellm_gemm(
diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp
index 2250c7f69f0ab..9839bfc0331c4 100644
--- a/csrc/pybind.cpp
+++ b/csrc/pybind.cpp
@@ -67,6 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
+ ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
+ ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu
new file mode 100644
index 0000000000000..9902f55167d89
--- /dev/null
+++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu
@@ -0,0 +1,1520 @@
+/*
+ * Modified by Neural Magic
+ * Copyright (C) Marlin.2024 Elias Frantar
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Adapted from https://github.com/IST-DASLab/marlin
+ */
+
+#include "gptq_marlin.cuh"
+
+template inline std::string str(T x) { return std::to_string(x); }
+
+namespace gptq_marlin {
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+
+__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
+ int const *__restrict__ perm_int_ptr,
+ int4 *__restrict__ out_int4_ptr, int size_m,
+ int size_k, int block_rows) {}
+
+template shared
+ // fetch pipeline
+ const bool has_act_order, // whether act_order is enabled
+ const int group_blocks = -1 // number of consecutive 16x16 blocks with
+ // a separate quantization scale
+ >
+__global__ void
+Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
+ const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
+ int4 *__restrict__ C, // fp16 output buffer of shape mxn
+ const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape
+ // (k/groupsize)xn
+ const int *__restrict__ g_idx, // int32 group indices of shape k
+ int num_groups, // number of scale groups per output channel
+ int prob_m, // batch dimension m
+ int prob_n, // output dimension n
+ int prob_k, // reduction dimension k
+ int *locks // extra global storage for barrier synchronization
+) {}
+
+} // namespace gptq_marlin
+
+torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
+ torch::Tensor &b_scales, torch::Tensor &g_idx,
+ torch::Tensor &perm, torch::Tensor &workspace,
+ int64_t size_m, int64_t size_n, int64_t size_k,
+ bool is_k_full) {
+ TORCH_CHECK_NOT_IMPLEMENTED(false,
+ "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
+ return torch::empty({1, 1});
+}
+
+#else
+
+// Matrix fragments for tensor core instructions; their precise layout is
+// documented here:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+using FragA = Vec;
+using FragB = Vec;
+using FragC = Vec;
+using FragS = Vec; // quantization scales
+
+// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
+// output/accumulation.
+__device__ inline void mma(const FragA &a_frag, const FragB &frag_b,
+ FragC &frag_c) {
+ const uint32_t *a = reinterpret_cast(&a_frag);
+ const uint32_t *b = reinterpret_cast(&frag_b);
+ float *c = reinterpret_cast(&frag_c);
+ asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
+ "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
+ : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]),
+ "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
+}
+
+// Instruction for loading a full 16x16 matrix fragment of operand A from shared
+// memory, directly in tensor core layout.
+__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
+ uint32_t *a = reinterpret_cast(&frag_a);
+ uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
+ : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
+ : "r"(smem));
+}
+
+// Lookup-table based 3-input logical operation; explicitly used for
+// dequantization as the compiler does not seem to automatically recognize it in
+// all cases.
+template __device__ inline int lop3(int a, int b, int c) {
+ int res;
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+ : "=r"(res)
+ : "r"(a), "r"(b), "r"(c), "n"(lut));
+ return res;
+}
+
+// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
+// values. We mostly follow the strategy in the link below, with some small
+// changes:
+// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+__device__ inline FragB dequant(int q) {
+ const int LO = 0x000f000f;
+ const int HI = 0x00f000f0;
+ const int EX = 0x64006400;
+ // Guarantee that the `(a & b) | c` operations are LOP3s.
+ int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
+ int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
+ // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
+ // directly into `SUB` and `ADD`.
+ const int SUB = 0x64086408;
+ const int MUL = 0x2c002c00;
+ const int ADD = 0xd480d480;
+ FragB frag_b;
+ frag_b[0] = __hsub2(*reinterpret_cast(&lo),
+ *reinterpret_cast(&SUB));
+ frag_b[1] = __hfma2(*reinterpret_cast(&hi),
+ *reinterpret_cast(&MUL),
+ *reinterpret_cast(&ADD));
+ return frag_b;
+}
+
+// Multiply dequantized values by the corresponding quantization scale; used
+// only for grouped quantization.
+__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
+ half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
+ frag_b[0] = __hmul2(frag_b[0], s);
+ frag_b[1] = __hmul2(frag_b[1], s);
+}
+
+// Same as above, but for act_order (each K is multiplied individually)
+__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
+ FragS &frag_s_3, FragS &frag_s_4, int i) {
+ __half2 s_val_1_2;
+ s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i];
+ s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i];
+
+ __half2 s_val_3_4;
+ s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i];
+ s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i];
+
+ frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
+ frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
+}
+
+// Wait until barrier reaches `count`, then lock for current threadblock.
+__device__ inline void barrier_acquire(int *lock, int count) {
+ if (threadIdx.x == 0) {
+ int state = -1;
+ do
+ // Guarantee that subsequent writes by this threadblock will be visible
+ // globally.
+ asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
+ : "=r"(state)
+ : "l"(lock));
+ while (state != count);
+ }
+ __syncthreads();
+}
+
+// Release barrier and increment visitation count.
+__device__ inline void barrier_release(int *lock, bool reset = false) {
+ __syncthreads();
+ if (threadIdx.x == 0) {
+ if (reset) {
+ lock[0] = 0;
+ return;
+ }
+ int val = 1;
+ // Make sure that all writes since acquiring this barrier are visible
+ // globally, while releasing the barrier.
+ asm volatile("fence.acq_rel.gpu;\n");
+ asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
+ :
+ : "l"(lock), "r"(val));
+ }
+}
+
+// For a given "a" of size [M,K] performs a permutation of the K columns based
+// on the given "perm" indices.
+__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
+ int const *__restrict__ perm_int_ptr,
+ int4 *__restrict__ out_int4_ptr, int size_m,
+ int size_k, int block_rows) {
+
+ int start_row = block_rows * blockIdx.x;
+ int finish_row = start_row + block_rows;
+ if (finish_row > size_m) {
+ finish_row = size_m;
+ }
+ int cur_block_rows = finish_row - start_row;
+
+ int row_stride = size_k * sizeof(half) / 16;
+
+ auto permute_row = [&](int row) {
+ int iters = size_k / default_threads;
+ int rest = size_k % default_threads;
+
+ int offset = row * row_stride;
+
+ half const *a_row_half =
+ reinterpret_cast(a_int4_ptr + offset);
+ half *out_half = reinterpret_cast(out_int4_ptr + offset);
+
+ int base_k = 0;
+
+ for (int i = 0; i < iters; i++) {
+ int cur_k = base_k + threadIdx.x;
+ int src_pos = perm_int_ptr[cur_k];
+
+ out_half[cur_k] = a_row_half[src_pos];
+
+ base_k += default_threads;
+ }
+
+ if (rest) {
+ if (threadIdx.x < rest) {
+ int cur_k = base_k + threadIdx.x;
+ int src_pos = perm_int_ptr[cur_k];
+
+ out_half[cur_k] = a_row_half[src_pos];
+ }
+ }
+ };
+
+ for (int i = 0; i < cur_block_rows; i++) {
+ int cur_row = start_row + i;
+ if (cur_row < size_m) {
+ permute_row(cur_row);
+ }
+ }
+}
+
+template shared
+ // fetch pipeline
+ const bool has_act_order, // whether act_order is enabled
+ const int group_blocks = -1 // number of consecutive 16x16 blocks with
+ // a separate quantization scale
+ >
+__global__ void
+Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
+ const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn
+ int4 *__restrict__ C, // fp16 output buffer of shape mxn
+ const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape
+ // (k/groupsize)xn
+ const int *__restrict__ g_idx, // int32 group indices of shape k
+ int num_groups, // number of scale groups per output channel
+ int prob_m, // batch dimension m
+ int prob_n, // output dimension n
+ int prob_k, // reduction dimension k
+ int *locks // extra global storage for barrier synchronization
+) {
+ // Each threadblock processes one "stripe" of the B matrix with (roughly) the
+ // same size, which might involve multiple column "slices" (of width 16 *
+ // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
+ // example:
+ // 0 1 3
+ // 0 2 3
+ // 1 2 4
+ // While this kind of partitioning makes things somewhat more complicated, it
+ // ensures good utilization of all SMs for many kinds of shape and GPU
+ // configurations, while requiring as few slow global cross-threadblock
+ // reductions as possible.
+
+ // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
+ // better partitioning with less reductions
+ int parallel = 1;
+ if (prob_m > 16 * thread_m_blocks) {
+ parallel = prob_m / (16 * thread_m_blocks);
+ prob_m = 16 * thread_m_blocks;
+ }
+
+ int k_tiles = prob_k / 16 / thread_k_blocks;
+ int n_tiles = prob_n / 16 / thread_n_blocks;
+ int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
+
+ if constexpr (!has_act_order && group_blocks != -1) {
+ if (group_blocks >= thread_k_blocks) {
+ // Ensure that the number of tiles in each stripe is a multiple of the
+ // groupsize; this avoids an annoying special case where a stripe starts
+ // in the middle of group.
+ iters = (group_blocks / thread_k_blocks) *
+ div_ceil(iters, (group_blocks / thread_k_blocks));
+ }
+ }
+
+ int slice_row = (iters * blockIdx.x) % k_tiles;
+ int slice_col_par = (iters * blockIdx.x) / k_tiles;
+ int slice_col = slice_col_par;
+ int slice_iters; // number of threadblock tiles in the current slice
+ int slice_count =
+ 0; // total number of active threadblocks in the current slice
+ int slice_idx; // index of threadblock in current slice; numbered bottom to
+ // top
+
+ // We can easily implement parallel problem execution by just remapping
+ // indices and advancing global pointers
+ if (slice_col_par >= n_tiles) {
+ A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
+ C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
+ locks += (slice_col_par / n_tiles) * n_tiles;
+ slice_col = slice_col_par % n_tiles;
+ }
+
+ // Compute all information about the current slice which is required for
+ // synchronization.
+ auto init_slice = [&]() {
+ slice_iters =
+ iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
+ if (slice_iters < 0 || slice_col_par >= n_tiles * parallel)
+ slice_iters = 0;
+ if (slice_iters == 0)
+ return;
+ if (slice_row + slice_iters > k_tiles)
+ slice_iters = k_tiles - slice_row;
+ slice_count = 1;
+ slice_idx = 0;
+ int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
+ if (col_first <= k_tiles * (slice_col_par + 1)) {
+ int col_off = col_first - k_tiles * slice_col_par;
+ slice_count = div_ceil(k_tiles - col_off, iters);
+ if (col_off > 0)
+ slice_count++;
+ int delta_first = iters * blockIdx.x - col_first;
+ if (delta_first < 0 || (col_off == 0 && delta_first == 0))
+ slice_idx = slice_count - 1;
+ else {
+ slice_idx = slice_count - 1 - delta_first / iters;
+ if (col_off > 0)
+ slice_idx--;
+ }
+ }
+ if (slice_col == n_tiles) {
+ A += 16 * thread_m_blocks * prob_k / 8;
+ C += 16 * thread_m_blocks * prob_n / 8;
+ locks += n_tiles;
+ slice_col = 0;
+ }
+ };
+ init_slice();
+
+ // A sizes/strides
+
+ // stride of the A matrix in global memory
+ int a_gl_stride = prob_k / 8;
+ // stride of an A matrix tile in shared memory
+ constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
+ // delta between subsequent A tiles in global memory
+ constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
+ // between subsequent accesses within a tile
+ int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
+ // between shared memory writes
+ constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
+ // between shared memory tile reads
+ constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
+ // within a shared memory tile
+ constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
+ // overall size of a tile
+ constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
+ // number of shared write iterations for a tile
+ constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
+
+ // B sizes/strides
+ int b_gl_stride = 16 * prob_n / 32;
+ constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
+ int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
+ int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
+ constexpr int b_sh_wr_delta = threads;
+ constexpr int b_sh_rd_delta = threads;
+ constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
+ constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
+
+ // Scale sizes/strides without act_order
+ int s_gl_stride = prob_n / 8;
+ constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
+ constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks
+ ? thread_k_blocks / group_blocks
+ : 1;
+ constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
+ int s_gl_rd_delta = s_gl_stride;
+
+ // Scale size/strides with act_order
+ constexpr int tb_k = 16 * thread_k_blocks;
+ constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
+ // constexpr int act_s_row_stride = 1;
+ // int act_s_col_stride = act_s_row_stride * num_groups;
+ int act_s_col_stride = 1;
+ int act_s_col_warp_stride = act_s_col_stride * 8;
+ int tb_n_warps = thread_n_blocks / 4;
+ int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
+
+ // Global A read index of current thread.
+ int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
+ (threadIdx.x % a_gl_rd_delta_o);
+ a_gl_rd += a_gl_rd_delta_o * slice_row;
+ // Shared write index of current thread.
+ int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
+ (threadIdx.x % a_gl_rd_delta_o);
+ // Shared read index.
+ int a_sh_rd =
+ a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
+ a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
+
+ int b_gl_rd =
+ b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
+ b_gl_rd += b_sh_stride * slice_col;
+ b_gl_rd += b_gl_rd_delta_o * slice_row;
+ int b_sh_wr = threadIdx.x;
+ int b_sh_rd = threadIdx.x;
+
+ // For act_order
+ constexpr int k_iter_size = tb_k / b_sh_wr_iters;
+ int slice_k_start = tb_k * slice_row;
+ int slice_k_finish = slice_k_start + tb_k * slice_iters;
+ int slice_k_start_shared_fetch = slice_k_start;
+ int slice_n_offset = act_s_col_tb_stride * slice_col;
+
+ // No act_order
+ int s_gl_rd;
+ if constexpr (!has_act_order) {
+ s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
+ s_sh_stride * slice_col + threadIdx.x;
+ }
+ int s_sh_wr = threadIdx.x;
+ bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
+
+ // We use a different scale layout for grouped and column-wise quantization as
+ // we scale a `half2` tile in column-major layout in the former and in
+ // row-major in the latter case.
+ int s_sh_rd;
+ if constexpr (group_blocks != -1)
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+ (threadIdx.x % 32) / 4;
+ else
+ s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
+ (threadIdx.x % 32) % 4;
+
+ // Precompute which thread should not read memory in which iterations; this is
+ // needed if there are more threads than required for a certain tilesize or
+ // when the batchsize is not a multiple of 16.
+ bool a_sh_wr_pred[a_sh_wr_iters];
+#pragma unroll
+ for (int i = 0; i < a_sh_wr_iters; i++)
+ a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
+
+ // To ensure that writing and reading A tiles to/from shared memory, the
+ // latter in fragment format, is fully bank conflict free, we need to use a
+ // rather fancy XOR-based layout. The key here is that neither reads nor
+ // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
+ // same shared memory banks. Further, it seems (based on NSight-Compute) that
+ // each warp must also write a consecutive memory segment?
+ auto transform_a = [&](int i) {
+ int row = i / a_gl_rd_delta_o;
+ return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
+ };
+ // Since the computation of this remapping is non-trivial and, due to our main
+ // loop unrolls, all shared memory accesses are static, we simply precompute
+ // both transformed reads and writes.
+ int a_sh_wr_trans[a_sh_wr_iters];
+#pragma unroll
+ for (int i = 0; i < a_sh_wr_iters; i++)
+ a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
+ int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
+#pragma unroll
+ for (int i = 0; i < b_sh_wr_iters; i++) {
+#pragma unroll
+ for (int j = 0; j < thread_m_blocks; j++)
+ a_sh_rd_trans[i][j] =
+ transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
+ }
+
+ // Since B-accesses have non-constant stride they have to be computed at
+ // runtime; we break dependencies between subsequent accesses with a tile by
+ // maintining multiple pointers (we have enough registers), a tiny
+ // optimization.
+ const int4 *B_ptr[b_sh_wr_iters];
+#pragma unroll
+ for (int i = 0; i < b_sh_wr_iters; i++)
+ B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
+
+ extern __shared__ int4 sh[];
+ // Shared memory storage for global fetch pipelines.
+ int4 *sh_a = sh;
+ int4 *sh_b = sh_a + (stages * a_sh_stage);
+ int4 *sh_g_idx = sh_b + (stages * b_sh_stage);
+ int4 *sh_s = sh_g_idx + (stages * g_idx_stage);
+
+ // Register storage for double buffer of shared memory reads.
+ FragA frag_a[2][thread_m_blocks];
+ I4 frag_b_quant[2];
+ FragC frag_c[thread_m_blocks][4][2];
+ FragS frag_s[2][4]; // No act-order
+ FragS act_frag_s[2][4][4]; // For act-order
+
+ // Zero accumulators.
+ auto zero_accums = [&]() {
+#pragma unroll
+ for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
+ reinterpret_cast(frag_c)[i] = 0;
+ };
+
+ int sh_first_group_id = -1;
+ int sh_num_groups = -1;
+ constexpr int sh_max_num_groups = 32;
+
+ auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
+ int last_group_id) {
+ sh_first_group_id = first_group_id;
+ sh_num_groups = last_group_id - first_group_id + 1;
+
+ if (sh_num_groups < sh_max_num_groups) {
+ sh_num_groups = sh_max_num_groups;
+ }
+
+ if (sh_first_group_id + sh_num_groups > num_groups) {
+ sh_num_groups = num_groups - sh_first_group_id;
+ }
+
+ int row_offset = first_group_id * s_gl_stride;
+
+ if (is_async) {
+ for (int i = 0; i < sh_num_groups; i++) {
+ if (threadIdx.x < s_sh_stride) {
+ cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
+ &scales_ptr[row_offset + (i * s_gl_stride) +
+ slice_n_offset + threadIdx.x]);
+ }
+ }
+ } else {
+ for (int i = 0; i < sh_num_groups; i++) {
+ if (threadIdx.x < s_sh_stride) {
+ sh_s[(i * s_sh_stride) + threadIdx.x] =
+ scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
+ threadIdx.x];
+ }
+ }
+ }
+ };
+ // Asynchronously fetch the next A, B and s tile from global to the next
+ // shared memory pipeline location.
+ auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
+ if (pred) {
+ int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
+#pragma unroll
+ for (int i = 0; i < a_sh_wr_iters; i++) {
+ cp_async4_pred(
+ &sh_a_stage[a_sh_wr_trans[i]],
+ &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
+ a_sh_wr_pred[i]);
+ }
+ int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
+#pragma unroll
+ for (int i = 0; i < b_sh_wr_iters; i++) {
+ cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
+ B_ptr[i] += b_gl_rd_delta_o;
+ }
+
+ if constexpr (has_act_order) {
+ // Fetch g_idx thread-block portion
+ int full_pipe = a_off;
+ int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
+ if (cur_k < prob_k && cur_k < slice_k_finish) {
+ int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
+
+ int4 const *cur_g_idx_stage_ptr =
+ reinterpret_cast(&g_idx[cur_k]);
+
+ if (threadIdx.x < g_idx_stage) {
+ cp_async4_pred(&sh_g_idx_stage[threadIdx.x],
+ &cur_g_idx_stage_ptr[threadIdx.x]);
+ }
+ }
+ } else {
+ if constexpr (group_blocks != -1) {
+ int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
+
+ if constexpr (group_blocks >= thread_k_blocks) {
+ // Only fetch scales if this tile starts a new group
+ if (pipe % (group_blocks / thread_k_blocks) == 0) {
+ if (s_sh_wr_pred) {
+ cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
+ }
+ s_gl_rd += s_gl_rd_delta;
+ }
+ } else {
+ for (int i = 0; i < s_tb_groups; i++) {
+ if (s_sh_wr_pred) {
+ cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr],
+ &scales_ptr[s_gl_rd]);
+ }
+ s_gl_rd += s_gl_rd_delta;
+ }
+ }
+ }
+ }
+ }
+ // Insert a fence even when we are winding down the pipeline to ensure that
+ // waiting is also correct at this point.
+ cp_async_fence();
+ };
+
+ // Wait until the next thread tile has been loaded to shared memory.
+ auto wait_for_stage = [&]() {
+ // We only have `stages - 2` active fetches since we are double buffering
+ // and can only issue the next fetch when it is guaranteed that the previous
+ // shared memory load is fully complete (as it may otherwise be
+ // overwritten).
+ cp_async_wait();
+ __syncthreads();
+ };
+
+ // Load the next sub-tile from the current location in the shared memory pipe
+ // into the current register buffer.
+ auto fetch_to_registers = [&](int k, int pipe) {
+ int4 *sh_a_stage = sh_a + a_sh_stage * pipe;
+#pragma unroll
+ for (int i = 0; i < thread_m_blocks; i++)
+ ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
+ int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
+ frag_b_quant[k % 2] = *reinterpret_cast(
+ &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
+ };
+
+ bool is_same_group[stages];
+ int same_group_id[stages];
+
+ auto init_same_group = [&](int pipe) {
+ int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
+ int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage);
+
+ int group_id_1 = sh_g_idx_int_ptr[0];
+ int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
+
+ is_same_group[pipe] = group_id_1 == group_id_2;
+ same_group_id[pipe] = group_id_1;
+ };
+
+ auto fetch_scales_to_registers = [&](int k, int full_pipe) {
+ int pipe = full_pipe % stages;
+
+ if constexpr (!has_act_order) {
+ // No act-order case
+ if constexpr (group_blocks != -1) {
+ if constexpr (group_blocks >= thread_k_blocks) {
+ int4 *sh_s_stage =
+ sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
+ (pipe / (group_blocks / thread_k_blocks)));
+ reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
+ } else {
+ int warp_id = threadIdx.x / 32;
+ int n_warps = thread_n_blocks / 4;
+
+ int warp_row = warp_id / n_warps;
+
+ int cur_k = warp_row * 16;
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
+
+ int k_blocks = cur_k / 16;
+ int cur_group_id = k_blocks / group_blocks;
+
+ int4 *sh_s_stage = sh_s + s_sh_stage * pipe;
+
+ reinterpret_cast(&frag_s[k % 2])[0] =
+ sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
+ }
+ }
+
+ return;
+ }
+
+ // Act-order case
+
+ // Determine K of the "current" thread-block
+ int cur_k = slice_k_start + tb_k * full_pipe;
+ if (cur_k >= prob_k || cur_k >= slice_k_finish) {
+ return;
+ }
+
+ // Reset (to current thread-block) since we read g_idx portion from the
+ // shared memory
+ cur_k = 0;
+
+ // Progress to current iteration
+ cur_k += k_iter_size * (k % b_sh_wr_iters);
+
+ // Determine "position" inside the thread-block (based on warp and
+ // thread-id)
+ int warp_id = threadIdx.x / 32;
+ int n_warps =
+ thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
+
+ int warp_row = warp_id / n_warps;
+ int warp_col = warp_id % n_warps;
+
+ cur_k += warp_row * 16;
+
+ int th_id = threadIdx.x % 32;
+ cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
+
+ int s_col_shift =
+ /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +
+ (th_id / 4) * act_s_col_stride;
+
+ if (is_same_group[pipe]) {
+ if (k % 2 == 0) {
+ *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) =
+ sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +
+ s_col_shift];
+ } else {
+ *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) =
+ *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0])));
+ }
+
+ for (int i = 1; i < 4; i++) {
+ *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) =
+ *(reinterpret_cast(&(act_frag_s[k % 2][0][0])));
+ }
+ return;
+ }
+
+ int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
+ int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage);
+
+ constexpr int k_frag_offsets[4] = {0, 1, 8,
+ 9}; // Tensor core offsets per thread
+
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+
+ int actual_k = cur_k + k_frag_offsets[i];
+
+ int group_id = sh_g_idx_int_ptr[actual_k];
+ int rel_group_id = group_id - sh_first_group_id;
+
+ *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) =
+ sh_s[rel_group_id * s_sh_stride + s_col_shift];
+ }
+ };
+
+ // Execute the actual tensor core matmul of a sub-tile.
+ auto matmul = [&](int k) {
+// We have the m dimension as the inner loop in order to encourage overlapping
+// dequantization and matmul operations.
+#pragma unroll
+ for (int j = 0; j < 4; j++) {
+ int b_quant = frag_b_quant[k % 2][j];
+ int b_quant_shift = b_quant >> 8;
+
+ FragB frag_b0 = dequant(b_quant);
+
+ // Apply scale to frag_b0
+ if constexpr (has_act_order) {
+ scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
+ act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0);
+ } else {
+ if constexpr (group_blocks != -1) {
+ scale(frag_b0, frag_s[k % 2][j], 0);
+ }
+ }
+
+ FragB frag_b1 = dequant(b_quant_shift);
+
+ // Apply scale to frag_b1
+ if constexpr (has_act_order) {
+ scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
+ act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1);
+
+ } else {
+ if constexpr (group_blocks != -1) {
+ scale(frag_b1, frag_s[k % 2][j], 1);
+ }
+ }
+
+#pragma unroll
+ for (int i = 0; i < thread_m_blocks; i++) {
+ mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
+ mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
+ }
+ }
+ };
+
+ // Since we slice across the k dimension of a tile in order to increase the
+ // number of warps while keeping the n dimension of a tile reasonable, we have
+ // multiple warps that accumulate their partial sums of the same output
+ // location; which we have to reduce over in the end. We do in shared memory.
+ auto thread_block_reduce = [&]() {
+ constexpr int red_off = threads / b_sh_stride / 2;
+ if (red_off >= 1) {
+ int red_idx = threadIdx.x / b_sh_stride;
+ constexpr int red_sh_stride = b_sh_stride * 4 * 2;
+ constexpr int red_sh_delta = b_sh_stride;
+ int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
+ (threadIdx.x % b_sh_stride);
+
+ // Parallel logarithmic shared memory reduction. We make sure to avoid any
+ // unnecessary read or write iterations, e.g., for two warps we write only
+ // once by warp 1 and read only once by warp 0.
+
+#pragma unroll
+ for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
+#pragma unroll
+ for (int i = red_off; i > 0; i /= 2) {
+ if (i <= red_idx && red_idx < 2 * i) {
+#pragma unroll
+ for (int j = 0; j < 4 * 2; j++) {
+ int red_sh_wr =
+ red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
+ if (i < red_off) {
+ float *c_rd = reinterpret_cast(
+ &sh[red_sh_delta * j + red_sh_rd]);
+ float *c_wr = reinterpret_cast(&sh[red_sh_wr]);
+#pragma unroll
+ for (int k = 0; k < 4; k++)
+ reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] +=
+ c_rd[k] + c_wr[k];
+ }
+ sh[red_sh_wr] =
+ reinterpret_cast(&frag_c)[4 * 2 * m_block + j];
+ }
+ }
+ __syncthreads();
+ }
+ if (red_idx == 0) {
+#pragma unroll
+ for (int i = 0; i < 4 * 2; i++) {
+ float *c_rd =
+ reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]);
+#pragma unroll
+ for (int j = 0; j < 4; j++)
+ reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] +=
+ c_rd[j];
+ }
+ }
+ __syncthreads();
+ }
+ }
+ };
+
+ // Since multiple threadblocks may process parts of the same column slice, we
+ // finally have to globally reduce over the results. As the striped portioning
+ // minimizes the number of such reductions and our outputs are usually rather
+ // small, we perform this reduction serially in L2 cache.
+ auto global_reduce = [&](bool first = false, bool last = false) {
+ // We are very careful here to reduce directly in the output buffer to
+ // maximize L2 cache utilization in this step. To do this, we write out
+ // results in FP16 (but still reduce with FP32 compute).
+ constexpr int active_threads = 32 * thread_n_blocks / 4;
+ if (threadIdx.x < active_threads) {
+ int c_gl_stride = prob_n / 8;
+ int c_gl_wr_delta_o = 8 * c_gl_stride;
+ int c_gl_wr_delta_i = 4 * (active_threads / 32);
+ int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
+ 4 * (threadIdx.x / 32) + threadIdx.x % 4;
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
+ constexpr int c_sh_wr_delta = active_threads;
+ int c_sh_wr = threadIdx.x;
+
+ int row = (threadIdx.x % 32) / 4;
+
+ if (!first) {
+// Interestingly, doing direct global accesses here really seems to mess up the
+// compiler and lead to slowdowns, hence we also use async-copies even though
+// these fetches are not actually asynchronous.
+#pragma unroll
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
+ cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i],
+ &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
+ c_gl_wr_delta_i * (i % 2)],
+ i < (thread_m_blocks - 1) * 4 ||
+ 8 * (i / 2) + row < prob_m);
+ }
+ cp_async_fence();
+ cp_async_wait<0>();
+ }
+
+#pragma unroll
+ for (int i = 0; i < thread_m_blocks * 4; i++) {
+ if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
+ if (!first) {
+ int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
+#pragma unroll
+ for (int j = 0; j < 2 * 4; j++) {
+ reinterpret_cast(
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
+ __half2float(reinterpret_cast<__half *>(&c_red)[j]);
+ }
+ }
+ if (!last) {
+ int4 c;
+#pragma unroll
+ for (int j = 0; j < 2 * 4; j++) {
+ reinterpret_cast<__half *>(&c)[j] =
+ __float2half(reinterpret_cast(
+ &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
+ }
+ C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
+ c;
+ }
+ }
+ }
+ }
+ };
+
+ // Write out the reduce final result in the correct layout. We only actually
+ // reshuffle matrix fragments in this step, the reduction above is performed
+ // in fragment layout.
+ auto write_result = [&]() {
+ int c_gl_stride = prob_n / 8;
+ constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
+ int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
+ constexpr int c_sh_rd_delta =
+ c_sh_stride * (threads / (2 * thread_n_blocks));
+
+ int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
+ (threadIdx.x % (2 * thread_n_blocks));
+ c_gl_wr += (2 * thread_n_blocks) * slice_col;
+ int c_sh_wr =
+ (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
+ c_sh_wr += 32 * (threadIdx.x / 32);
+ int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
+ (threadIdx.x % (2 * thread_n_blocks));
+
+ int c_gl_wr_end = c_gl_stride * prob_m;
+
+ // We first reorder in shared memory to guarantee the most efficient final
+ // global write patterns
+ auto write = [&](int idx, float c0, float c1, FragS &s) {
+ half2 res = __halves2half2(__float2half(c0), __float2half(c1));
+
+ // For per-column quantization we finally apply the scale here
+ if constexpr (!has_act_order && group_blocks == -1) {
+ res = __hmul2(res, s[0]);
+ }
+
+ ((half2 *)sh)[idx] = res;
+ };
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
+#pragma unroll
+ for (int i = 0; i < thread_m_blocks; i++) {
+#pragma unroll
+ for (int j = 0; j < 4; j++) {
+ int wr = c_sh_wr + 8 * j;
+ write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
+ frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
+ write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
+ frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
+ write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
+ frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
+ write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
+ frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
+ }
+ c_sh_wr += 16 * (4 * c_sh_stride);
+ }
+ }
+ __syncthreads();
+
+#pragma unroll
+ for (int i = 0;
+ i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
+ i++) {
+ if (c_gl_wr < c_gl_wr_end) {
+ C[c_gl_wr] = sh[c_sh_rd];
+ c_gl_wr += c_gl_wr_delta;
+ c_sh_rd += c_sh_rd_delta;
+ }
+ }
+ };
+
+ // Start global fetch and register load pipelines.
+ auto start_pipes = [&]() {
+
+#pragma unroll
+ for (int i = 0; i < stages - 1; i++) {
+ if (has_act_order && i == 0) {
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
+ if (last_g_idx >= prob_k) {
+ last_g_idx = prob_k - 1;
+ }
+ fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
+ }
+ fetch_to_shared(i, i, i < slice_iters);
+ }
+
+ zero_accums();
+ wait_for_stage();
+ init_same_group(0);
+ fetch_to_registers(0, 0);
+ fetch_scales_to_registers(0, 0);
+ a_gl_rd += a_gl_rd_delta_o * (stages - 1);
+ slice_k_start_shared_fetch += tb_k * (stages - 1);
+ };
+ if (slice_iters) {
+ start_pipes();
+ }
+
+ // Main loop.
+ while (slice_iters) {
+ // We unroll over both the global fetch and the register load pipeline to
+ // ensure all shared memory accesses are static. Note that both pipelines
+ // have even length meaning that the next iteration will always start at
+ // index 0.
+#pragma unroll
+ for (int pipe = 0; pipe < stages;) {
+#pragma unroll
+ for (int k = 0; k < b_sh_wr_iters; k++) {
+ fetch_to_registers(k + 1, pipe % stages);
+ fetch_scales_to_registers(k + 1, pipe);
+ if (k == b_sh_wr_iters - 2) {
+ fetch_to_shared((pipe + stages - 1) % stages, pipe,
+ slice_iters >= stages);
+ pipe++;
+ wait_for_stage();
+ init_same_group(pipe % stages);
+ }
+ matmul(k);
+ }
+ slice_iters--;
+ if (slice_iters == 0) {
+ break;
+ }
+ }
+
+ a_gl_rd += a_gl_rd_delta_o * stages;
+ slice_k_start += tb_k * stages;
+ slice_k_start_shared_fetch += tb_k * stages;
+
+ if constexpr (has_act_order) {
+ int first_group_id = g_idx[slice_k_start];
+ int last_g_idx = slice_k_start + stages * tb_k * 2;
+ if (last_g_idx >= prob_k) {
+ last_g_idx = prob_k - 1;
+ }
+ int last_group_id = g_idx[last_g_idx];
+ if (last_group_id >= sh_first_group_id + sh_num_groups) {
+ fetch_scales_to_shared(false, first_group_id, last_group_id);
+ __syncthreads();
+ }
+ }
+
+ // Process results and, if necessary, proceed to the next column slice.
+ // While this pattern may not be the most readable, other ways of writing
+ // the loop seemed to noticeably worse performance after compilation.
+ if (slice_iters == 0) {
+ cp_async_wait<0>();
+ bool last = slice_idx == slice_count - 1;
+ // For per-column scales, we only fetch them here in the final step before
+ // write-out
+ if constexpr (!has_act_order && group_blocks == -1) {
+ if (last) {
+ if (s_sh_wr_pred) {
+ cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
+ }
+ cp_async_fence();
+ }
+ }
+
+ thread_block_reduce();
+ if constexpr (!has_act_order && group_blocks == -1) {
+ if (last) {
+ cp_async_wait<0>();
+ __syncthreads();
+ if (threadIdx.x / 32 < thread_n_blocks / 4) {
+ reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0];
+ reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4];
+ }
+ }
+ }
+
+ if (slice_count > 1) { // only globally reduce if there is more than one
+ // block in a slice
+ barrier_acquire(&locks[slice_col], slice_idx);
+ global_reduce(slice_idx == 0, last);
+ barrier_release(&locks[slice_col], last);
+ }
+ if (last) // only the last block in a slice actually writes the result
+ write_result();
+ slice_row = 0;
+ slice_col_par++;
+ slice_col++;
+ init_slice();
+ if (slice_iters) {
+ a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
+ (threadIdx.x % a_gl_rd_delta_o);
+#pragma unroll
+ for (int i = 0; i < b_sh_wr_iters; i++)
+ B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
+ if (slice_col == 0) {
+#pragma unroll
+ for (int i = 0; i < b_sh_wr_iters; i++)
+ B_ptr[i] -= b_gl_stride;
+ }
+
+ // Update slice k/n for scales loading
+ if constexpr (has_act_order) {
+ slice_k_start = tb_k * slice_row;
+ slice_k_finish = slice_k_start + tb_k * slice_iters;
+ slice_k_start_shared_fetch = slice_k_start;
+ slice_n_offset = act_s_col_tb_stride * slice_col;
+
+ } else {
+ s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
+ }
+
+ // if (blockIdx.x == 0 && threadIdx.x == 0) {
+ // printf("Move\n");
+ // }
+ start_pipes();
+ }
+ }
+ }
+}
+
+#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
+ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
+ else if (thread_m_blocks == THREAD_M_BLOCKS && \
+ thread_n_blocks == THREAD_N_BLOCKS && \
+ thread_k_blocks == THREAD_K_BLOCKS && \
+ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
+ num_threads == NUM_THREADS) { \
+ cudaFuncSetAttribute( \
+ Marlin, \
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
+ Marlin \
+ <<>>( \
+ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
+ prob_k, locks); \
+ }
+
+typedef struct {
+ int thread_k;
+ int thread_n;
+ int num_threads;
+} thread_config_t;
+
+thread_config_t small_batch_thread_configs[] = {
+ // Ordered by priority
+
+ // thread_k, thread_n, num_threads
+ {128, 128, 256}, // Default
+ {128, 64, 128}, // Reduce N 2X, same K
+ {64, 256, 256}, // Reduce K 2X, increase N 2X
+ {64, 128, 128}, // Reduce K 2X, same N
+};
+
+thread_config_t large_batch_thread_configs[] = {
+ // Ordered by priority
+
+ // thread_k, thread_n, num_threads
+ {64, 256, 256}, // Default
+ {128, 64, 128}, // Reduce N 2X, same K
+ {64, 128, 128}, // Reduce N 2X, same K
+ // {128, 64, 128}, // Reduce N 4X, increase K 2X
+};
+
+bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
+ int prob_k) {
+ // Sanity
+ if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
+ th_config.num_threads == -1) {
+ return false;
+ }
+
+ // Verify K/N are divisible by thread K/N
+ if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
+ return false;
+ }
+
+ // Verify min for thread K/N
+ if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
+ return false;
+ }
+
+ // num_threads must be at least 128 (= 4 warps)
+ if (th_config.num_threads < 128) {
+ return false;
+ }
+
+ return true;
+}
+
+thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
+
+ // TODO: Enable if needed after some more testing
+ if (prob_m <= 0) {
+ for (auto th_config : small_batch_thread_configs) {
+ if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
+ return th_config;
+ }
+ }
+
+ } else {
+ for (auto th_config : large_batch_thread_configs) {
+ if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
+ return th_config;
+ }
+ }
+ }
+
+ return thread_config_t{-1, -1, -1};
+}
+
+#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
+ \
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
+ __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
+ \
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
+ __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
+ \
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
+ __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
+ \
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
+ __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
+
+void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
+ void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k,
+ void *workspace, bool has_act_order, bool is_k_full,
+ int num_groups, int group_size, int dev = 0,
+ cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1,
+ int sms = -1, int max_par = 16) {
+ TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
+ ", ", prob_n, ", ", prob_k, "]");
+
+ int tot_m = prob_m;
+ int tot_m_blocks = div_ceil(tot_m, 16);
+ int pad = 16 * tot_m_blocks - tot_m;
+
+ if (sms == -1) {
+ cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
+ }
+
+ int max_shared_mem = 0;
+ cudaDeviceGetAttribute(&max_shared_mem,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
+ TORCH_CHECK(max_shared_mem > 0);
+
+ // Set thread config
+ thread_config_t th_config;
+ if (thread_k != -1 && thread_n != -1) {
+ // User-defined config
+ th_config = thread_config_t{thread_k, thread_n, default_threads};
+ } else {
+ // Auto config
+ th_config = determine_thread_config(prob_m, prob_n, prob_k);
+ }
+
+ TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
+ "Invalid thread config: thread_k = " + str(th_config.thread_k) +
+ ", thread_n = " + str(th_config.thread_n) +
+ ", num_threads = " + str(th_config.num_threads) +
+ " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " +
+ str(prob_n) + "]");
+
+ int num_threads = th_config.num_threads;
+ thread_k = th_config.thread_k;
+ thread_n = th_config.thread_n;
+
+ int thread_k_blocks = thread_k / 16;
+ int thread_n_blocks = thread_n / 16;
+
+ int blocks = sms;
+
+ TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
+ " is not divisible by thread_n = ", thread_n);
+ TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
+ " is not divisible by thread_k = ", thread_k);
+
+ int group_blocks = 0;
+ if (has_act_order) {
+ if (is_k_full) {
+ TORCH_CHECK(group_size != -1);
+ group_blocks = group_size / 16;
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
+ " is not divisible by group_blocks = ", group_blocks);
+ } else {
+ TORCH_CHECK(group_size == 0);
+ group_blocks = 0;
+ }
+
+ } else {
+ if (group_size == -1) {
+ group_blocks = -1;
+ } else {
+ group_blocks = group_size / 16;
+ TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
+ " is not divisible by group_blocks = ", group_blocks);
+ }
+ }
+
+ const int4 *A_ptr = (const int4 *)A;
+ const int4 *B_ptr = (const int4 *)B;
+ int4 *C_ptr = (int4 *)C;
+ const int4 *s_ptr = (const int4 *)s;
+ const int *g_idx_ptr = (const int *)g_idx;
+ const int *perm_ptr = (const int *)perm;
+ int4 *a_tmp_ptr = (int4 *)a_tmp;
+
+ int *locks = (int *)workspace;
+
+ if (has_act_order) {
+ // Permute A columns
+ int block_rows = div_ceil(prob_m, blocks);
+ permute_cols_kernel<<>>(
+ A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
+ A_ptr = a_tmp_ptr;
+ }
+
+ // If we have a full K, then we can run the non-act-order version of Marlin
+ // (since the weight rows are reordered by increasing group ids, and by having
+ // a full K, we have full original groups)
+ if (is_k_full) {
+ has_act_order = false;
+ }
+
+ // Main loop
+ for (int i = 0; i < tot_m_blocks; i += 4) {
+ int thread_m_blocks = tot_m_blocks - i;
+ prob_m = tot_m - 16 * i;
+ int par = 1;
+ if (thread_m_blocks > 4) {
+ // Note that parallel > 1 currently only works for inputs without any
+ // padding
+ par = (16 * thread_m_blocks - pad) / 64;
+ if (par > max_par)
+ par = max_par;
+ prob_m = 64 * par;
+ i += 4 * (par - 1);
+ thread_m_blocks = 4;
+ }
+
+ // Define kernel configurations
+ if (false) {
+ }
+ CALL_IF(16, 4, 256)
+ CALL_IF(8, 8, 256)
+ CALL_IF(8, 4, 128)
+ CALL_IF(4, 8, 128)
+ else {
+ TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
+ str(prob_n) + ", " + str(prob_k) + "]" +
+ ", has_act_order = " + str(has_act_order) +
+ ", num_groups = " + str(num_groups) +
+ ", group_size = " + str(group_size) +
+ ", thread_m_blocks = " + str(thread_m_blocks) +
+ ", thread_n_blocks = " + str(thread_n_blocks) +
+ ", thread_k_blocks = " + str(thread_k_blocks));
+ }
+
+ A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
+ C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
+ }
+}
+
+} // namespace gptq_marlin
+
+torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
+ torch::Tensor &b_scales, torch::Tensor &g_idx,
+ torch::Tensor &perm, torch::Tensor &workspace,
+ int64_t size_m, int64_t size_n, int64_t size_k,
+ bool is_k_full) {
+ // Verify A
+ TORCH_CHECK(a.size(0) == size_m,
+ "Shape mismatch: a.size(0) = " + str(a.size(0)) +
+ ", size_m = " + str(size_m));
+ TORCH_CHECK(a.size(1) == size_k,
+ "Shape mismatch: a.size(1) = " + str(a.size(1)) +
+ ", size_k = " + str(size_k));
+
+ // Verify B
+ TORCH_CHECK(size_k % gptq_marlin::tile_size == 0,
+ "size_k = " + str(size_k) + " is not divisible by tile_size = " +
+ str(gptq_marlin::tile_size));
+ TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
+ "Shape mismatch: b_q_weight.size(0) = " +
+ str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
+ ", tile_size = " + str(gptq_marlin::tile_size));
+ TORCH_CHECK(
+ b_q_weight.size(1) % gptq_marlin::tile_size == 0,
+ "b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
+ " is not divisible by tile_size = " + str(gptq_marlin::tile_size));
+ int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) *
+ gptq_marlin::pack_factor_4bit;
+ TORCH_CHECK(size_n == actual_size_n,
+ "size_n = " + str(size_n) +
+ ", actual_size_n = " + str(actual_size_n));
+
+ // Verify device and strides
+ TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
+ TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
+
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
+
+ TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
+ TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
+
+ TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
+ TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
+
+ TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
+ TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
+
+ // Alloc buffers
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
+ auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
+ torch::Tensor c = torch::empty({size_m, size_n}, options);
+ torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
+
+ // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
+ // auto -1)
+ int thread_k = -1;
+ // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
+ // auto -1)
+ int thread_n = -1;
+ // sms: number of SMs to use for the kernel (can usually be left as auto -1)
+ int sms = -1;
+
+ // Verify g_idx and perm
+ TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
+ (g_idx.size(0) == size_k && perm.size(0) == size_k),
+ "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) +
+ " and perm.size(0) = " + str(perm.size(0)) +
+ ", where size_k = " + str(size_k));
+
+ // Detect groupsize and act_order
+ int num_groups = -1;
+ int group_size = -1;
+ bool has_act_order = g_idx.size(0) != 0;
+
+ int b_rank = b_scales.sizes().size();
+ TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
+ TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
+ " is not size_n = ", size_n);
+ num_groups = b_scales.size(0);
+
+ if (has_act_order) {
+ if (is_k_full) {
+ TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
+ TORCH_CHECK(size_k % num_groups == 0,
+ "size_k = " + str(size_k) +
+ ", is not divisible by num_groups = " + str(num_groups));
+ group_size = size_k / num_groups;
+ } else {
+ group_size = 0;
+ }
+
+ } else {
+ if (num_groups > 1) {
+ TORCH_CHECK(size_k % num_groups == 0,
+ "size_k = " + str(size_k) +
+ ", is not divisible by b_scales.size(0) = " +
+ str(b_scales.size(0)));
+ group_size = size_k / num_groups;
+ } else {
+ group_size = -1;
+ }
+ }
+
+ // Verify workspace size
+ TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0,
+ "size_n = " + str(size_n) +
+ ", is not divisible by min_thread_n = " +
+ str(gptq_marlin::min_thread_n));
+ int min_workspace_size =
+ (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
+ TORCH_CHECK(workspace.numel() >= min_workspace_size,
+ "workspace.numel = " + str(workspace.numel()) +
+ " is below min_workspace_size = " + str(min_workspace_size));
+
+ int dev = a.get_device();
+ gptq_marlin::marlin_cuda(
+ a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
+ g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
+ size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups,
+ group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
+ sms, gptq_marlin::max_par);
+
+ return c;
+}
+
+#endif
diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh
new file mode 100644
index 0000000000000..8cfce6b2575d5
--- /dev/null
+++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh
@@ -0,0 +1,74 @@
+#pragma once
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace gptq_marlin {
+
+// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
+// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
+// many registers per warp and small tiles.
+static constexpr int default_threads = 256;
+
+static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
+
+static constexpr int min_thread_n = 64;
+static constexpr int min_thread_k = 64;
+
+static constexpr int tile_size = 16;
+static constexpr int max_par = 16;
+
+static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
+
+template
+struct Vec {
+ T elems[n];
+ __device__ T& operator[](int i) { return elems[i]; }
+};
+
+using I4 = Vec;
+
+constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ // No support for async
+#else
+
+__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
+ const int BYTES = 16;
+ uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
+ asm volatile("{\n"
+ " .reg .pred p;\n"
+ " setp.ne.b32 p, %0, 0;\n"
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
+ "}\n" ::"r"((int)pred),
+ "r"(smem), "l"(glob_ptr), "n"(BYTES));
+}
+
+__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
+ const int BYTES = 16;
+ uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr));
+ asm volatile("{\n"
+ " .reg .b64 p;\n"
+ " createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
+ " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
+ "}\n" ::"r"(smem),
+ "l"(glob_ptr), "n"(BYTES));
+}
+
+__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
+
+template
+__device__ inline void cp_async_wait() {
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
+}
+
+#endif
+
+} // namespace gptq_marlin
diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
new file mode 100644
index 0000000000000..fa45ce68a0c77
--- /dev/null
+++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
@@ -0,0 +1,324 @@
+#include "gptq_marlin.cuh"
+
+namespace gptq_marlin {
+
+static constexpr int repack_stages = 8;
+
+static constexpr int repack_threads = 256;
+
+static constexpr int tile_k_size = tile_size;
+static constexpr int tile_n_size = tile_k_size * 4;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+
+template
+__global__ void
+marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
+ uint32_t const *__restrict__ perm_ptr,
+ uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
+
+} // namespace gptq_marlin
+
+torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
+ int64_t size_k, int64_t size_n) {
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
+ return torch::empty({1, 1});
+}
+
+#else
+
+template
+__global__ void
+marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
+ uint32_t const *__restrict__ perm_ptr,
+ uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
+ int k_tiles = size_k / tile_k_size;
+ int n_tiles = size_n / tile_n_size;
+ int block_k_tiles = div_ceil(k_tiles, gridDim.x);
+
+ int start_k_tile = blockIdx.x * block_k_tiles;
+ if (start_k_tile >= k_tiles) {
+ return;
+ }
+
+ int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
+
+ // Wait until the next thread tile has been loaded to shared memory.
+ auto wait_for_stage = [&]() {
+ // We only have `stages - 2` active fetches since we are double buffering
+ // and can only issue the next fetch when it is guaranteed that the previous
+ // shared memory load is fully complete (as it may otherwise be
+ // overwritten).
+ cp_async_wait();
+ __syncthreads();
+ };
+
+ extern __shared__ int4 sh[];
+
+ constexpr int perm_size = tile_k_size / 4;
+
+ int4 *sh_perm_ptr = sh;
+ int4 *sh_pipe_ptr = sh_perm_ptr;
+ if constexpr (has_perm) {
+ sh_pipe_ptr += perm_size;
+ }
+
+ constexpr int stage_n_threads = tile_n_size / 4;
+ constexpr int stage_k_threads =
+ has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
+ constexpr int stage_size = stage_k_threads * stage_n_threads;
+
+ auto load_perm_to_shared = [&](int k_tile_id) {
+ int first_k_int4 = (k_tile_id * tile_k_size) / 4;
+
+ int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr);
+
+ if (threadIdx.x < perm_size) {
+ sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
+ }
+ __syncthreads();
+ };
+
+ auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
+ if (n_tile_id >= n_tiles) {
+ cp_async_fence();
+ return;
+ }
+
+ int first_n = n_tile_id * tile_n_size;
+
+ int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
+
+ if constexpr (has_perm) {
+ if (threadIdx.x < stage_size) {
+ int k_id = threadIdx.x / stage_n_threads;
+ int n_id = threadIdx.x % stage_n_threads;
+
+ uint32_t const *sh_perm_int_ptr =
+ reinterpret_cast(sh_perm_ptr);
+
+ int src_k = sh_perm_int_ptr[k_id];
+ int src_k_packed = src_k / pack_factor_4bit;
+
+ cp_async4_stream(
+ &sh_ptr[k_id * stage_n_threads + n_id],
+ reinterpret_cast(&(
+ b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
+ }
+
+ } else {
+ if (threadIdx.x < stage_size) {
+ int k_id = threadIdx.x / stage_n_threads;
+ int n_id = threadIdx.x % stage_n_threads;
+
+ int first_k = k_tile_id * tile_k_size;
+ int first_k_packed = first_k / pack_factor_4bit;
+
+ cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
+ reinterpret_cast(
+ &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
+ first_n + (n_id * 4)])));
+ }
+ }
+
+ cp_async_fence();
+ };
+
+ auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
+ if (n_tile_id >= n_tiles) {
+ return;
+ }
+
+ int warp_id = threadIdx.x / 32;
+ int th_id = threadIdx.x % 32;
+
+ if (warp_id >= 4) {
+ return;
+ }
+
+ int tc_col = th_id / 4;
+ int tc_row = (th_id % 4) * 2;
+
+ constexpr int tc_offsets[4] = {0, 1, 8, 9};
+
+ int cur_n = warp_id * 16 + tc_col;
+
+ constexpr int sh_stride = 64;
+
+ int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
+ uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr);
+
+ uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr);
+
+ uint32_t vals[pack_factor_4bit];
+
+ if constexpr (has_perm) {
+ for (int i = 0; i < 4; i++) {
+ int k_idx = tc_row + tc_offsets[i];
+
+ uint32_t src_k = sh_perm_int_ptr[k_idx];
+ uint32_t src_k_pos = src_k % pack_factor_4bit;
+
+ uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
+ uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
+
+ uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
+ uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
+
+ vals[i] = b1_cur_val;
+ vals[4 + i] = b2_cur_val;
+ }
+
+ } else {
+
+ uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
+ uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
+
+ uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
+ uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
+
+#pragma unroll
+ for (int i = 0; i < 2; i++) {
+ int cur_elem = tc_row + tc_offsets[i];
+ vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
+ vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
+ }
+
+#pragma unroll
+ for (int i = 2; i < 4; i++) {
+ int cur_elem = tc_row + tc_offsets[i] - 8;
+ vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
+ vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
+ }
+ }
+
+ // Result of:
+ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+ constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
+
+ uint32_t res = 0;
+#pragma unroll
+ for (int i = 0; i < pack_factor_4bit; i++) {
+ res |= vals[pack_idx[i]] << (i * 4);
+ }
+
+ constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
+ int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
+
+ out_ptr[out_offset + th_id * 4 + warp_id] = res;
+ };
+
+ auto start_pipes = [&](int k_tile_id, int n_tile_id) {
+#pragma unroll
+ for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
+ fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
+ }
+
+ wait_for_stage();
+ };
+#pragma unroll
+ for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
+ int n_tile_id = 0;
+
+ if constexpr (has_perm) {
+ load_perm_to_shared(k_tile_id);
+ }
+
+ start_pipes(k_tile_id, n_tile_id);
+
+ while (n_tile_id < n_tiles) {
+#pragma unroll
+ for (int pipe = 0; pipe < repack_stages; pipe++) {
+ fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
+ n_tile_id + pipe + repack_stages - 1);
+ repack_tile(pipe, k_tile_id, n_tile_id + pipe);
+ wait_for_stage();
+ }
+ n_tile_id += repack_stages;
+ }
+ }
+}
+
+} // namespace gptq_marlin
+
+torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
+ int64_t size_k, int64_t size_n) {
+ // Verify compatibility with marlin tile of 16x64
+ TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
+ " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
+ TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
+ " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
+
+ // Verify B
+ TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
+ "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
+ ", size_k = ", size_k,
+ ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
+ TORCH_CHECK(b_q_weight.size(1) == size_n,
+ "b_q_weight.size(1) = ", b_q_weight.size(1),
+ " is not size_n = ", size_n);
+
+ // Verify device and strides
+ TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
+ TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
+ TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
+
+ TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
+ TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
+ TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
+
+ // Alloc buffers
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
+ auto options = torch::TensorOptions()
+ .dtype(b_q_weight.dtype())
+ .device(b_q_weight.device());
+ torch::Tensor out = torch::empty(
+ {size_k / gptq_marlin::tile_size,
+ size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
+ options);
+
+ // Detect if there is act_order
+ bool has_perm = perm.size(0) != 0;
+
+ // Get ptrs
+ uint32_t const *b_q_weight_ptr =
+ reinterpret_cast(b_q_weight.data_ptr());
+ uint32_t const *perm_ptr =
+ reinterpret_cast(perm.data_ptr());
+ uint32_t *out_ptr = reinterpret_cast(out.data_ptr());
+
+ // Get dev info
+ int dev = b_q_weight.get_device();
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
+ int blocks;
+ cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
+
+ int max_shared_mem = 0;
+ cudaDeviceGetAttribute(&max_shared_mem,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
+ TORCH_CHECK(max_shared_mem > 0);
+
+ if (has_perm) {
+ cudaFuncSetAttribute(
+ gptq_marlin::marlin_repack_kernel,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ max_shared_mem);
+ gptq_marlin::marlin_repack_kernel
+ <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
+
+ } else {
+ cudaFuncSetAttribute(
+ gptq_marlin::marlin_repack_kernel,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ max_shared_mem);
+ gptq_marlin::marlin_repack_kernel
+ <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
+ }
+
+ return out;
+}
+
+#endif
diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py
new file mode 100644
index 0000000000000..dc027697ffd4d
--- /dev/null
+++ b/tests/models/test_gptq_marlin.py
@@ -0,0 +1,93 @@
+"""Compares the outputs of gptq vs gptq_marlin
+Note: GPTQ and Marlin do not have bitwise correctness.
+As a result, in this test, we just confirm that the top selected tokens of the
+Marlin/GPTQ models are in the top 3 selections of each other.
+Note: Marlin internally uses locks to synchronize the threads. This can
+result in very slight nondeterminism for Marlin. As a result, we re-run the test
+up to 3 times to see if we pass.
+Note: This test currently fails running with --forked with the following:
+ RuntimeError: Cannot re-initialize CUDA in forked subprocess.
+ To use CUDA with multiprocessing, you must use the 'spawn' start method
+Run `pytest tests/models/test_gptq_marlin.py`.
+"""
+import os
+
+import pytest
+import torch
+
+from tests.models.utils import check_logprobs_close
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+
+MAX_MODEL_LEN = 1024
+
+capability = torch.cuda.get_device_capability()
+capability = capability[0] * 10 + capability[1]
+gptq_marlin_not_supported = (
+ capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
+
+MODELS = [
+ # act_order==False, group_size=channelwise
+ ("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
+ # act_order==False, group_size=128
+ ("TheBloke/Llama-2-7B-GPTQ", "main"),
+
+ # act_order==True, group_size=128
+ ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
+ # act_order==True, group_size=64
+ ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
+ # act_order==True, group_size=32
+ ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
+]
+
+
+@pytest.mark.flaky(reruns=2)
+@pytest.mark.skipif(gptq_marlin_not_supported,
+ reason="gptq_marlin is not supported on this GPU type.")
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [32])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_models(
+ vllm_runner,
+ example_prompts,
+ model,
+ dtype: str,
+ max_tokens: int,
+ num_logprobs: int,
+) -> None:
+ model_name, revision = model
+
+ # Run marlin.
+ gptq_marlin_model = vllm_runner(model_name=model_name,
+ revision=revision,
+ dtype=dtype,
+ quantization="marlin",
+ max_model_len=MAX_MODEL_LEN,
+ tensor_parallel_size=1,
+ disable_custom_all_reduce=True)
+
+ gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
+ example_prompts, max_tokens, num_logprobs)
+ del gptq_marlin_model
+
+ # Run gptq.
+ gptq_model = vllm_runner(model_name=model_name,
+ revision=revision,
+ dtype=dtype,
+ quantization="gptq",
+ max_model_len=MAX_MODEL_LEN,
+ tensor_parallel_size=1,
+ disable_custom_all_reduce=True)
+ gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
+ max_tokens,
+ num_logprobs)
+ del gptq_model
+
+ check_logprobs_close(
+ outputs_0_lst=gptq_outputs,
+ outputs_1_lst=gptq_marlin_outputs,
+ name_0="gptq",
+ name_1="gptq_marlin",
+ )
diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py
index 4fe6daec02520..fa846d43d0e88 100644
--- a/tests/models/test_marlin.py
+++ b/tests/models/test_marlin.py
@@ -10,12 +10,12 @@
Run `pytest tests/models/test_marlin.py`.
"""
-
from dataclasses import dataclass
import pytest
import torch
+from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
@@ -55,43 +55,24 @@ def test_models(
max_tokens: int,
num_logprobs: int,
) -> None:
- marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
+ marlin_model = vllm_runner(model_pair.model_marlin,
+ dtype=dtype,
+ quantization="marlin")
marlin_outputs = marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
-
- # Note: not sure why, but deleting just the model on Ada Lovelace
- # does not free the GPU memory. On Ampere, deleting the just model
- # frees the memory.
del marlin_model
- gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
+ gptq_model = vllm_runner(model_pair.model_gptq,
+ dtype=dtype,
+ quantization="gptq")
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
-
- # Note: not sure why, but deleting just the model on Ada Lovelace
- # does not free the GPU memory. On Ampere, deleting the just model
- # frees the memory.
del gptq_model
- # loop through the prompts
- for prompt_idx in range(len(example_prompts)):
- gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
- prompt_idx]
- marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
- prompt_idx]
-
- for idx, (gptq_output_id, marlin_output_id) in enumerate(
- zip(gptq_output_ids, marlin_output_ids)):
- # If sequence is not an exact match,
- if marlin_output_id != gptq_output_id:
- # Each predicted token must be in top 5 of the other's
- assert gptq_output_id in marlin_logprobs[idx], (
- f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
- f"Marlin:\t{marlin_output_str!r}")
- assert marlin_output_id in gptq_logprobs[idx], (
- f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
- f"Marlin:\t{marlin_output_str!r}")
-
- # Break out since sequences will now diverge.
- break
+ check_logprobs_close(
+ outputs_0_lst=gptq_outputs,
+ outputs_1_lst=marlin_outputs,
+ name_0="gptq",
+ name_1="marlin",
+ )
diff --git a/tests/models/utils.py b/tests/models/utils.py
new file mode 100644
index 0000000000000..3e49dfb331176
--- /dev/null
+++ b/tests/models/utils.py
@@ -0,0 +1,29 @@
+def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
+ """Compare the logprobs of two sequences generated by different models,
+ which should be similar but not necessarily equal.
+ """
+ # Loop through responses to each prompt.
+ for prompt_idx, (outputs_0,
+ outputs_1) in enumerate(zip(outputs_0_lst,
+ outputs_1_lst)):
+ output_ids_0, output_str_0, logprobs_0 = outputs_0
+ output_ids_1, output_str_1, logprobs_1 = outputs_1
+
+ # Loop through generated tokens.
+ for idx, (output_id_0,
+ output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
+
+ # If generated tokens don't match, then
+ if output_id_0 != output_id_1:
+ # Each predicted token must be in top N logprobs of the other
+ assert output_id_0 in logprobs_1[idx], (
+ f"Test{prompt_idx}:"
+ f"\n{name_0}:\t{output_str_0!r}"
+ f"\n{name_1}:\t{output_str_1!r}")
+ assert output_id_1 in logprobs_0[idx], (
+ f"Test{prompt_idx}:"
+ f"\n{name_0}:\t{output_str_0!r}"
+ f"\n{name_1}:\t{output_str_1!r}")
+
+ # Break out since sequences will now diverge.
+ break
diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py
deleted file mode 100644
index 1310b4da218b5..0000000000000
--- a/tests/quantization/test_autogptq_marlin_configs.py
+++ /dev/null
@@ -1,64 +0,0 @@
-"""Tests whether Marlin models can be loaded from the autogptq config.
-
-Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
-"""
-
-from dataclasses import dataclass
-
-import pytest
-
-from vllm.config import ModelConfig
-
-
-@dataclass
-class ModelPair:
- model_marlin: str
- model_gptq: str
-
-
-# Model Id // Expected Kernel
-MODELS_QUANT_TYPE = [
- # compat: autogptq <=0.7.1 is_marlin_format: bool
- ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
- ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
- # compat: autogptq >=0.8.0 use checkpoint_format: str
- ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
- ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
-]
-
-
-@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
-def test_auto_gptq(model_quant_type: str, ) -> None:
- model_path, quant_type = model_quant_type
-
- model_config_no_quant_arg = ModelConfig(
- model_path,
- model_path,
- tokenizer_mode="auto",
- trust_remote_code=False,
- seed=0,
- dtype="float16",
- revision=None,
- quantization=None # case 1
- )
-
- model_config_quant_arg = ModelConfig(
- model_path,
- model_path,
- tokenizer_mode="auto",
- trust_remote_code=False,
- seed=0,
- dtype="float16",
- revision=None,
- quantization="gptq" # case 2
- )
-
- assert model_config_no_quant_arg.quantization == quant_type, (
- f"Expected quant_type == {quant_type} for {model_path}, "
- f"but found {model_config_no_quant_arg.quantization} "
- "for no --quantization None case")
-
- assert model_config_quant_arg.quantization == quant_type, (
- f"Expected quant_type == {quant_type} for {model_path}, "
- f"but found {model_config_quant_arg.quantization} "
- "for --quantization gptq case")
diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py
new file mode 100644
index 0000000000000..6820b2728e3c9
--- /dev/null
+++ b/tests/quantization/test_configs.py
@@ -0,0 +1,73 @@
+"""Tests whether Marlin models can be loaded from the autogptq config.
+
+Run `pytest tests/quantization/test_configs.py --forked`.
+"""
+
+from dataclasses import dataclass
+
+import pytest
+
+from vllm.config import ModelConfig
+
+
+@dataclass
+class ModelPair:
+ model_marlin: str
+ model_gptq: str
+
+
+# Model Id // Quantization Arg // Expected Type
+MODEL_ARG_EXPTYPES = [
+ # AUTOGPTQ
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
+ # Model Serialized in Marlin Format should always use Marlin kernel.
+ ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"),
+ ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"),
+ ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"),
+ ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"),
+ # Model Serialized in Exllama Format.
+ ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"),
+ ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"),
+ ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
+ ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
+ # Model Serialized in Marlin Format should always use Marlin kernel.
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"),
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"),
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"),
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"),
+ # Model Serialized in Exllama Format.
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"),
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"),
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
+ ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
+
+ # AUTOAWQ
+ ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"),
+ ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
+ ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"),
+ ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
+]
+
+
+@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
+def test_auto_gptq(model_arg_exptype: str) -> None:
+ model_path, quantization_arg, expected_type = model_arg_exptype
+
+ try:
+ model_config = ModelConfig(model_path,
+ model_path,
+ tokenizer_mode="auto",
+ trust_remote_code=False,
+ seed=0,
+ dtype="float16",
+ revision=None,
+ quantization=quantization_arg)
+ found_quantization_type = model_config.quantization
+ except ValueError:
+ found_quantization_type = "ERROR"
+
+ assert found_quantization_type == expected_type, (
+ f"Expected quant_type == {expected_type} for {model_path}, "
+ f"but found {found_quantization_type} "
+ f"for no --quantization {quantization_arg} case")
diff --git a/vllm/config.py b/vllm/config.py
index aedb589247646..a5512c657e038 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -9,11 +9,14 @@
from transformers import PretrainedConfig
from vllm.logger import init_logger
-from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
+ get_quantization_config)
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
+GPTQMarlinConfig = get_quantization_config("gptq_marlin")
+
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@@ -138,14 +141,34 @@ def _verify_quantization(self) -> None:
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
or quant_cfg.get("is_marlin_format", False))
- # Use marlin if the GPTQ model is serialized in marlin format.
- if quant_method == "gptq" and is_format_marlin:
- logger.info("The model is serialized in Marlin format. "
+ # Check which LinearMethod the GPTQ model should use.
+ if quant_method == "gptq":
+ # If serialized in Marlin format, use MarlinLinearMethod.
+ # TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
+ if is_format_marlin:
+ logger.info("The model is serialized in Marlin format. "
+ "Using Marlin kernel.")
+ quant_method = "marlin"
+ if self.quantization == "gptq":
+ self.quantization = quant_method
+
+ # If convertible to Marlin format, use GPTQMarlinLinearMethod
+ # unless the user explicitly specified GPTQLinearMethod.
+ elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
+ if self.quantization == "gptq":
+ logger.warning(
+ "The model is convertible to Marlin format, but "
+ "you specified quantization=gptq. Use "
+ "quantization=marlin for faster inference.")
+ else:
+ logger.info(
+ "The model is convertible to Marlin format. "
"Using Marlin kernel.")
- quant_method = "marlin"
- if self.quantization == "gptq":
- self.quantization = quant_method
+ quant_method = "gptq_marlin"
+ if self.quantization == "marlin":
+ self.quantization = quant_method
+ # Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
@@ -165,7 +188,7 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
- if self.quantization != "marlin":
+ if (self.quantization not in ["marlin", "gptq_marlin"]):
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index 70e0a7cfe3e3b..1c652e347d4ad 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -6,6 +6,8 @@
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
+from vllm.model_executor.layers.quantization.gptq_marlin import (
+ GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
@@ -15,6 +17,7 @@
"fp8": Fp8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
+ "gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig,
}
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
new file mode 100644
index 0000000000000..7bff0e834483f
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -0,0 +1,444 @@
+import enum
+from enum import Enum
+from typing import Any, Dict, List, Optional
+
+import numpy
+import torch
+from torch.nn.parameter import Parameter
+
+from vllm._C import ops
+from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
+ set_weight_attrs)
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+
+GPTQ_MARLIN_TILE = 16
+GPTQ_MARLIN_MIN_THREAD_N = 64
+GPTQ_MARLIN_MIN_THREAD_K = 128
+GPTQ_MARLIN_MAX_PARALLEL = 16
+
+GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
+GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
+GPTQ_MARLIN_SUPPORTED_SYM = [True]
+
+
+# Precompute permutations for Marlin weight and scale shuffling
+#
+# Marlin works on [16,64] tiles. The goal of the permutations
+# is to reorder the weight data so that it is compatible
+# with the tensor-core format that is described here:
+# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
+#
+# As a result of this reordering, the vector loads inside the
+# kernel will get the data as it is needed for tensor-core
+# (without the need to use ldmatrix instructions)
+def _get_perms():
+ perm = []
+ for i in range(32):
+ perm1 = []
+ col = i // 4
+ for block in [0, 1]:
+ for row in [
+ 2 * (i % 4),
+ 2 * (i % 4) + 1,
+ 2 * (i % 4 + 4),
+ 2 * (i % 4 + 4) + 1,
+ ]:
+ perm1.append(16 * row + col + 8 * block)
+ for j in range(4):
+ perm.extend([p + 256 * j for p in perm1])
+
+ perm = numpy.array(perm)
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
+ perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
+ perm = torch.from_numpy(perm)
+ scale_perm = []
+ for i in range(8):
+ scale_perm.extend([i + 8 * j for j in range(8)])
+ scale_perm_single = []
+ for i in range(4):
+ scale_perm_single.extend(
+ [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
+ return perm, scale_perm, scale_perm_single
+
+
+_perm, _scale_perm, _scale_perm_single = _get_perms()
+
+
+def get_pack_factor(num_bits):
+ assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
+ f"Unsupported num_bits = {num_bits}")
+ return 32 // num_bits
+
+
+def marlin_permute_scales(s, size_k, size_n, group_size):
+ if group_size < size_k and group_size != -1:
+ s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
+ else:
+ s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
+ s = s.reshape((-1, size_n)).contiguous()
+
+ return s
+
+
+class GPTQMarlinConfig(QuantizationConfig):
+ """Config class for GPTQ Marlin"""
+
+ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
+ is_sym: bool) -> None:
+ if desc_act and group_size == -1:
+ # In this case, act_order == True is the same as act_order == False
+ # (since we have only one group per output channel)
+ desc_act = False
+
+ self.weight_bits = weight_bits
+ self.group_size = group_size
+ self.desc_act = desc_act
+ self.is_sym = is_sym
+
+ # Verify
+ if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
+ raise ValueError(
+ f"Marlin does not support weight_bits = {self.weight_bits}. "
+ f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
+ "are supported.")
+ if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
+ raise ValueError(
+ f"Marlin does not support group_size = {self.group_size}. "
+ f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
+ "are supported.")
+ if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
+ raise ValueError(
+ f"Marlin does not support is_sym = {self.is_sym}. "
+ f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
+
+ # Init
+ self.pack_factor = get_pack_factor(weight_bits)
+ self.tile_size = GPTQ_MARLIN_TILE
+ self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
+ self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
+ self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
+
+ def __repr__(self) -> str:
+ return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
+ f"group_size={self.group_size}, "
+ f"desc_act={self.desc_act})")
+
+ @classmethod
+ def get_name(cls) -> str:
+ return "gptq_marlin"
+
+ @classmethod
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+ return [torch.half]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ return 80
+
+ @classmethod
+ def get_config_filenames(cls) -> List[str]:
+ return ["quantize_config.json"]
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
+ weight_bits = cls.get_from_keys(config, ["bits"])
+ group_size = cls.get_from_keys(config, ["group_size"])
+ desc_act = cls.get_from_keys(config, ["desc_act"])
+ is_sym = cls.get_from_keys(config, ["sym"])
+ return cls(weight_bits, group_size, desc_act, is_sym)
+
+ def get_quant_method(
+ self,
+ layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return GPTQMarlinLinearMethod(self)
+ return None
+
+ def get_scaled_act_names(self) -> List[str]:
+ return []
+
+ @classmethod
+ def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
+ # Extract data from quant config.
+ num_bits = quant_config.get("bits", None)
+ group_size = quant_config.get("group_size", None)
+ sym = quant_config.get("sym", None)
+ desc_act = quant_config.get("desc_act", None)
+
+ # If we cannot find the info needed in the config, cannot convert.
+ if (num_bits is None or group_size is None or sym is None
+ or desc_act is None):
+ return False
+
+ # If the capability of the device is too low, cannot convert.
+ major, minor = torch.cuda.get_device_capability()
+ device_capability = major * 10 + minor
+ if device_capability < cls.get_min_capability():
+ return False
+
+ # Otherwise, can convert if model satisfies marlin constraints.
+ return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
+ and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
+ and sym in GPTQ_MARLIN_SUPPORTED_SYM)
+
+
+class GPTQMarlinState(Enum):
+ REPACK = enum.auto()
+ READY = enum.auto()
+
+
+class GPTQMarlinLinearMethod(LinearMethodBase):
+ """Linear method for GPTQ Marlin.
+
+ Args:
+ quant_config: The GPTQ Marlin quantization config.
+ """
+
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
+ self.quant_config = quant_config
+
+ def create_weights(
+ self,
+ layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_partition_sizes: List[int],
+ input_size: int,
+ output_size: int,
+ params_dtype: torch.dtype,
+ **extra_weight_attrs,
+ ) -> None:
+ del output_size
+
+ # Normalize group_size
+ if self.quant_config.group_size != -1:
+ group_size = self.quant_config.group_size
+ else:
+ group_size = input_size
+
+ # Validate dtype
+ if params_dtype != torch.float16:
+ raise ValueError(
+ f"The params dtype must be float16, but got {params_dtype}")
+
+ # Validate output_size_per_partition
+ output_size_per_partition = sum(output_partition_sizes)
+ if output_size_per_partition % self.quant_config.min_thread_n != 0:
+ raise ValueError(
+ f"Weight output_size_per_partition = "
+ f"{output_size_per_partition} is not divisible by "
+ f" min_thread_n = {self.quant_config.min_thread_n}.")
+
+ # Validate input_size_per_partition
+ if input_size_per_partition % self.quant_config.min_thread_k != 0:
+ raise ValueError(
+ f"Weight input_size_per_partition = "
+ f"{input_size_per_partition} is not divisible "
+ f"by min_thread_k = {self.quant_config.min_thread_k}.")
+
+ if (group_size < input_size
+ and input_size_per_partition % group_size != 0):
+ raise ValueError(
+ f"Weight input_size_per_partition = {input_size_per_partition}"
+ f" is not divisible by group_size = {group_size}.")
+
+ # Detect sharding of scales/zp
+
+ # By default, no sharding over "input dim"
+ scales_and_zp_size = input_size // group_size
+ scales_and_zp_input_dim = None
+
+ if self.quant_config.desc_act:
+ # Act-order case
+ assert self.quant_config.group_size != -1
+
+ is_k_full = input_size_per_partition == input_size
+
+ else:
+ # No act-order case
+
+ # K is always full due to full alignment with
+ # group-size and shard of scales/zp
+ is_k_full = True
+
+ # If this is a row-parallel case, then shard scales/zp
+ if (input_size != input_size_per_partition
+ and self.quant_config.group_size != -1):
+ scales_and_zp_size = input_size_per_partition // group_size
+ scales_and_zp_input_dim = 0
+
+ # Init buffers
+
+ # Quantized weights
+ qweight = Parameter(
+ torch.empty(
+ input_size_per_partition // self.quant_config.pack_factor,
+ output_size_per_partition,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(
+ qweight, {
+ **extra_weight_attrs,
+ "input_dim": 0,
+ "output_dim": 1,
+ "packed_dim": 0,
+ "pack_factor": self.quant_config.pack_factor,
+ })
+
+ # Activation order
+ g_idx = Parameter(
+ torch.empty(
+ input_size_per_partition,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ # Ignore warning from fused linear layers such as QKVParallelLinear.
+ set_weight_attrs(g_idx, {
+ **extra_weight_attrs, "input_dim": 0,
+ "ignore_warning": True
+ })
+
+ g_idx_sort_indices = Parameter(
+ torch.empty(
+ g_idx.shape,
+ dtype=torch.int32,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
+
+ # Scales
+ scales = Parameter(
+ torch.empty(
+ scales_and_zp_size,
+ output_size_per_partition,
+ dtype=params_dtype,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(
+ scales, {
+ **extra_weight_attrs,
+ "input_dim": scales_and_zp_input_dim,
+ "output_dim": 1,
+ })
+
+ # Quantized zero-points
+ qzeros = Parameter(
+ torch.empty(scales_and_zp_size,
+ output_size_per_partition //
+ self.quant_config.pack_factor,
+ dtype=torch.int32,
+ device="meta"),
+ requires_grad=False,
+ )
+ set_weight_attrs(
+ qzeros, {
+ **extra_weight_attrs,
+ "input_dim": scales_and_zp_input_dim,
+ "output_dim": 1,
+ "packed_dim": 1,
+ "pack_factor": self.quant_config.pack_factor,
+ })
+
+ # Allocate marlin workspace
+ max_workspace_size = (
+ output_size_per_partition //
+ self.quant_config.min_thread_n) * self.quant_config.max_parallel
+ workspace = torch.zeros(max_workspace_size,
+ dtype=torch.int,
+ requires_grad=False)
+
+ layer.register_parameter("qweight", qweight)
+ layer.register_parameter("g_idx", g_idx)
+ layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
+ layer.register_parameter("scales", scales)
+ layer.register_parameter("qzeros", qzeros)
+ layer.workspace = workspace
+ layer.input_size_per_partition = input_size_per_partition
+ layer.output_size_per_partition = output_size_per_partition
+ layer.input_size = input_size
+ layer.is_k_full = is_k_full
+ layer.marlin_state = GPTQMarlinState.REPACK
+
+ def apply(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ reshaped_x = x.reshape(-1, x.shape[-1])
+
+ size_m = reshaped_x.shape[0]
+ part_size_n = layer.output_size_per_partition
+ part_size_k = layer.input_size_per_partition
+ full_size_k = layer.input_size
+
+ out_shape = x.shape[:-1] + (part_size_n, )
+
+ if layer.marlin_state == GPTQMarlinState.REPACK:
+ layer.marlin_state = GPTQMarlinState.READY
+
+ # Newly generated tensors need to replace existing tensors that are
+ # already registered as parameters by vLLM (and won't be freed)
+ def replace_tensor(name, new_t):
+ # It is important to use resize_() here since it ensures
+ # the same buffer is reused
+ getattr(layer, name).resize_(new_t.shape)
+ getattr(layer, name).copy_(new_t)
+ del new_t
+
+ cur_device = layer.qweight.device
+
+ # Process act_order
+ if self.quant_config.desc_act:
+ # Get sorting based on g_idx
+ g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
+
+ sorted_g_idx = layer.g_idx[g_idx_sort_indices]
+
+ replace_tensor("g_idx", sorted_g_idx)
+ replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
+
+ else:
+ # Reset g_idx related tensors
+ layer.g_idx = Parameter(torch.empty(0,
+ dtype=torch.int,
+ device=cur_device),
+ requires_grad=False)
+ layer.g_idx_sort_indices = Parameter(torch.empty(
+ 0, dtype=torch.int, device=cur_device),
+ requires_grad=False)
+
+ # Repack weights
+ marlin_qweight = ops.gptq_marlin_repack(
+ layer.qweight,
+ layer.g_idx_sort_indices,
+ part_size_k,
+ part_size_n,
+ )
+ replace_tensor("qweight", marlin_qweight)
+
+ # Permute scales
+ scales_size_k = part_size_k
+ scales_size_n = part_size_n
+ if self.quant_config.desc_act:
+ scales_size_k = full_size_k
+
+ marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
+ scales_size_n,
+ self.quant_config.group_size)
+ replace_tensor("scales", marlin_scales)
+
+ output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
+ layer.g_idx, layer.g_idx_sort_indices,
+ layer.workspace, size_m, part_size_n,
+ part_size_k, layer.is_k_full)
+
+ if bias is not None:
+ output.add_(bias) # In-place add
+
+ return output.reshape(out_shape)
From ac5ccf0156e1772f3ea89c205704a31219442c55 Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Mon, 29 Apr 2024 12:50:01 -0700
Subject: [PATCH 151/413] [CI] hotfix: soft fail neuron test (#4458)
---
.buildkite/test-template.j2 | 1 +
1 file changed, 1 insertion(+)
diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2
index fb1086db77823..5c9515840bb03 100644
--- a/.buildkite/test-template.j2
+++ b/.buildkite/test-template.j2
@@ -25,6 +25,7 @@ steps:
agents:
queue: neuron
command: bash .buildkite/run-neuron-test.sh
+ soft_fail: true
- label: "CPU Test"
command: bash .buildkite/run-cpu-test.sh
From f4f921b7f12c67d3c4b7575caf5ddd9bd4b0b787 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Mon, 29 Apr 2024 13:52:22 -0700
Subject: [PATCH 152/413] [Core][Distributed] use cpu group to broadcast
metadata in cpu (#4444)
---
.../tensorize_vllm_model_for_testing.py | 6 +-
tests/worker/test_model_runner.py | 23 ++++---
vllm/distributed/communication_op.py | 69 +++++++++++++------
3 files changed, 63 insertions(+), 35 deletions(-)
diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
index e4b15fd57add4..0e113ab647e67 100644
--- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
+++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py
@@ -6,14 +6,14 @@
from functools import partial
from typing import Type
-import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
-from vllm.distributed import initialize_model_parallel
+from vllm.distributed import (init_distributed_environment,
+ initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
@@ -226,7 +226,7 @@ def deserialize():
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
-torch.distributed.init_process_group(world_size=1, rank=0)
+init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py
index abb401f25c100..56fe6db589f18 100644
--- a/tests/worker/test_model_runner.py
+++ b/tests/worker/test_model_runner.py
@@ -2,8 +2,10 @@
import torch
from vllm.config import ModelConfig, SchedulerConfig
+from vllm.distributed.parallel_state import init_distributed_environment
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
+from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@@ -249,19 +251,18 @@ def test_empty_seq_group():
assert len(return_prompt_lens) == 0
-@pytest.mark.parametrize("batch_size", list(range(2, 128)))
-@pytest.mark.parametrize("enforce_eager", [True, False])
-def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
-
- def get_world_size(group=None):
- return 1
+@pytest.fixture
+def distributed_init():
+ init_distributed_environment(
+ world_size=1,
+ rank=0,
+ distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
+ local_rank=0)
- def mock_get_process_group_ranks(group=None):
- return [0]
- monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
- monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
- mock_get_process_group_ranks)
+@pytest.mark.parametrize("batch_size", list(range(2, 128)))
+@pytest.mark.parametrize("enforce_eager", [True, False])
+def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
model_config = ModelConfig(
"facebook/opt-125m",
diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py
index a3e93691a1e8e..8b2c26c3a8afb 100644
--- a/vllm/distributed/communication_op.py
+++ b/vllm/distributed/communication_op.py
@@ -4,7 +4,8 @@
import torch
from torch.distributed import ProcessGroup
-from .parallel_state import (get_tensor_model_parallel_group,
+from .parallel_state import (get_cpu_world_group,
+ get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
@@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
+def _split_tensor_dict(
+ tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
+) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
+ """Split the tensor dictionary into two parts:
+ 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
+ by its metadata.
+ 2. A list of tensors.
+ """
+ metadata_list = []
+ tensor_list = []
+ for key, value in tensor_dict.items():
+ if isinstance(value, torch.Tensor):
+ # Note(youkaichao): currently this only supports broadcasting
+ # tensors on cuda. In the future, we can add device as a field in
+ # TensorMetadata to support broadcasting tensors on different
+ # devices.
+ assert value.is_cuda, (
+ f"Tensor {key}: {value} is not on cuda. Currently we only "
+ f"support broadcasting tensors on cuda.")
+ metadata_list.append((key, TensorMetadata(value.dtype,
+ value.size())))
+ tensor_list.append(value)
+ else:
+ metadata_list.append((key, value))
+ return metadata_list, tensor_list
+
+
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
+ metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
- """Broadcast the input tensor dictionary."""
+ """Broadcast the input tensor dictionary.
+ `group` is used to broadcast the tensors, while `metadata_group` is used
+ to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
+ dtypes).
+ """
group = group or torch.distributed.group.WORLD
+ metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
@@ -161,27 +195,20 @@ def broadcast_tensor_dict(
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
- for key, value in tensor_dict.items():
- if isinstance(value, torch.Tensor):
- assert value.is_cuda, (
- f"Tensor {key}: {value} is not on cuda. Currently we only "
- f"support broadcasting tensors on cuda.")
- metadata_list.append(
- (key, TensorMetadata(value.dtype, value.size())))
- else:
- metadata_list.append((key, value))
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
+ # `metadata_list` lives in CPU memory.
+ # `broadcast_object_list` involves serialization and deserialization,
+ # all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
- group=group)
+ group=metadata_group)
async_handles = []
- for key, value in metadata_list:
- if isinstance(value, TensorMetadata):
- tensor = tensor_dict[key]
- async_handles.append(
- torch.distributed.broadcast(tensor,
- src=src,
- group=group,
- async_op=True))
+ for tensor in tensor_list:
+ async_handles.append(
+ torch.distributed.broadcast(tensor,
+ src=src,
+ group=group,
+ async_op=True))
for async_handle in async_handles:
async_handle.wait()
@@ -189,7 +216,7 @@ def broadcast_tensor_dict(
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
- group=group)
+ group=metadata_group)
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
From d627a3d837976a23f89ba808f5ef6908fdb65bfa Mon Sep 17 00:00:00 2001
From: Michael Goin
Date: Mon, 29 Apr 2024 20:05:47 -0400
Subject: [PATCH 153/413] [Misc] Upgrade to `torch==2.3.0` (#4454)
---
.github/workflows/publish.yml | 2 +-
CMakeLists.txt | 2 +-
Dockerfile | 2 +-
pyproject.toml | 2 +-
requirements-build.txt | 2 +-
requirements-cpu.txt | 2 +-
requirements-cuda.txt | 4 ++--
7 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 4b9fc3d04d872..d79681f03b003 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -49,7 +49,7 @@ jobs:
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
- pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt.
+ pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1']
steps:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1558dbf313ce7..f817f3382c5e1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
-set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1")
+set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
diff --git a/Dockerfile b/Dockerfile
index d1d29177b0f44..e471a6e93b963 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -85,7 +85,7 @@ FROM dev as flash-attn-builder
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# flash attention version
-ARG flash_attn_version=v2.5.6
+ARG flash_attn_version=v2.5.8
ENV FLASH_ATTN_VERSION=${flash_attn_version}
WORKDIR /usr/src/flash-attention-v2
diff --git a/pyproject.toml b/pyproject.toml
index 2e026c1ac8911..6a448defc16e1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ requires = [
"ninja",
"packaging",
"setuptools >= 49.4.0",
- "torch == 2.2.1",
+ "torch == 2.3.0",
"wheel",
]
build-backend = "setuptools.build_meta"
diff --git a/requirements-build.txt b/requirements-build.txt
index 2bc07fb152aac..1a07a94e82e04 100644
--- a/requirements-build.txt
+++ b/requirements-build.txt
@@ -3,5 +3,5 @@ cmake>=3.21
ninja
packaging
setuptools>=49.4.0
-torch==2.2.1
+torch==2.3.0
wheel
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
index e911ad03295f0..b739642d8d344 100644
--- a/requirements-cpu.txt
+++ b/requirements-cpu.txt
@@ -2,5 +2,5 @@
-r requirements-common.txt
# Dependencies for x86_64 CPUs
-torch == 2.2.1+cpu
+torch == 2.3.0+cpu
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
\ No newline at end of file
diff --git a/requirements-cuda.txt b/requirements-cuda.txt
index 1bddae4c6f40f..6548d7a6684b2 100644
--- a/requirements-cuda.txt
+++ b/requirements-cuda.txt
@@ -5,5 +5,5 @@
ray >= 2.9
nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
-torch == 2.2.1
-xformers == 0.0.25 # Requires PyTorch 2.2.1
+torch == 2.3.0
+xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
From fa32207842f1ed5a966372ed0513914bff8426c4 Mon Sep 17 00:00:00 2001
From: Woosuk Kwon
Date: Mon, 29 Apr 2024 22:05:40 -0700
Subject: [PATCH 154/413] [Bugfix][Kernel] Fix compute_type for MoE kernel
(#4463)
---
vllm/model_executor/layers/fused_moe/fused_moe.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index d37837a0b2ce8..b4f81527141a8 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -433,6 +433,8 @@ def fused_moe(
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
+ compute_type = (tl.bfloat16
+ if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
@@ -447,7 +449,7 @@ def fused_moe(
False,
topk_ids.shape[1],
config,
- compute_type=tl.float16,
+ compute_type=compute_type,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -465,7 +467,7 @@ def fused_moe(
True,
1,
config,
- compute_type=tl.float16,
+ compute_type=compute_type,
use_fp8=use_fp8)
if inplace:
From 26f2fb51133c85ad8a57a87c8037f750dda757f4 Mon Sep 17 00:00:00 2001
From: Kunshang Ji
Date: Tue, 30 Apr 2024 12:14:47 +0000
Subject: [PATCH 155/413] [Core]Refactor gptq_marlin ops (#4466)
---
vllm/_custom_ops.py | 16 ++++++++++++++++
.../layers/quantization/gptq_marlin.py | 2 +-
2 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 5ba104bada7ac..4af8b09b1e16c 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -167,6 +167,22 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes)
+# gptq_marlin
+def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
+ size_k: int, size_n: int) -> torch.Tensor:
+ return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n)
+
+
+def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_scales: torch.Tensor, g_idx: torch.Tensor,
+ perm: torch.Tensor, workspace: torch.Tensor, size_m: int,
+ size_n: int, size_k: int,
+ is_k_full: bool) -> torch.Tensor:
+ return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
+ workspace, size_m, size_n, size_k,
+ is_k_full)
+
+
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index 7bff0e834483f..efbffa0878c4b 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -6,7 +6,7 @@
import torch
from torch.nn.parameter import Parameter
-from vllm._C import ops
+from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
From 4bb53e2dde809ea5727b8cac95a080893733a1ef Mon Sep 17 00:00:00 2001
From: leiwen83
Date: Wed, 1 May 2024 01:12:59 +0800
Subject: [PATCH 156/413] [BugFix] fix num_lookahead_slots missing in async
executor (#4165)
Co-authored-by: Lei Wen
---
tests/spec_decode/e2e/conftest.py | 125 +++++++++++++++++++-
tests/spec_decode/e2e/test_compatibility.py | 15 ++-
tests/spec_decode/e2e/test_correctness.py | 25 ++--
vllm/engine/async_llm_engine.py | 6 +-
vllm/executor/cpu_executor.py | 4 +-
vllm/executor/executor_base.py | 1 +
vllm/executor/gpu_executor.py | 4 +-
vllm/executor/neuron_executor.py | 1 +
vllm/executor/ray_gpu_executor.py | 1 +
9 files changed, 163 insertions(+), 19 deletions(-)
diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py
index 59fb8311fc5b7..5d3469c4210ee 100644
--- a/tests/spec_decode/e2e/conftest.py
+++ b/tests/spec_decode/e2e/conftest.py
@@ -1,10 +1,127 @@
-from typing import List, Tuple
+import asyncio
+from typing import List, Optional, Tuple, Union
import pytest
+import ray
from tests.conftest import cleanup
from vllm import LLM
+from vllm.engine.arg_utils import AsyncEngineArgs
+from vllm.engine.async_llm_engine import AsyncLLMEngine
+from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed
+from vllm.outputs import RequestOutput
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import MultiModalData
+from vllm.usage.usage_lib import UsageContext
+from vllm.utils import Counter, random_uuid
+
+
+class AsyncLLM:
+ """AsyncLLM
+
+ Note: Current LLM class in vllm don't support async mode, for test purpose,
+ we implement async one in here. Maybe we could move to
+ vllm/entrypoints/llm.py in future.
+
+ Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
+ to make to work in async mode.
+ """
+
+ def __init__(
+ self,
+ model: str,
+ tokenizer: Optional[str] = None,
+ tokenizer_mode: str = "auto",
+ skip_tokenizer_init: bool = False,
+ trust_remote_code: bool = False,
+ tensor_parallel_size: int = 1,
+ dtype: str = "auto",
+ quantization: Optional[str] = None,
+ revision: Optional[str] = None,
+ tokenizer_revision: Optional[str] = None,
+ seed: int = 0,
+ gpu_memory_utilization: float = 0.9,
+ swap_space: int = 4,
+ enforce_eager: bool = False,
+ max_context_len_to_capture: int = 8192,
+ disable_custom_all_reduce: bool = False,
+ **kwargs,
+ ) -> None:
+ if "disable_log_stats" not in kwargs:
+ kwargs["disable_log_stats"] = True
+ self.engine_args = AsyncEngineArgs(
+ model=model,
+ tokenizer=tokenizer,
+ tokenizer_mode=tokenizer_mode,
+ skip_tokenizer_init=skip_tokenizer_init,
+ trust_remote_code=trust_remote_code,
+ tensor_parallel_size=tensor_parallel_size,
+ dtype=dtype,
+ quantization=quantization,
+ revision=revision,
+ tokenizer_revision=tokenizer_revision,
+ seed=seed,
+ gpu_memory_utilization=gpu_memory_utilization,
+ swap_space=swap_space,
+ enforce_eager=enforce_eager,
+ max_context_len_to_capture=max_context_len_to_capture,
+ engine_use_ray=True,
+ disable_custom_all_reduce=disable_custom_all_reduce,
+ **kwargs,
+ )
+ self.request_counter = Counter()
+
+ def generate(
+ self,
+ prompts: Optional[Union[str, List[str]]] = None,
+ sampling_params: Optional[Union[SamplingParams,
+ List[SamplingParams]]] = None,
+ prompt_token_ids: Optional[List[List[int]]] = None,
+ use_tqdm: bool = True,
+ lora_request: Optional[LoRARequest] = None,
+ multi_modal_data: Optional[MultiModalData] = None,
+ ) -> List[RequestOutput]:
+
+ llm_engine = AsyncLLMEngine.from_engine_args(
+ self.engine_args, usage_context=UsageContext.LLM_CLASS)
+
+ if prompts is None:
+ raise ValueError("prompts must be provided.")
+ if isinstance(prompts, str):
+ # Convert a single prompt to a list.
+ prompts = [prompts]
+
+ if prompts is not None:
+ num_requests = len(prompts)
+
+ if sampling_params is None:
+ # Use default sampling params.
+ sampling_params = SamplingParams()
+
+ elif isinstance(sampling_params,
+ list) and len(sampling_params) != num_requests:
+ raise ValueError("The lengths of prompts and "
+ "sampling_params must be the same.")
+
+ async def get_output(prompt, sampling_param) -> str:
+ request_id = random_uuid()
+ results_generator = llm_engine.generate(prompt, sampling_param,
+ request_id)
+ final_output = None
+ async for request_output in results_generator:
+ final_output = request_output
+ return final_output
+
+ outputs = []
+ try:
+ for i in range(num_requests):
+ prompt = prompts[i] if prompts is not None else None
+ res = asyncio.run(get_output(prompt, sampling_params))
+ outputs.append(res)
+ finally:
+ ray.shutdown()
+ return outputs
@pytest.fixture
@@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
- llm = LLM(**kwargs)
+ use_async = False
+ if "use_async" in kwargs:
+ use_async = kwargs.pop("use_async")
+
+ llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
set_random_seed(seed)
yield llm
diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py
index fde950c14382c..60c20ed7db7a3 100644
--- a/tests/spec_decode/e2e/test_compatibility.py
+++ b/tests/spec_decode/e2e/test_compatibility.py
@@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
temperature=temperature,
)
- with pytest.raises(AssertionError,
- match="Speculative decoding not yet supported for "):
- get_output_from_llm_generator(test_llm_generator, prompts,
- sampling_params)
+ try:
+ with pytest.raises(
+ AssertionError,
+ match="Speculative decoding not yet supported for "):
+ get_output_from_llm_generator(test_llm_generator, prompts,
+ sampling_params)
+ finally:
+ # we need to free up ray resource,
+ # so that latter test could use the gpu we allocated here
+ import ray
+ ray.shutdown()
@pytest.mark.parametrize(
diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py
index 0536cc4ecde76..ab8d913fb894a 100644
--- a/tests/spec_decode/e2e/test_correctness.py
+++ b/tests/spec_decode/e2e/test_correctness.py
@@ -40,17 +40,24 @@
@pytest.mark.parametrize(
"common_llm_kwargs",
- [{
- # Use a small model for a fast test.
- # Note this is repeated in the test body; to initialize a tokenizer.
- "model": "JackFram/llama-68m",
+ [
+ {
+ # Use a small model for a fast test.
+ # Note this is repeated in the test body; to initialize a tokenizer.
+ "model": "JackFram/llama-68m",
- # Skip cuda graph recording for fast test.
- "enforce_eager": True,
+ # Skip cuda graph recording for fast test.
+ "enforce_eager": True,
- # Required for spec decode.
- "use_v2_block_manager": True
- }])
+ # Required for spec decode.
+ "use_v2_block_manager": True,
+
+ # whether use AsyncLLM engine
+ "use_async": async_mode,
+ }
+ # Try both async and sync engine execution
+ for async_mode in [True, False]
+ ])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 7c1eb2ecbe550..4aceb19b50776 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -211,9 +211,11 @@ async def step_async(self) -> List[RequestOutput]:
if not scheduler_outputs.is_empty():
# Execute the model.
output = await self.model_executor.execute_model_async(
- seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
+ seq_group_metadata_list,
+ scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
- scheduler_outputs.blocks_to_copy)
+ scheduler_outputs.blocks_to_copy,
+ num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
else:
output = []
diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py
index e4436b2144bd3..da1b500cddaf6 100644
--- a/vllm/executor/cpu_executor.py
+++ b/vllm/executor/cpu_executor.py
@@ -109,12 +109,14 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
- blocks_to_copy=blocks_to_copy)
+ blocks_to_copy=blocks_to_copy,
+ num_lookahead_slots=num_lookahead_slots)
return output
async def check_health_async(self) -> None:
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index c36aa18fb25bb..96cd18250bb37 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -112,6 +112,7 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Executes one model step on the given sequences."""
raise NotImplementedError
diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py
index 5ac62f02b99c7..489e66d586028 100644
--- a/vllm/executor/gpu_executor.py
+++ b/vllm/executor/gpu_executor.py
@@ -163,10 +163,12 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
- blocks_to_copy=blocks_to_copy)
+ blocks_to_copy=blocks_to_copy,
+ num_lookahead_slots=num_lookahead_slots)
return output
diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py
index f406287f3c1d8..8a3b9cde84311 100644
--- a/vllm/executor/neuron_executor.py
+++ b/vllm/executor/neuron_executor.py
@@ -84,6 +84,7 @@ async def execute_model_async(
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
+ num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, )
diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py
index b6bcda4e6b18c..3eb3726bd5a6d 100644
--- a/vllm/executor/ray_gpu_executor.py
+++ b/vllm/executor/ray_gpu_executor.py
@@ -196,6 +196,7 @@ def execute_model(self,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
+ "num_lookahead_slots": num_lookahead_slots,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
From b31a1fb63c98fa1c64666aaae15579439af60d95 Mon Sep 17 00:00:00 2001
From: Prashant Gupta
Date: Tue, 30 Apr 2024 10:41:59 -0700
Subject: [PATCH 157/413] [Doc] add visualization for multi-stage dockerfile
(#4456)
Signed-off-by: Prashant Gupta
Co-authored-by: Roger Wang
---
Dockerfile | 4 ++
.../dev/dockerfile-stages-dependency.png | Bin 0 -> 118207 bytes
docs/source/dev/dockerfile/dockerfile.rst | 50 ++++++++++++++++++
docs/source/index.rst | 1 +
4 files changed, 55 insertions(+)
create mode 100644 docs/source/assets/dev/dockerfile-stages-dependency.png
create mode 100644 docs/source/dev/dockerfile/dockerfile.rst
diff --git a/Dockerfile b/Dockerfile
index e471a6e93b963..e8a9842c089dd 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,10 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server.
+# Please update any changes made here to
+# docs/source/dev/dockerfile/dockerfile.rst and
+# docs/source/assets/dev/dockerfile-stages-dependency.png
+
#################### BASE BUILD IMAGE ####################
# prepare basic build environment
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
diff --git a/docs/source/assets/dev/dockerfile-stages-dependency.png b/docs/source/assets/dev/dockerfile-stages-dependency.png
new file mode 100644
index 0000000000000000000000000000000000000000..b016531f1e0a06bb38b01b1989df932f161e3aee
GIT binary patch
literal 118207
zcmbTe2RN2}{62h1C8|^?QXwg#VPrLwB%92JC`B0=S&4+K5-BScNf|A&
zw`65+-t&9Y^ZefTcO38Wf8V{vgFT2Cn}{-bFh+EwdtH_RqXc-qcx*E0rR
z1*scb-ImW|e79nKG=0?kh4TUrS}O+Sedv)8-0^79$yTXH1;v&x-mxlr(6O(x*|&Gz
zysH<^6)mscx}5c}FO|AM||lCh~bKJ?{fr#c3|hGm`9(zD5ljp!Z7IutIbwko#1
z{fL%OU@;vA_4gGi7JOXKdfwmvl7FW|zA*fc|0`^X9{iu5C|fLmDgO6L+pWU=-~T8o
z@$Ub}k9T`5oBezl+cy0i883F>l4qxjpILW&iE5_O1hllYI=&|yyFBzeAjf&)@qQn!
zz5B^{-Hs}E(9=HHz?u#BY}e8+T=Kz_Ws~JQ
zmi?z{URp@@ylnpRaF0inF`1M?WBkWf%gkTPXGf6u`SHOA{l)szqVkiC%mbbJGg4cQ
zC2f^r&EtB68%vJ0`0X&vzkZurUaq_9#TF@6^0AeMQFAlS)7h$>wz+=%M@D~buZy_r
z_}7T^%c{L~*;Xz0_CJ37m`~PmST*Bft+oWOLUqRZ@9|qPU-B+d{{`QPi0)c
zDD$^g*H7EtSi+%UAVnK-4BxnlAo-3m@Uh1H+dpjbdAd>#Hx+kkf0sZRZ~?>t#DJCn#g$yHw8pWDKe{U>r=oaM%f4z)FT
z371s8F!kfvb7M)-HDE5gz*A^}nl-NM&ST$T>XxlG8uI@6FpszKpla1XBb)6o|
zoa`@tZus`b)fKdT8%PigEN+e3L+(LQY1iL-vA3svRVzMnNLI_ChkLq_+j45CBKAO_
z&;mYs1IMS@F9ahdTllRorDNgJFjbl{W?vo>+|EI+uruK%-`~BuEM%1=uP@+a
z-BTl(QSaiIGxphB(1}VDseNTroI5j-+g>8w`*c!#@+N=yhssprfmr#exCU!-sWKvPAyCtHp-da#EsaSPMZ*N^ymNM4qST}ah;
z-K9J_2iUXdrG~t#e#D9L)OO5yVrBL=I6!luPx4B+WfBc;GULAy
zko*;S&6mfjz2q~-6LjpPLrlXdGJ=$WIs102{qZH{z>bSGFUcM%`?74>cj}|RfaIS%
zdM34n3RY`GPbm@OS;8r`Q@OuTep)0p!sYVOB;6C2M9GCJCFg=s-o-eWDEDLo!@%Iz
zL;T9iWd7vUxXw&#dabhds=){zud#dGFl0bFLyu
zuP1kA`u1A!q+iErAI;_qnTx+&NAGVv;d?+hd#AZm@AswRHIaYlQam`sFY4xuz8B_G
zUbA*BalJER{jCdPX(6&Z_TP_AcSKAJBxm)j__jQKUh8Kj?@lz{THkvhk-Kc1e5i;EMGSVip(
zmUnfz!OR`5m2foX)hj7K{zE0Hg-E{T(-(OTS-4i}}p3xgSrAhQN_nwS!0qSmDju
z#>U3p5>1|;Ld3QxTh=>|w0=8sXd6a^*M${+v+>d`-F%?!b7s%0PQlJ(r`9bnMjmB2
zb)f51g=^wd`H{@hRY4?IQ|Uf)5umJ1cC_JI@Q$-;zAf3k&gN=zV=Y*cvXD^c;h&Pf
zpFHyA`TF=^@rpw->X$Z>Pgd@lGZlpgcE#H=G`aL`zH2Suql*+nIQkeXHx&1AoP@6{
z=g7MasZ~{m`z%F(
z3{eY{I{oJAf{w4Thmh9riWK;FX{*v+_vT2$e3X5?gmaEJ(z_-8{uXycDC