From c5d7fb9ddc16d9eb68f1018cfb384faf3be301be Mon Sep 17 00:00:00 2001
From: Russell Bryant
Date: Mon, 28 Oct 2024 22:39:21 -0400
Subject: [PATCH 01/23] [Doc] fix third-party model example (#9771)
Signed-off-by: Russell Bryant
---
docs/source/models/adding_model.rst | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst
index ae09259c0756c..c6d88cc38e99b 100644
--- a/docs/source/models/adding_model.rst
+++ b/docs/source/models/adding_model.rst
@@ -133,7 +133,9 @@ If you are running api server with :code:`vllm serve `, you can wrap the e
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
- import runpy
- runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
+
+ if __name__ == '__main__':
+ import runpy
+ runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
Save the above code in a file and run it with :code:`python your_file.py `.
From 7a4df5f200f0943113dd2d9be49cbcae38ad10bb Mon Sep 17 00:00:00 2001
From: Jee Jee Li
Date: Tue, 29 Oct 2024 12:14:07 +0800
Subject: [PATCH 02/23] [Model][LoRA]LoRA support added for Qwen (#9622)
Signed-off-by: Jee Jee Li
---
vllm/lora/models.py | 6 +-
vllm/model_executor/models/qwen.py | 109 ++++++++++++++++++++++++++---
2 files changed, 101 insertions(+), 14 deletions(-)
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index aaadca9a4d16d..d0279f273db7a 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -578,10 +578,10 @@ def _filter_unsupported_mm_module(self, module_name: str) -> bool:
be filtered out.
"""
if self.supports_mm:
- prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
- return (prefix in module_mapping.connector
- or prefix in module_mapping.tower_model)
+ prefix_lst = module_mapping.connector + module_mapping.tower_model
+ return any(
+ [module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _register_packed_modules(self, module_full_name: str) -> None:
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index cd3f7c1b6c4db..0a1b40927e9f9 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -20,7 +20,7 @@
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
-from vllm.config import CacheConfig, MultiModalConfig
+from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
@@ -30,6 +30,7 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
+ ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -39,6 +40,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@@ -46,7 +48,7 @@
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of
-from .interfaces import SupportsMultiModal, SupportsPP
+from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@@ -122,8 +124,8 @@ def __init__(
# Strided linear layer.
assert self._qkv_same_embed_dim, \
'Visual Attention implementation only supports self-attention'
- self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
- self.out_proj = nn.Linear(embed_dim, embed_dim)
+ self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
+ self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(
@@ -133,7 +135,7 @@ def forward(
) -> torch.Tensor:
# query/key/value: [sq, b, h]
sq, b, _ = x.size()
- mixed_x_layer = self.in_proj(x)
+ mixed_x_layer, _ = self.in_proj(x)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
@@ -182,7 +184,7 @@ def forward(
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
- output = self.out_proj(context_layer)
+ output, _ = self.out_proj(context_layer)
return output
@@ -860,11 +862,7 @@ def dummy_data_for_qwen(
return seq_data, mm_data
-@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
-@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
-@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
-@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
-class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
+class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__(
self,
@@ -872,6 +870,7 @@ def __init__(
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
@@ -990,3 +989,91 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
+
+
+class QWenLLM(QWenBaseModel):
+ packed_modules_mapping = {
+ "c_attn": ["c_attn"],
+ "gate_up_proj": [
+ "w2",
+ "w1",
+ ],
+ }
+ # LoRA specific attributes
+ supported_lora_modules = [
+ "c_attn",
+ "gate_up_proj",
+ "c_proj",
+ ]
+
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+
+class QWenVL(QWenBaseModel):
+ packed_modules_mapping = {
+ "c_attn": ["c_attn"],
+ "gate_up_proj": [
+ "w2",
+ "w1",
+ ],
+ }
+ # LoRA specific attributes
+ supported_lora_modules = [
+ "c_attn",
+ "gate_up_proj",
+ "c_proj",
+ # visual module
+ "out_proj",
+ "in_proj",
+ "c_fc",
+ # resampler
+ "kv_proj",
+ ]
+
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="transformer.h",
+ connector="transformer.visual.attn_pool",
+ tower_model="transformer.visual.transformer")
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
+class QWenLMHeadModel(QWenBaseModel):
+ """
+ QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
+ conducive to the current integration logic of LoRA in vLLM. Therefore, it
+ is necessary to separate them.
+ """
+ # Ensure that the LoRA support check passes when the class is not
+ # initialized, but set all these attributes to empty.
+ packed_modules_mapping = {}
+ supported_lora_modules = []
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ def __new__(
+ cls,
+ config: PretrainedConfig,
+ multimodal_config: MultiModalConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ lora_config: Optional[LoRAConfig] = None,
+ ):
+ # Initialize VL
+ if hasattr(config, "visual"):
+ return QWenVL(config, multimodal_config, cache_config,
+ quant_config, lora_config)
+ # Initialize LLM
+ else:
+ return QWenLLM(config, multimodal_config, cache_config,
+ quant_config, lora_config)
From e74f2d448c9b984f6b2c91137c58919441456503 Mon Sep 17 00:00:00 2001
From: Cyrus Leung
Date: Tue, 29 Oct 2024 13:07:57 +0800
Subject: [PATCH 03/23] [Doc] Specify async engine args in docs (#9726)
---
vllm/engine/async_llm_engine.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 1f57aecb6481d..e9848a14cbe17 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -999,6 +999,7 @@ async def generate(
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
+ >>> # note that engine_args here is AsyncEngineArgs instance
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
@@ -1082,6 +1083,7 @@ async def encode(
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
+ >>> # note that engine_args here is AsyncEngineArgs instance
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "input": "What is LLM?",
From eae3d48181b1ad27f132f14df18e8cff203f7552 Mon Sep 17 00:00:00 2001
From: Cyrus Leung
Date: Tue, 29 Oct 2024 13:08:20 +0800
Subject: [PATCH 04/23] [Bugfix] Use temporary directory in registry (#9721)
---
vllm/model_executor/models/registry.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 595a9256f958e..32b9341ae0b93 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -1,4 +1,5 @@
import importlib
+import os
import pickle
import subprocess
import sys
@@ -423,9 +424,13 @@ def is_attention_free_model(self, architectures: Union[str,
def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
- with tempfile.NamedTemporaryFile() as output_file:
+ # NOTE: We use a temporary directory instead of a temporary file to avoid
+ # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
+ with tempfile.TemporaryDirectory() as tempdir:
+ output_filepath = os.path.join(tempdir, "registry_output.tmp")
+
# `cloudpickle` allows pickling lambda functions directly
- input_bytes = cloudpickle.dumps((fn, output_file.name))
+ input_bytes = cloudpickle.dumps((fn, output_filepath))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
@@ -442,7 +447,7 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
raise RuntimeError(f"Error raised in subprocess:\n"
f"{returned.stderr.decode()}") from e
- with open(output_file.name, "rb") as f:
+ with open(output_filepath, "rb") as f:
return pickle.load(f)
From ef7865b4f9013e1d328058091b12b28c4a078e91 Mon Sep 17 00:00:00 2001
From: Zhong Qishuai
Date: Tue, 29 Oct 2024 19:49:47 +0800
Subject: [PATCH 05/23] [Frontend] re-enable multi-modality input in the new
beam search implementation (#9427)
Signed-off-by: Qishuai Ferdinandzhong@gmail.com
---
tests/entrypoints/openai/test_vision.py | 71 +++++++++++++++
vllm/beam_search.py | 9 +-
vllm/engine/protocol.py | 88 ++++++++++++-------
vllm/entrypoints/openai/protocol.py | 4 +-
vllm/entrypoints/openai/serving_chat.py | 7 +-
vllm/entrypoints/openai/serving_completion.py | 10 ++-
vllm/sampling_params.py | 1 +
7 files changed, 150 insertions(+), 40 deletions(-)
diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py
index 8311a5cb3c2d4..68804d6833c73 100644
--- a/tests/entrypoints/openai/test_vision.py
+++ b/tests/entrypoints/openai/test_vision.py
@@ -107,6 +107,42 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+async def test_single_chat_session_image_beamsearch(client: openai.AsyncOpenAI,
+ model_name: str,
+ image_url: str):
+ messages = [{
+ "role":
+ "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": image_url
+ }
+ },
+ {
+ "type": "text",
+ "text": "What's in this image?"
+ },
+ ],
+ }]
+
+ chat_completion = await client.chat.completions.create(
+ model=model_name,
+ messages=messages,
+ n=2,
+ max_tokens=10,
+ logprobs=True,
+ top_logprobs=5,
+ extra_body=dict(use_beam_search=True))
+ assert len(chat_completion.choices) == 2
+ assert chat_completion.choices[
+ 0].message.content != chat_completion.choices[1].message.content
+
+
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@@ -162,6 +198,41 @@ async def test_single_chat_session_image_base64encoded(
assert message.content is not None and len(message.content) >= 0
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
+async def test_single_chat_session_image_base64encoded_beamsearch(
+ client: openai.AsyncOpenAI, model_name: str, image_url: str,
+ base64_encoded_image: Dict[str, str]):
+
+ messages = [{
+ "role":
+ "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url":
+ f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
+ }
+ },
+ {
+ "type": "text",
+ "text": "What's in this image?"
+ },
+ ],
+ }]
+ chat_completion = await client.chat.completions.create(
+ model=model_name,
+ messages=messages,
+ n=2,
+ max_tokens=10,
+ extra_body=dict(use_beam_search=True))
+ assert len(chat_completion.choices) == 2
+ assert chat_completion.choices[
+ 0].message.content != chat_completion.choices[1].message.content
+
+
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
diff --git a/vllm/beam_search.py b/vllm/beam_search.py
index 1b48538734dae..026037e5434d1 100644
--- a/vllm/beam_search.py
+++ b/vllm/beam_search.py
@@ -1,8 +1,11 @@
from dataclasses import dataclass
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from vllm.sequence import Logprob
+if TYPE_CHECKING:
+ from vllm.multimodal import MultiModalDataDict
+
@dataclass
class BeamSearchSequence:
@@ -16,6 +19,10 @@ class BeamSearchSequence:
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
+ finish_reason: Optional[str] = None
+ stop_reason: Union[int, str, None] = None
+ multi_modal_data: Optional["MultiModalDataDict"] = None
+ mm_processor_kwargs: Optional[Dict[str, Any]] = None
@dataclass
diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py
index b00dd136d4a47..6a09361c56865 100644
--- a/vllm/engine/protocol.py
+++ b/vllm/engine/protocol.py
@@ -6,6 +6,7 @@
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
+from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -59,7 +60,8 @@ def generate(
async def beam_search(
self,
- prompt: Union[str, List[int]],
+ prompt: Union[PromptType, List[int]],
+ model_config: ModelConfig,
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
@@ -69,32 +71,40 @@ async def beam_search(
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
+ include_stop_str_in_output = params.include_stop_str_in_output
- tokenizer = await self.get_tokenizer(lora_request=None)
- if isinstance(prompt, str):
- tokenized_prompt = tokenizer.encode(prompt)
- prompt_text = prompt
- else:
- tokenized_prompt = prompt
- prompt_text = None
- tokenized_length = len(tokenized_prompt)
+ tokenizer = await self.get_tokenizer()
+ input_preprocessor = InputPreprocessor(model_config, tokenizer)
+
+ (prompt_text, prompt_token_ids, multi_modal_data,
+ mm_processor_kwargs) = input_preprocessor._extract_prompt_components(
+ prompt,
+ request_id=request_id,
+ )
+ tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)
- beam_search_params = SamplingParams(logprobs=2 * beam_width,
- max_tokens=1,
- temperature=temperature)
+ beam_search_params = SamplingParams(
+ logprobs=2 * beam_width,
+ max_tokens=1,
+ temperature=temperature,
+ )
all_beams = [
- BeamSearchSequence(tokens=tokenized_prompt,
+ BeamSearchSequence(tokens=prompt_token_ids,
+ cum_logprob=0,
logprobs=[],
- cum_logprob=0)
+ multi_modal_data=multi_modal_data,
+ mm_processor_kwargs=mm_processor_kwargs)
]
completed = []
for _ in range(max_tokens):
prompts_batch = [
- TokensPrompt(prompt_token_ids=beam.tokens)
+ TokensPrompt(prompt_token_ids=beam.tokens,
+ multi_modal_data=beam.multi_modal_data,
+ mm_processor_kwargs=beam.mm_processor_kwargs)
for beam in all_beams
]
@@ -120,17 +130,31 @@ async def beam_search(
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
- new_beam = BeamSearchSequence(
- tokens=current_beam.tokens + [token_id],
- logprobs=current_beam.logprobs + [logprobs],
- cum_logprob=current_beam.cum_logprob +
- logprob_obj.logprob)
-
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
- completed.append(new_beam)
+ completed.append(
+ BeamSearchSequence(
+ tokens=current_beam.tokens +
+ [token_id] if include_stop_str_in_output
+ else current_beam.tokens,
+ logprobs=current_beam.logprobs +
+ [logprobs],
+ cum_logprob=current_beam.cum_logprob +
+ logprob_obj.logprob,
+ finish_reason="stop",
+ stop_reason=tokenizer.eos_token_id))
else:
- new_beams.append(new_beam)
+ new_beams.append(
+ BeamSearchSequence(
+ tokens=current_beam.tokens + [token_id],
+ logprobs=current_beam.logprobs +
+ [logprobs],
+ cum_logprob=current_beam.cum_logprob +
+ logprob_obj.logprob,
+ multi_modal_data=current_beam.
+ multi_modal_data,
+ mm_processor_kwargs=current_beam.
+ mm_processor_kwargs))
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
@@ -151,16 +175,18 @@ async def beam_search(
request_id=request_id,
prompt=prompt_text,
outputs=[
- CompletionOutput(
- text=beam.text,
- cumulative_logprob=beam.cum_logprob,
- token_ids=beam.tokens[tokenized_length:],
- index=i,
- logprobs=beam.logprobs,
- ) for (i, beam) in enumerate(best_beams)
+ CompletionOutput(text=beam.text,
+ cumulative_logprob=beam.cum_logprob,
+ token_ids=beam.tokens[tokenized_length:],
+ index=i,
+ logprobs=beam.logprobs,
+ finish_reason=beam.finish_reason if
+ beam.finish_reason is not None else "length",
+ stop_reason=beam.stop_reason)
+ for (i, beam) in enumerate(best_beams)
],
finished=True,
- prompt_token_ids=tokenized_prompt,
+ prompt_token_ids=prompt_token_ids,
prompt_logprobs=None)
yield beam_search_output
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index a212c0d608ddb..7f270a81a7692 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -308,7 +308,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
- )
+ include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
@@ -606,7 +606,7 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
- )
+ include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index cd2883a3b323b..1f951d15a7a32 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -236,9 +236,10 @@ async def create_chat_completion(
if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
- engine_inputs['prompt_token_ids'],
- request_id,
- sampling_params,
+ prompt=engine_inputs,
+ model_config=self.model_config,
+ request_id=request_id,
+ params=sampling_params,
)
else:
result_generator = self.engine_client.generate(
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 56e35950410a0..da521a6012530 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -150,9 +150,13 @@ async def create_completion(
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
- prompt_inputs["prompt_token_ids"],
- request_id_item,
- sampling_params,
+ prompt={
+ "prompt_token_ids":
+ prompt_inputs["prompt_token_ids"]
+ },
+ model_config=self.model_config,
+ request_id=request_id,
+ params=sampling_params,
)
else:
generator = self.engine_client.generate(
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index bac32c991a0e3..5e191c6e715e0 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -500,3 +500,4 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
+ include_stop_str_in_output: bool = False
From 09500f7ddeb974730972fd9284bd93c08a557cf6 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Tue, 29 Oct 2024 20:20:02 +0800
Subject: [PATCH 06/23] [Model] Add BNB quantization support for Mllama (#9720)
---
.../layers/quantization/bitsandbytes.py | 35 ++++++++++++++--
vllm/model_executor/model_loader/loader.py | 19 +++++++--
vllm/model_executor/models/mllama.py | 42 ++++++++++++++++---
3 files changed, 84 insertions(+), 12 deletions(-)
diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py
index faa8d92e83de3..7a039a78f09b8 100644
--- a/vllm/model_executor/layers/quantization/bitsandbytes.py
+++ b/vllm/model_executor/layers/quantization/bitsandbytes.py
@@ -3,6 +3,7 @@
import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
+ UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@@ -23,7 +24,7 @@ def __init__(
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
- llm_int8_skip_modules: Optional[Any] = None,
+ llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 0.0,
) -> None:
@@ -34,11 +35,15 @@ def __init__(
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
- self.llm_int8_skip_modules = llm_int8_skip_modules
+ self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold
def __repr__(self) -> str:
- return "BitsAndBytesConfig"
+ return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
+ f"load_in_4bit={self.load_in_4bit}, "
+ f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
+ f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
+ f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> str:
@@ -102,8 +107,10 @@ def get_safe_value(config, keys, default_value=None):
llm_int8_threshold=llm_int8_threshold)
def get_quant_method(self, layer: torch.nn.Module,
- prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
+ prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
+ if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
+ return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
return None
@@ -111,6 +118,10 @@ def get_scaled_act_names(self) -> List[str]:
return []
+def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
+ return any(module_name in prefix for module_name in llm_int8_skip_modules)
+
+
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
@@ -211,6 +222,11 @@ def _apply_8bit_weight(
from bitsandbytes import MatmulLtState, matmul
original_type = x.dtype
+ original_shape = x.shape
+ reshape_after_matmul = False
+ if x.ndim > 2:
+ x = x.reshape(-1, x.size(-1))
+ reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
@@ -265,6 +281,9 @@ def _apply_8bit_weight(
out = out.to(original_type)
+ if reshape_after_matmul:
+ out = out.view(*original_shape[:-1], out.size(-1))
+
if bias is not None:
out += bias
@@ -282,6 +301,11 @@ def _apply_4bit_weight(
from bitsandbytes import matmul_4bit
original_type = x.dtype
+ original_shape = x.shape
+ reshape_after_matmul = False
+ if x.ndim > 2:
+ x = x.reshape(-1, x.size(-1))
+ reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
@@ -310,6 +334,9 @@ def _apply_4bit_weight(
out = out.to(original_type)
+ if reshape_after_matmul:
+ out = out.view(*original_shape[:-1], out.size(-1))
+
if bias is not None:
out += bias
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 813f58339da37..3cfee13b9fa6e 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -899,6 +899,19 @@ def _get_quantized_weights_iterator(
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
+ def _is_8bit_weight_name(self, weight_name: str):
+ quantized_suffix = {".scb", ".weight_format"}
+ return any(weight_name.lower().endswith(suffix)
+ for suffix in quantized_suffix)
+
+ def _is_4bit_weight_name(self, weight_name: str):
+ quantized_suffix = {
+ "absmax", "quant_map", "nested_absmax", "nested_quant_map",
+ "bitsandbytes"
+ }
+ suffix = weight_name.split(".")[-1]
+ return any(q_suffix in suffix for q_suffix in quantized_suffix)
+
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
@@ -912,7 +925,7 @@ def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
- if not weight_name.endswith((".weight", ".bias")):
+ if self._is_8bit_weight_name(weight_name):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
@@ -932,7 +945,7 @@ def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
- if weight_name.endswith((".weight", ".bias")):
+ if not self._is_4bit_weight_name(weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
@@ -956,7 +969,7 @@ def _parse_quant_state(param_name: str,
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
- if not weight_name.endswith((".weight", ".bias")):
+ if self._is_4bit_weight_name(weight_name):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py
index 44ef49729c969..5cf5272cae878 100644
--- a/vllm/model_executor/models/mllama.py
+++ b/vllm/model_executor/models/mllama.py
@@ -325,7 +325,10 @@ def forward(self, hidden_state: torch.Tensor,
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):
- def __init__(self, config: config_mllama.MllamaVisionConfig):
+ def __init__(self,
+ config: config_mllama.MllamaVisionConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
@@ -341,12 +344,16 @@ def __init__(self, config: config_mllama.MllamaVisionConfig):
self.head_dim,
self.num_heads,
bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
)
def forward(
@@ -393,7 +400,8 @@ def __init__(
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
- self.self_attn = MllamaVisionSdpaAttention(config)
+ self.self_attn = MllamaVisionSdpaAttention(
+ config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
@@ -1002,6 +1010,7 @@ def __init__(
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
+ prefix=f"{prefix}.lm_head",
)
def forward(
@@ -1037,6 +1046,26 @@ def forward(
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
+ # BitandBytes specific attributes
+ default_bitsandbytes_target_modules = [
+ ".gate_proj.",
+ ".down_proj.",
+ ".up_proj.",
+ ".q_proj.",
+ ".k_proj.",
+ ".v_proj.",
+ ".o_proj.",
+ ]
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
+ bitsandbytes_stacked_params_mapping = {
+ # shard_name, weight_name, index
+ "q_proj": ("qkv_proj", 0),
+ "k_proj": ("qkv_proj", 1),
+ "v_proj": ("qkv_proj", 2),
+ "gate_proj": ("gate_up_proj", 0),
+ "up_proj": ("gate_up_proj", 1),
+ }
def __init__(self,
config: config_mllama.MllamaConfig,
@@ -1061,10 +1090,13 @@ def __init__(self,
quant_config=quant_config,
prefix="language_model",
)
- self.multi_modal_projector = nn.Linear(
+ self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
+ quant_config=quant_config,
+ gather_output=True,
+ prefix="multi_modal_projector",
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
@@ -1128,7 +1160,7 @@ def _parse_and_validate_image_input(self, **kwargs: object):
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
- device = self.multi_modal_projector.weight.device
+ device = next(self.multi_modal_projector.parameters()).device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
@@ -1204,7 +1236,7 @@ def get_cross_attention_states(
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
- cross_attention_states = self.multi_modal_projector(
+ cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
From 622b7ab955186f37879208d7a30e9faf985be220 Mon Sep 17 00:00:00 2001
From: wangshuai09 <391746016@qq.com>
Date: Tue, 29 Oct 2024 22:47:44 +0800
Subject: [PATCH 07/23] [Hardware] using current_platform.seed_everything
(#9785)
Signed-off-by: wangshuai09 <391746016@qq.com>
---
benchmarks/kernels/benchmark_layernorm.py | 6 +++---
benchmarks/kernels/benchmark_moe.py | 7 ++++---
.../kernels/benchmark_paged_attention.py | 5 +++--
benchmarks/kernels/benchmark_quant.py | 6 +++---
benchmarks/kernels/benchmark_rope.py | 5 +++--
tests/kernels/test_activation.py | 6 +++---
tests/kernels/test_attention.py | 6 +++---
tests/kernels/test_awq_triton.py | 6 +++---
tests/kernels/test_blocksparse_attention.py | 6 +++---
tests/kernels/test_cache.py | 12 +++++------
tests/kernels/test_causal_conv1d.py | 12 +++++------
tests/kernels/test_flash_attn.py | 6 +++---
tests/kernels/test_flashinfer.py | 10 ++++-----
tests/kernels/test_fp8_quant.py | 8 +++----
tests/kernels/test_gguf.py | 6 +++---
tests/kernels/test_int8_quant.py | 10 ++++-----
tests/kernels/test_layernorm.py | 4 ++--
tests/kernels/test_mamba_ssm.py | 6 +++---
tests/kernels/test_moe.py | 3 +--
tests/kernels/test_pos_encoding.py | 8 +++----
tests/kernels/test_prefix_prefill.py | 7 ++++---
tests/lora/test_layers.py | 4 ++--
tests/lora/test_punica_sizes.py | 10 ++++-----
tests/lora/test_punica_variation.py | 12 +++++------
vllm/model_executor/utils.py | 3 +--
vllm/platforms/interface.py | 14 +++++++++++++
vllm/utils.py | 21 ++-----------------
27 files changed, 104 insertions(+), 105 deletions(-)
diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py
index 92f6053cc6d7e..7acea6087fdfd 100644
--- a/benchmarks/kernels/benchmark_layernorm.py
+++ b/benchmarks/kernels/benchmark_layernorm.py
@@ -3,8 +3,8 @@
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- seed_everything)
+from vllm.platforms import current_platform
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
@@ -16,7 +16,7 @@ def main(num_tokens: int,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device("cuda")
layer = RMSNorm(hidden_size).to(dtype=dtype)
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 4f88e8e6eb1a6..8f538c21f7f7e 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -10,7 +10,8 @@
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
-from vllm.utils import FlexibleArgumentParser, seed_everything
+from vllm.platforms import current_platform
+from vllm.utils import FlexibleArgumentParser
class BenchmarkConfig(TypedDict):
@@ -167,7 +168,7 @@ class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
- seed_everything(seed)
+ current_platform.seed_everything(seed)
self.seed = seed
def benchmark(
@@ -181,7 +182,7 @@ def benchmark(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
- seed_everything(self.seed)
+ current_platform.seed_everything(self.seed)
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
index 87864d038d593..14eef00b855ac 100644
--- a/benchmarks/kernels/benchmark_paged_attention.py
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -5,8 +5,9 @@
import torch
from vllm import _custom_ops as ops
+from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- create_kv_caches_with_random, seed_everything)
+ create_kv_caches_with_random)
NUM_BLOCKS = 1024
PARTITION_SIZE = 512
@@ -28,7 +29,7 @@ def main(
device: str = "cuda",
kv_cache_dtype: Optional[str] = None,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs,
diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py
index 743a5744e8614..1d62483448946 100644
--- a/benchmarks/kernels/benchmark_quant.py
+++ b/benchmarks/kernels/benchmark_quant.py
@@ -3,8 +3,8 @@
import torch
from vllm import _custom_ops as ops
-from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
- seed_everything)
+from vllm.platforms import current_platform
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
@@ -17,7 +17,7 @@ def main(num_tokens: int,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device("cuda")
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py
index 784b1cf9844e4..250d505168d09 100644
--- a/benchmarks/kernels/benchmark_rope.py
+++ b/benchmarks/kernels/benchmark_rope.py
@@ -6,7 +6,8 @@
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
-from vllm.utils import FlexibleArgumentParser, seed_everything
+from vllm.platforms import current_platform
+from vllm.utils import FlexibleArgumentParser
def benchmark_rope_kernels_multi_lora(
@@ -22,7 +23,7 @@ def benchmark_rope_kernels_multi_lora(
max_position: int = 8192,
base: int = 10000,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py
index 0e3d3c3a2e987..057a11746014c 100644
--- a/tests/kernels/test_activation.py
+++ b/tests/kernels/test_activation.py
@@ -8,7 +8,7 @@
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
@@ -37,7 +37,7 @@ def test_act_and_mul(
seed: int,
device: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
@@ -85,7 +85,7 @@ def test_activation(
seed: int,
device: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation[0]()
diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py
index 1604aa4d2d6e5..4ecd0fc1a21ad 100644
--- a/tests/kernels/test_attention.py
+++ b/tests/kernels/test_attention.py
@@ -7,7 +7,7 @@
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
-from vllm.utils import get_max_shared_memory_bytes, seed_everything
+from vllm.utils import get_max_shared_memory_bytes
from .allclose_default import get_default_atol, get_default_rtol
@@ -144,7 +144,7 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
@@ -382,7 +382,7 @@ def test_multi_query_kv_attention(
seed: int,
device: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py
index e95e5bd948212..406a0c8dd8080 100644
--- a/tests/kernels/test_awq_triton.py
+++ b/tests/kernels/test_awq_triton.py
@@ -7,7 +7,7 @@
from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
device = "cuda"
@@ -80,7 +80,7 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
zeros_cols = qweight_cols
zeros_dtype = torch.int32
- seed_everything(0)
+ current_platform.seed_everything(0)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
@@ -134,7 +134,7 @@ def test_gemm(N, K, M, splitK, group_size):
qzeros_rows = scales_rows
qzeros_cols = qweight_cols
- seed_everything(0)
+ current_platform.seed_everything(0)
input = torch.rand((input_rows, input_cols),
dtype=input_dtype,
diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py
index b65efb3abc230..fb601852dd523 100644
--- a/tests/kernels/test_blocksparse_attention.py
+++ b/tests/kernels/test_blocksparse_attention.py
@@ -8,7 +8,7 @@
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.platforms import current_platform
-from vllm.utils import get_max_shared_memory_bytes, seed_everything
+from vllm.utils import get_max_shared_memory_bytes
from .allclose_default import get_default_atol, get_default_rtol
@@ -173,7 +173,7 @@ def test_paged_attention(
blocksparse_block_size: int,
blocksparse_head_sliding_step: int,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
@@ -384,7 +384,7 @@ def test_varlen_blocksparse_attention_prefill(
seed: int,
device: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py
index b0e7097fdfbd4..5b8311a33c361 100644
--- a/tests/kernels/test_cache.py
+++ b/tests/kernels/test_cache.py
@@ -6,7 +6,7 @@
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -56,7 +56,7 @@ def test_copy_blocks(
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
@@ -132,7 +132,7 @@ def test_reshape_and_cache(
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
@@ -224,7 +224,7 @@ def test_reshape_and_cache_flash(
device: str,
kv_cache_dtype: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
# Create a random slot mapping.
@@ -339,7 +339,7 @@ def test_swap_blocks(
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
- seed_everything(seed)
+ current_platform.seed_everything(seed)
src_device = device if direction[0] == "cuda" else 'cpu'
dst_device = device if direction[1] == "cuda" else 'cpu'
@@ -408,7 +408,7 @@ def test_fp8_e4m3_conversion(
seed: int,
device: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
low = -224.0
high = 224.0
diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py
index 277d7e4977d73..96bfe06d74ae5 100644
--- a/tests/kernels/test_causal_conv1d.py
+++ b/tests/kernels/test_causal_conv1d.py
@@ -9,7 +9,7 @@
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
def causal_conv1d_ref(
@@ -70,7 +70,7 @@ def causal_conv1d_update_ref(x,
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
- The conv_state will be updated by copying x to the
+ The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
@@ -161,7 +161,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
- seed_everything(0)
+ current_platform.seed_everything(0)
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
@@ -223,7 +223,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
- seed_everything(0)
+ current_platform.seed_everything(0)
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
@@ -270,7 +270,7 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
rtol, atol = 1e-2, 5e-2
# set seed
- seed_everything(0)
+ current_platform.seed_everything(0)
batch_size = 3
padding = 5 if with_padding else 0
@@ -343,7 +343,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
- seed_everything(0)
+ current_platform.seed_everything(0)
seqlens = []
batch_size = 4
if seqlen < 10:
diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py
index 35c29c5bd1028..a20c73345218f 100644
--- a/tests/kernels/test_flash_attn.py
+++ b/tests/kernels/test_flash_attn.py
@@ -3,7 +3,7 @@
import pytest
import torch
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
@@ -91,7 +91,7 @@ def test_flash_attn_with_paged_kv(
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
- seed_everything(0)
+ current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
@@ -161,7 +161,7 @@ def test_varlen_with_paged_kv(
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
- seed_everything(0)
+ current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py
index 80a388db6530e..a2c8f71665737 100644
--- a/tests/kernels/test_flashinfer.py
+++ b/tests/kernels/test_flashinfer.py
@@ -4,7 +4,7 @@
import pytest
import torch
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
HEAD_SIZES = [128, 256]
@@ -84,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv(
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
- seed_everything(0)
+ current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
@@ -170,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
- seed_everything(0)
+ current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
@@ -268,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
- seed_everything(0)
+ current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
@@ -381,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
) -> None:
# test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda")
- seed_everything(0)
+ current_platform.seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py
index c18f5f468dc5a..ebaaae2321885 100644
--- a/tests/kernels/test_fp8_quant.py
+++ b/tests/kernels/test_fp8_quant.py
@@ -6,7 +6,7 @@
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
from tests.kernels.utils import opcheck
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
@@ -46,7 +46,7 @@ def opcheck_fp8_quant(output,
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans
@@ -76,7 +76,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
@@ -95,7 +95,7 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode()
@pytest.mark.parametrize("seed", SEEDS)
def test_fp8_quant_large(seed: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
hidden_size = 1152 # Smallest hidden_size to reproduce the error
diff --git a/tests/kernels/test_gguf.py b/tests/kernels/test_gguf.py
index 1513fc196153c..893af99ba4977 100644
--- a/tests/kernels/test_gguf.py
+++ b/tests/kernels/test_gguf.py
@@ -7,7 +7,7 @@
from huggingface_hub import snapshot_download
import vllm._custom_ops as ops
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
@@ -75,7 +75,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
- seed_everything(0)
+ current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
@@ -111,7 +111,7 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
@torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
- seed_everything(0)
+ current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py
index 41e103e1d09f9..8db6a0d0d9fa4 100644
--- a/tests/kernels/test_int8_quant.py
+++ b/tests/kernels/test_int8_quant.py
@@ -4,7 +4,7 @@
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
@@ -45,7 +45,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
@torch.inference_mode()
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@@ -68,7 +68,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
@torch.inference_mode()
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
@@ -112,7 +112,7 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@@ -138,7 +138,7 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float, azp: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py
index 382079d472ee9..9dfa2cbe45e94 100644
--- a/tests/kernels/test_layernorm.py
+++ b/tests/kernels/test_layernorm.py
@@ -3,7 +3,7 @@
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
@@ -31,7 +31,7 @@ def test_rms_norm(
seed: int,
device: str,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py
index e92d401368a7b..bf7ff3b5c59b8 100644
--- a/tests/kernels/test_mamba_ssm.py
+++ b/tests/kernels/test_mamba_ssm.py
@@ -8,7 +8,7 @@
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
def selective_state_update_ref(state,
@@ -235,7 +235,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
- seed_everything(0)
+ current_platform.seed_everything(0)
batch_size = 1
dim = 4
dstate = 8
@@ -358,7 +358,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
if torch.version.hip:
atol *= 2
# set seed
- seed_everything(0)
+ current_platform.seed_everything(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py
index 70906ab2187bc..19c3fc1e1fe3a 100644
--- a/tests/kernels/test_moe.py
+++ b/tests/kernels/test_moe.py
@@ -19,7 +19,6 @@
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
-from vllm.utils import seed_everything
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@@ -115,7 +114,7 @@ def test_fused_marlin_moe(
num_bits: int,
is_k_full: bool,
):
- seed_everything(7)
+ current_platform.seed_everything(7)
# Filter act_order
if act_order:
diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py
index 94da00915d40e..b408559cc0b07 100644
--- a/tests/kernels/test_pos_encoding.py
+++ b/tests/kernels/test_pos_encoding.py
@@ -5,7 +5,7 @@
import torch
from vllm.model_executor.layers.rotary_embedding import get_rope
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
@@ -48,7 +48,7 @@ def test_rotary_embedding(
if rotary_dim is None:
rotary_dim = head_size
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
@@ -100,7 +100,7 @@ def test_batched_rotary_embedding(
max_position: int = 8192,
base: int = 10000,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
@@ -160,7 +160,7 @@ def test_batched_rotary_embedding_multi_lora(
max_position: int = 8192,
base: int = 10000,
) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py
index 3181d92562399..a8a187ebaede4 100644
--- a/tests/kernels/test_prefix_prefill.py
+++ b/tests/kernels/test_prefix_prefill.py
@@ -9,7 +9,8 @@
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.prefix_prefill import context_attention_fwd
-from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, seed_everything
+from vllm.platforms import current_platform
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
@@ -39,7 +40,7 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
) -> None:
- seed_everything(0)
+ current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
@@ -234,7 +235,7 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
) -> None:
- seed_everything(0)
+ current_platform.seed_everything(0)
torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process
diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py
index db877219a285c..eb882faf3974a 100644
--- a/tests/lora/test_layers.py
+++ b/tests/lora/test_layers.py
@@ -39,7 +39,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed
-from vllm.utils import seed_everything
+from vllm.platforms import current_platform
from .utils import DummyLoRAManager
@@ -923,7 +923,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None:
dtype = torch.float16
seed = 0
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py
index 41c37a4813c68..e756544d96e98 100644
--- a/tests/lora/test_punica_sizes.py
+++ b/tests/lora/test_punica_sizes.py
@@ -1,5 +1,5 @@
"""
-This script is mainly used to tests various hidden_sizes. We have collected the
+This script is mainly used to tests various hidden_sizes. We have collected the
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
@@ -15,8 +15,8 @@
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
+from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry
-from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
@@ -146,7 +146,7 @@ def test_punica_sgmv(
device: str,
):
torch.set_default_device(device)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
seq_length = 128
(
@@ -239,7 +239,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
seq_length = 1
(
@@ -327,7 +327,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py
index 185da6399a06a..dc0edeb10ef46 100644
--- a/tests/lora/test_punica_variation.py
+++ b/tests/lora/test_punica_variation.py
@@ -1,6 +1,6 @@
"""
-This script is mainly used to test whether trtion kernels can run normally
-under different conditions, including various batches, numbers of LoRA , and
+This script is mainly used to test whether trtion kernels can run normally
+under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
from unittest.mock import patch
@@ -14,8 +14,8 @@
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
+from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry
-from vllm.utils import seed_everything
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
@@ -61,7 +61,7 @@ def test_punica_sgmv(
device: str,
):
torch.set_default_device(device)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
seq_length = 128
(
@@ -154,7 +154,7 @@ def test_punica_bgmv(
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
torch.set_default_device(device)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
seq_length = 1
(
@@ -242,7 +242,7 @@ def test_punica_expand_nslices(
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
torch.set_default_device(device)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
(
diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py
index c27b1cf6ac7b9..39ead08c238ce 100644
--- a/vllm/model_executor/utils.py
+++ b/vllm/model_executor/utils.py
@@ -4,11 +4,10 @@
import torch
from vllm.platforms import current_platform
-from vllm.utils import seed_everything
def set_random_seed(seed: int) -> None:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
def set_weight_attrs(
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 7c933385d6ff6..c3a3e7a284457 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -1,6 +1,8 @@
import enum
+import random
from typing import NamedTuple, Optional, Tuple, Union
+import numpy as np
import torch
@@ -111,6 +113,18 @@ def inference_mode(cls):
"""
return torch.inference_mode(mode=True)
+ @classmethod
+ def seed_everything(cls, seed: int) -> None:
+ """
+ Set the seed of each random module.
+ `torch.manual_seed` will set seed on all devices.
+
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
diff --git a/vllm/utils.py b/vllm/utils.py
index c3f9a6bdd8b80..fea318ebcdf41 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -7,7 +7,6 @@
import inspect
import ipaddress
import os
-import random
import socket
import subprocess
import sys
@@ -331,22 +330,6 @@ def get_cpu_memory() -> int:
return psutil.virtual_memory().total
-def seed_everything(seed: int) -> None:
- """
- Set the seed of each random module.
-
- Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
- """
- random.seed(seed)
- np.random.seed(seed)
-
- if current_platform.is_cuda_alike():
- torch.cuda.manual_seed_all(seed)
-
- if current_platform.is_xpu():
- torch.xpu.manual_seed_all(seed)
-
-
def random_uuid() -> str:
return str(uuid.uuid4().hex)
@@ -643,7 +626,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
@@ -685,7 +668,7 @@ def create_kv_caches_with_random(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
- seed_everything(seed)
+ current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
From 74fc2d77aec13304550bb52b459bd8c6da756d39 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=A7=91=E8=8B=B1?=
Date: Wed, 30 Oct 2024 01:32:56 +0800
Subject: [PATCH 08/23] [Misc] Add metrics for request queue time, forward
time, and execute time (#9659)
---
vllm/config.py | 7 -----
vllm/engine/llm_engine.py | 15 +++++++++
vllm/engine/metrics.py | 60 +++++++++++++++++++++++++++++++-----
vllm/engine/metrics_types.py | 3 ++
4 files changed, 70 insertions(+), 15 deletions(-)
diff --git a/vllm/config.py b/vllm/config.py
index 99a82c8f1b40b..3814e41aeb92d 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1892,13 +1892,6 @@ def __post_init__(self):
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}")
- if ((self.collect_model_forward_time
- or self.collect_model_execute_time)
- and self.otlp_traces_endpoint is None):
- raise ValueError(
- "collect_model_forward_time or collect_model_execute_time "
- "requires --otlp-traces-endpoint to be set.")
-
@dataclass(frozen=True)
class EngineConfig:
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index ede77f04b1db9..60575210c9386 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -1645,6 +1645,9 @@ def _get_stats(self,
# Request stats
# Latency
time_e2e_requests: List[float] = []
+ time_in_queue_requests: List[float] = []
+ model_forward_time_requests: List[float] = []
+ model_execute_time_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
@@ -1738,6 +1741,15 @@ def _get_stats(self,
# Latency timings
time_e2e_requests.append(now -
seq_group.metrics.arrival_time)
+ if seq_group.metrics.time_in_queue is not None:
+ time_in_queue_requests.append(
+ seq_group.metrics.time_in_queue)
+ if seq_group.metrics.model_forward_time is not None:
+ model_forward_time_requests.append(
+ seq_group.metrics.model_forward_time)
+ if seq_group.metrics.model_execute_time is not None:
+ model_execute_time_requests.append(
+ seq_group.metrics.model_execute_time * 1000)
# Metadata
num_prompt_tokens_requests.append(
len(seq_group.prompt_token_ids))
@@ -1795,6 +1807,9 @@ def _get_stats(self,
# Request stats
# Latency
time_e2e_requests=time_e2e_requests,
+ time_in_queue_requests=time_in_queue_requests,
+ model_forward_time_requests=model_forward_time_requests,
+ model_execute_time_requests=model_execute_time_requests,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index a46625eff1e4a..0f5615ff14db1 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -133,7 +133,31 @@ def __init__(self, labelnames: List[str], max_model_len: int):
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
- buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
+ buckets=[
+ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
+ 40.0, 50.0, 60.0
+ ])
+ self.histogram_time_in_queue_request = self._histogram_cls(
+ name="vllm:time_in_queue_requests",
+ documentation=
+ "Histogram of time the request spent in the queue in seconds.",
+ labelnames=labelnames,
+ buckets=[
+ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
+ 40.0, 50.0, 60.0
+ ])
+ self.histogram_model_forward_time_request = self._histogram_cls(
+ name="vllm:model_forward_time_milliseconds",
+ documentation=
+ "Histogram of time spent in the model forward pass in ms.",
+ labelnames=labelnames,
+ buckets=build_1_2_3_5_8_buckets(3000))
+ self.histogram_model_execute_time_request = self._histogram_cls(
+ name="vllm:model_execute_time_milliseconds",
+ documentation=
+ "Histogram of time spent in the model execute function in ms.",
+ labelnames=labelnames,
+ buckets=build_1_2_3_5_8_buckets(3000))
# Metadata
self.histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens",
@@ -299,16 +323,12 @@ def _unregister_vllm_metrics(self) -> None:
pass
-def build_1_2_5_buckets(max_value: int) -> List[int]:
+def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
"""
- Builds a list of buckets with increasing powers of 10 multiplied by
- mantissa values (1, 2, 5) until the value exceeds the specified maximum.
+ Builds a list of buckets with increasing powers of 10 multiplied by
+ mantissa values until the value exceeds the specified maximum.
- Example:
- >>> build_1_2_5_buckets(100)
- [1, 2, 5, 10, 20, 50, 100]
"""
- mantissa_lst = [1, 2, 5]
exponent = 0
buckets: List[int] = []
while True:
@@ -321,6 +341,24 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
exponent += 1
+def build_1_2_5_buckets(max_value: int) -> List[int]:
+ """
+ Example:
+ >>> build_1_2_5_buckets(100)
+ [1, 2, 5, 10, 20, 50, 100]
+ """
+ return build_buckets([1, 2, 5], max_value)
+
+
+def build_1_2_3_5_8_buckets(max_value: int) -> List[int]:
+ """
+ Example:
+ >>> build_1_2_3_5_8_buckets(100)
+ [1, 2, 3, 5, 8, 10, 20, 30, 50, 80, 100]
+ """
+ return build_buckets([1, 2, 3, 5, 8], max_value)
+
+
def local_interval_elapsed(now: float, last_log: float,
local_interval: float) -> bool:
elapsed_time = now - last_log
@@ -486,6 +524,12 @@ def _log_prometheus(self, stats: Stats) -> None:
# Latency
self._log_histogram(self.metrics.histogram_e2e_time_request,
stats.time_e2e_requests)
+ self._log_histogram(self.metrics.histogram_time_in_queue_request,
+ stats.time_in_queue_requests)
+ self._log_histogram(self.metrics.histogram_model_forward_time_request,
+ stats.model_forward_time_requests)
+ self._log_histogram(self.metrics.histogram_model_execute_time_request,
+ stats.model_execute_time_requests)
# Metadata
finished_reason_counter = CollectionsCounter(
stats.finished_reason_requests)
diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py
index e9a5bd3b586be..510dd04bb3e55 100644
--- a/vllm/engine/metrics_types.py
+++ b/vllm/engine/metrics_types.py
@@ -46,6 +46,9 @@ class Stats:
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
+ time_in_queue_requests: List[float]
+ model_forward_time_requests: List[float]
+ model_execute_time_requests: List[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
From 08600ddc685558d8504eb94bbbf382230f6de386 Mon Sep 17 00:00:00 2001
From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com>
Date: Wed, 30 Oct 2024 01:36:59 +0800
Subject: [PATCH 09/23] Fix the log to correct guide user to install modelscope
(#9793)
Signed-off-by: yuze.zyz
---
vllm/transformers_utils/__init__.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py
index 74ca396276c3f..eeec029fc051a 100644
--- a/vllm/transformers_utils/__init__.py
+++ b/vllm/transformers_utils/__init__.py
@@ -9,7 +9,7 @@
if version.parse(modelscope.__version__) <= version.parse('1.18.0'):
raise ImportError(
'Using vLLM with ModelScope needs modelscope>=1.18.1, please '
- 'install by `pip install modelscope>=1.18.1`')
+ 'install by `pip install modelscope -U`')
from modelscope.utils.hf_util import patch_hub
From 0f43387157010bf84da05c68fc5ff366b3252f01 Mon Sep 17 00:00:00 2001
From: Sven Seeberg
Date: Tue, 29 Oct 2024 18:37:59 +0100
Subject: [PATCH 10/23] [Bugfix] Use host argument to bind to interface (#9798)
---
vllm/entrypoints/openai/api_server.py | 2 +-
vllm/entrypoints/openai/cli_args.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index ae44b26a6c55a..afa370a1cb40b 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -541,7 +541,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind(("", args.port))
+ sock.bind((args.host, args.port))
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index a089985ac9758..f4dd9df9587ce 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -77,7 +77,7 @@ def __call__(
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
- default=None,
+ default="0.0.0.0",
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
From 0ce7798f44c586e11c65d59725724eb805086e93 Mon Sep 17 00:00:00 2001
From: yannicks1 <43552841+yannicks1@users.noreply.github.com>
Date: Tue, 29 Oct 2024 18:39:20 +0100
Subject: [PATCH 11/23] [Misc]: Typo fix: Renaming classes (casualLM ->
causalLM) (#9801)
Signed-off-by: Yannick Schnider
---
vllm/model_executor/model_loader/neuron.py | 4 ++--
vllm/model_executor/model_loader/openvino.py | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py
index a9f1e6e88d792..a90fbd648def9 100644
--- a/vllm/model_executor/model_loader/neuron.py
+++ b/vllm/model_executor/model_loader/neuron.py
@@ -37,7 +37,7 @@
}
-class NeuronCasualLM(nn.Module):
+class NeuronCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
@@ -184,7 +184,7 @@ def get_neuron_model(model_config: ModelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
# Create a model instance.
- model = NeuronCasualLM(
+ model = NeuronCausalLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))
diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py
index 8ada2210d0d51..573f2a04895d9 100644
--- a/vllm/model_executor/model_loader/openvino.py
+++ b/vllm/model_executor/model_loader/openvino.py
@@ -95,7 +95,7 @@ def _require_model_export(model_id, revision=None, subfolder=None):
return True
-class OpenVINOCasualLM(nn.Module):
+class OpenVINOCausalLM(nn.Module):
def __init__(
self,
@@ -199,5 +199,5 @@ def get_model(
"be added in the future. If this is important to you, "
"please open an issue on github.")
- return OpenVINOCasualLM(ov_core, model_config, device_config,
+ return OpenVINOCausalLM(ov_core, model_config, device_config,
kv_cache_dtype)
From ac3d748dba446b9a8417fe3005345c12989d8de0 Mon Sep 17 00:00:00 2001
From: Junichi Sato
Date: Wed, 30 Oct 2024 02:40:35 +0900
Subject: [PATCH 12/23] [Model] Add LlamaEmbeddingModel as an embedding
Implementation of LlamaModel (#9806)
---
vllm/model_executor/models/registry.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 32b9341ae0b93..30dfff31f7e48 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -95,6 +95,7 @@
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
+ "LlamaModel": ("llama", "LlamaEmbeddingModel"),
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": (
From ab6f981671c4e5035575f5e5ef6172f4df52e121 Mon Sep 17 00:00:00 2001
From: Michael Goin
Date: Tue, 29 Oct 2024 14:12:43 -0400
Subject: [PATCH 13/23] [CI][Bugfix] Skip chameleon for transformers 4.46.1
(#9808)
---
tests/models/decoder_only/vision_language/test_broadcast.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/models/decoder_only/vision_language/test_broadcast.py b/tests/models/decoder_only/vision_language/test_broadcast.py
index fd7af4a8b0b29..38c4a95de16f4 100644
--- a/tests/models/decoder_only/vision_language/test_broadcast.py
+++ b/tests/models/decoder_only/vision_language/test_broadcast.py
@@ -24,7 +24,7 @@ def test_models(hf_runner, vllm_runner, image_assets,
elif model.startswith("llava-hf/llava-v1.6"):
from .test_llava_next import models, run_test # type: ignore[no-redef]
elif model.startswith("facebook/chameleon"):
- if transformers.__version__.startswith("4.46.0"):
+ if transformers.__version__.startswith("4.46"):
pytest.skip("Model broken in HF, "
"see huggingface/transformers#34379")
from .test_chameleon import models, run_test # type: ignore[no-redef]
From 7585ec996f7ec88735627cb2ab13949226f9bfce Mon Sep 17 00:00:00 2001
From: Russell Bryant
Date: Tue, 29 Oct 2024 15:24:42 -0400
Subject: [PATCH 14/23] [CI/Build] mergify: fix rules for ci/build label
(#9804)
Signed-off-by: Russell Bryant
---
.github/mergify.yml | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)
diff --git a/.github/mergify.yml b/.github/mergify.yml
index 2a3dee7c662d1..1ce5039a061b2 100644
--- a/.github/mergify.yml
+++ b/.github/mergify.yml
@@ -13,13 +13,14 @@ pull_request_rules:
- name: label-ci-build
description: Automatically apply ci/build label
conditions:
- - files~=^\.github/
- - files~=\.buildkite/
- - files~=^cmake/
- - files=CMakeLists.txt
- - files~=^Dockerfile
- - files~=^requirements.*\.txt
- - files=setup.py
+ - or:
+ - files~=^\.github/
+ - files~=\.buildkite/
+ - files~=^cmake/
+ - files=CMakeLists.txt
+ - files~=^Dockerfile
+ - files~=^requirements.*\.txt
+ - files=setup.py
actions:
label:
add:
From 0ad216f5750742115c686723bf38698372d483fd Mon Sep 17 00:00:00 2001
From: Kunjan
Date: Tue, 29 Oct 2024 12:52:19 -0700
Subject: [PATCH 15/23] [MISC] Set label value to timestamp over 0, to keep
track of recent history (#9777)
Signed-off-by: Kunjan Patel
---
vllm/engine/metrics.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py
index 0f5615ff14db1..9ed30e1e99857 100644
--- a/vllm/engine/metrics.py
+++ b/vllm/engine/metrics.py
@@ -1,3 +1,4 @@
+import time
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Type, Union, cast
@@ -253,6 +254,10 @@ def labels(self, **labels):
def set(self, value: Union[int, float]):
return self._gauge.set(value)
+ def set_to_current_time(self):
+ # ray metrics doesn't have set_to_current time, https://docs.ray.io/en/latest/_modules/ray/util/metrics.html
+ return self._gauge.set(time.time())
+
class _RayCounterWrapper:
"""Wraps around ray.util.metrics.Counter to provide same API as
@@ -479,7 +484,7 @@ def _log_histogram(self, histogram, data: Union[List[int],
histogram.labels(**self.labels).observe(datum)
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
- gauge.labels(**data).set(1)
+ gauge.labels(**data).set_to_current_time()
def _log_prometheus(self, stats: Stats) -> None:
# System state data
From 67bdf8e523e4020a559b6d74981936c8156243f9 Mon Sep 17 00:00:00 2001
From: Joe Runde
Date: Tue, 29 Oct 2024 16:13:20 -0500
Subject: [PATCH 16/23] [Bugfix][Frontend] Guard against bad token ids (#9634)
Signed-off-by: Joe Runde
---
.../entrypoints/llm/test_prompt_validation.py | 8 +++-
tests/entrypoints/openai/test_completion.py | 18 ++++-----
.../openai/test_prompt_validation.py | 15 +++++++
vllm/engine/async_llm_engine.py | 15 +++++--
vllm/engine/llm_engine.py | 40 +++++++++++++++++--
vllm/transformers_utils/tokenizer.py | 5 +++
vllm/transformers_utils/tokenizers/mistral.py | 5 +++
7 files changed, 89 insertions(+), 17 deletions(-)
diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py
index 565dfa01346cc..675a980ab3f3f 100644
--- a/tests/entrypoints/llm/test_prompt_validation.py
+++ b/tests/entrypoints/llm/test_prompt_validation.py
@@ -4,6 +4,12 @@
def test_empty_prompt():
- llm = LLM(model="gpt2")
+ llm = LLM(model="gpt2", enforce_eager=True)
with pytest.raises(ValueError, match='Prompt cannot be empty'):
llm.generate([""])
+
+
+def test_out_of_vocab_token():
+ llm = LLM(model="gpt2", enforce_eager=True)
+ with pytest.raises(ValueError, match='out of vocabulary'):
+ llm.generate({"prompt_token_ids": [999999]})
diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py
index f03bdb045f640..c81cfdbbe5cff 100644
--- a/tests/entrypoints/openai/test_completion.py
+++ b/tests/entrypoints/openai/test_completion.py
@@ -157,15 +157,15 @@ async def test_added_lora_tokens(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
# test using token IDs
- completion = await client.completions.create(
- model=MODEL_NAME,
- prompt=[0, 0, 32000, 32001, 32002],
- echo=True,
- max_tokens=5,
- temperature=0.0,
- )
- # Added tokens should not appear in tokenized prompt
- assert "vllm" not in completion.choices[0].text
+ with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
+ # Added tokens should be rejected by the base model
+ await client.completions.create(
+ model=MODEL_NAME,
+ prompt=[0, 0, 32000, 32001, 32002],
+ echo=True,
+ max_tokens=5,
+ temperature=0.0,
+ )
@pytest.mark.asyncio
diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py
index 0a573a0066d32..58075f7023821 100644
--- a/tests/entrypoints/openai/test_prompt_validation.py
+++ b/tests/entrypoints/openai/test_prompt_validation.py
@@ -20,3 +20,18 @@ async def test_empty_prompt():
prompt="",
max_tokens=5,
temperature=0.0)
+
+
+@pytest.mark.asyncio
+async def test_out_of_vocab_token_ids():
+ model_name = "gpt2"
+ server_args = ["--enforce-eager"]
+ with RemoteOpenAIServer(model_name, server_args) as remote_server:
+ client = remote_server.get_async_client()
+
+ with pytest.raises(openai.BadRequestError,
+ match=re.compile('.*out of vocabulary.*')):
+ await client.completions.create(model=model_name,
+ prompt=[999999],
+ max_tokens=5,
+ temperature=0.0)
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index e9848a14cbe17..5198467a6ac40 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -412,6 +412,12 @@ async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
+ async def get_tokenizer_async(self,
+ lora_request: Optional[LoRARequest] = None
+ ) -> AnyTokenizer:
+ return await (
+ self.get_tokenizer_group().get_lora_tokenizer_async(lora_request))
+
@overload # DEPRECATED
async def add_request_async(
self,
@@ -472,6 +478,10 @@ async def add_request_async(
if arrival_time is None:
arrival_time = time.time()
+ if self.tokenizer is not None:
+ tokenizer = await self.get_tokenizer_async(lora_request)
+ self._validate_token_prompt(prompt, tokenizer=tokenizer)
+
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
request_id=request_id,
@@ -488,7 +498,7 @@ async def add_request_async(
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
- tokenizer=self.get_tokenizer(lora_request),
+ tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
@@ -715,8 +725,7 @@ async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
- return await (self.engine.get_tokenizer_group().
- get_lora_tokenizer_async(lora_request))
+ return await self.engine.get_tokenizer_async(lora_request)
def start_background_loop(self) -> None:
"""Start the background loop."""
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 60575210c9386..fde768ed5165e 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -10,7 +10,7 @@
from typing import Set, Type, Union, cast, overload
import torch
-from typing_extensions import TypeVar
+from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@@ -32,7 +32,8 @@
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
- EncoderDecoderInputs, InputRegistry, PromptType)
+ EncoderDecoderInputs, InputRegistry, PromptType,
+ TokensPrompt)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@@ -667,7 +668,7 @@ def _add_processed_request(
)
return None
- self._validate_model_inputs(processed_inputs)
+ self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
@@ -829,6 +830,11 @@ def add_request(
if arrival_time is None:
arrival_time = time.time()
+ if self.tokenizer is not None:
+ self._validate_token_prompt(
+ prompt,
+ tokenizer=self.get_tokenizer(lora_request=lora_request))
+
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
@@ -855,6 +861,31 @@ def add_request(
priority=priority,
)
+ def _validate_token_prompt(self, prompt: PromptType,
+ tokenizer: AnyTokenizer):
+ # Guard against out-of-vocab tokens.
+ # For some tokenizers, tokenizer.decode will happily return empty text
+ # for token ids that are out of vocab, and we don't detect token ids
+ # that are greater than the max token id before running the model.
+ # However, these token ids will later crash a cuda kernel at runtime
+ # with an index out of bounds error. This will crash the entire engine.
+ # This needs to happen before multimodal input pre-processing, which
+ # may add dummy tokens that aren't part of the tokenizer's
+ # vocabulary.
+ if self._is_token_prompt(prompt):
+ prompt_ids = prompt["prompt_token_ids"]
+ if len(prompt_ids) == 0:
+ # Empty prompt check is handled later
+ return
+ max_input_id = max(prompt_ids)
+ if max_input_id > tokenizer.max_token_id:
+ raise ValueError(
+ "Token id {} is out of vocabulary".format(max_input_id))
+
+ @staticmethod
+ def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
+ return isinstance(prompt, dict) and "prompt_token_ids" in prompt
+
def _create_sequence_group_with_sampling(
self,
request_id: str,
@@ -1942,7 +1973,8 @@ def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
- EncoderDecoderInputs]):
+ EncoderDecoderInputs],
+ lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py
index 94af2388d79db..54f9f895fe541 100644
--- a/vllm/transformers_utils/tokenizer.py
+++ b/vllm/transformers_utils/tokenizer.py
@@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
+ max_token_id = max(tokenizer.get_vocab().values())
class CachedTokenizer(tokenizer.__class__): # type: ignore
@@ -50,6 +51,10 @@ def all_special_tokens(self):
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended
+ @property
+ def max_token_id(self):
+ return max_token_id
+
def __len__(self):
return tokenizer_len
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index 23ea657ffb0a9..80e21c2d32ecc 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -85,6 +85,7 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
self.tokenizer = tokenizer_
+ self._max_token_id = max(self._vocab.values())
@classmethod
def from_pretrained(cls,
@@ -158,6 +159,10 @@ def is_fast(self) -> bool:
def vocab_size(self) -> int:
return len(self._vocab)
+ @property
+ def max_token_id(self) -> int:
+ return self._max_token_id
+
def __len__(self) -> int:
return self.vocab_size
From 882a1ad0deb9fd26283db611e78e122ac19fb72f Mon Sep 17 00:00:00 2001
From: Will Eaton
Date: Tue, 29 Oct 2024 18:07:37 -0400
Subject: [PATCH 17/23] [Model] tool calling support for
ibm-granite/granite-20b-functioncalling (#8339)
Signed-off-by: Max de Bayser
Co-authored-by: Max de Bayser
Co-authored-by: Maximilien de Bayser
---
.../serving/openai_compatible_server.md | 21 +-
.../tool_chat_template_granite_20b_fc.jinja | 130 +++++++++
tests/tool_use/utils.py | 12 +
.../openai/tool_parsers/__init__.py | 7 +-
.../granite_20b_fc_tool_parser.py | 251 ++++++++++++++++++
.../openai/tool_parsers/llama_tool_parser.py | 27 +-
vllm/entrypoints/openai/tool_parsers/utils.py | 36 ++-
7 files changed, 456 insertions(+), 28 deletions(-)
create mode 100644 examples/tool_chat_template_granite_20b_fc.jinja
create mode 100644 vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md
index 413c87ab28755..a1f93a9a28578 100644
--- a/docs/source/serving/openai_compatible_server.md
+++ b/docs/source/serving/openai_compatible_server.md
@@ -185,7 +185,9 @@ from HuggingFace; and you can find an example of this in a `tokenizer_config.jso
If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template!
+
#### Hermes Models (`hermes`)
+
All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported.
* `NousResearch/Hermes-2-Pro-*`
* `NousResearch/Hermes-2-Theta-*`
@@ -197,7 +199,9 @@ step in their creation_.
Flags: `--tool-call-parser hermes`
+
#### Mistral Models (`mistral`)
+
Supported models:
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
* Additional mistral function-calling models are compatible as well.
@@ -216,7 +220,9 @@ when tools are provided, that results in much better reliability when working wi
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
+
#### Llama Models (`llama3_json`)
+
Supported models:
* `meta-llama/Meta-Llama-3.1-8B-Instruct`
* `meta-llama/Meta-Llama-3.1-70B-Instruct`
@@ -236,7 +242,9 @@ it works better with vLLM.
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
+
#### InternLM Models (`internlm`)
+
Supported models:
* `internlm/internlm2_5-7b-chat` (confirmed)
* Additional internlm2.5 function-calling models are compatible as well
@@ -246,6 +254,7 @@ Known issues:
Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`
+
#### Jamba Models (`jamba`)
AI21's Jamba-1.5 models are supported.
* `ai21labs/AI21-Jamba-1.5-Mini`
@@ -255,6 +264,16 @@ AI21's Jamba-1.5 models are supported.
Flags: `--tool-call-parser jamba`
+#### IBM Granite (`granite-20b-fc`)
+
+Supported models:
+* `ibm-granite/granite-20b-functioncalling`
+
+Flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`
+
+The example chat template deviates slightly from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
+
+
### How to write a tool parser plugin
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
@@ -312,5 +331,5 @@ Then you can use this plugin in the command line like this.
--tool-parser-plugin
--tool-call-parser example \
--chat-template \
-```
+```
diff --git a/examples/tool_chat_template_granite_20b_fc.jinja b/examples/tool_chat_template_granite_20b_fc.jinja
new file mode 100644
index 0000000000000..cb52188ec72d9
--- /dev/null
+++ b/examples/tool_chat_template_granite_20b_fc.jinja
@@ -0,0 +1,130 @@
+{%- macro json_to_python_type(json_spec) %}
+ {%- set basic_type_map = {
+ "string": "str",
+ "number": "float",
+ "integer": "int",
+ "boolean": "bool"
+} %}
+
+ {%- if basic_type_map[json_spec.type] is defined %}
+ {{- basic_type_map[json_spec.type] }}
+ {%- elif json_spec.type == "array" %}
+ {{- "list[" + json_to_python_type(json_spec|items) + "]" }}
+ {%- elif json_spec.type == "object" %}
+ {%- if json_spec.additionalProperties is defined %}
+ {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
+ {%- else %}
+ {{- "dict" }}
+ {%- endif %}
+ {%- elif json_spec.type is iterable %}
+ {{- "Union[" }}
+ {%- for t in json_spec.type %}
+ {{- json_to_python_type({"type": t}) }}
+ {%- if not loop.last %}
+ {{- "," }}
+ {%- endif %}
+ {%- endfor %}
+ {{- "]" }}
+ {%- else %}
+ {{- "Any" }}
+ {%- endif %}
+{%- endmacro %}
+
+{%- if not full_function_description is defined %}
+ {%- set full_function_description = false %}
+{%- endif %}
+
+{%- macro full_description(tool) %}
+ {{- tool.name + '(' }}
+ {%- if tool.parameters is defined %}
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
+ {{- param_name + ": " + json_to_python_type(param_fields) }}
+ {%- if not loop.last %}
+ {{- ", " }}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {{- ")" }}
+ {%- if tool.return is defined %}
+ {{- " -> " + json_to_python_type(tool.return) }}
+ {%- endif %}
+ {{- " - " + tool.description + "\n\n" }}
+ {%- if tool.parameters is defined %}
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
+ {%- if loop.first %}
+ {{- " Args:\n" }}
+ {%- endif %}
+ {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
+ {%- endfor %}
+ {%- endif %}
+ {%- if tool.return is defined and tool.return.description is defined %}
+ {{- "\n Returns:\n " + tool.return.description }}
+ {%- endif %}
+ {{- '"' }}
+{%- endmacro %}
+
+{%- macro simple_description(tool) %}
+ {{- tool.description }}
+{%- endmacro %}
+
+{%- macro function_description(tool) %}
+ {%- if full_function_description %}
+ {{- full_description(tool) }}
+ {%- else %}
+ {{- simple_description(tool) }}
+ {%- endif %}
+{%- endmacro %}
+
+{%- if messages[0]["role"] == "system" %}
+ {%- set sys_prompt = messages[0]["content"] %}
+ {%- set loop_messages = messages[1:] %}
+{%- else %}
+ {%- set loop_messages = messages %}
+ {% set sys_prompt = 'You are a helpful assistant with access to the following function calls. Your task is to understand the given conversation with function calls and responses and generate natural language response as the ASSISTANT to continue the conversation. You may use the following function calls to understand how to respond to the user query.' %}
+{%- endif %}
+
+{{ 'SYSTEM: ' + sys_prompt }}
+{% if tools is iterable and tools | length > 0 %}
+<|function_call_library|>
+ {%- for tool in tools %}
+ {%- if tool.function is defined %}
+ {%- set tool = tool.function %}
+ {%- endif %}
+ {{- '{"name": "' + tool.name + '", ' }}
+ {{- '"description": "' + function_description(tool) }}
+ {{- ', "parameters": ' }}
+ {%- if not tool.parameters is defined or tool.parameters.properties | length == 0 %}
+ {{- "{}" }}
+ {%- else %}
+ {{- tool.parameters|tojson }}
+ {%- endif %}
+ {{- "}" }}
+ {%- if not loop.last %}
+ {{- "\n" }}
+ {%- endif %}
+ {%- endfor %}
+If none of the functions are relevant or the given question lacks the parameters required by the function, please output \" {\"name\": \"no_function\", \"arguments\": {}}\".
+{%- endif %}
+
+
+
+{% for message in messages %}
+ {% if message['role'] == 'user' %}
+ {{- '\nUSER: ' + message['content'] }}
+ {% elif message['role'] == 'assistant' and message.tool_calls is defined %}
+ {{- '\nASSISTANT:' }}
+ {% for tc in message.tool_calls %}
+ {{- ' ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }}
+ {% endfor %}
+ {{- '<|endoftext|>' }}
+ {% elif message['role'] == 'assistant' %}
+ {{- '\nASSISTANT: ' + message['content'] + ' <|endoftext|>' }}
+ {% elif message['role'] == 'tool' %}
+ {{- ' ' + message['content'] }}
+ {%- else %}
+ {{- raise_exception("Unexpected combination of role and message content") }}
+ {% endif %}
+ {% if loop.last and add_generation_prompt %}
+ {{- '\nASSISTANT: ' }}
+ {% endif %}
+{% endfor %}
diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py
index ce36515a2381c..d9ee0b1d54b0a 100644
--- a/tests/tool_use/utils.py
+++ b/tests/tool_use/utils.py
@@ -88,6 +88,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
+ ## FIXME: temporary disabled due to lack of hardware specification
+ ## for individual runs
+ #"granite20b": {
+ # "model":
+ # "ibm-granite/granite-20b-functioncalling",
+ # "arguments": [
+ # "--tool-call-parser", "granite-20b-fc", "--chat-template",
+ # str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja")
+ # ],
+ # "supports_parallel":
+ # False,
+ #},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 0e88bb21ca75f..1b299ce655570 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -1,4 +1,5 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
+from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .hermes_tool_parser import Hermes2ProToolParser
from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
@@ -6,7 +7,7 @@
from .mistral_tool_parser import MistralToolParser
__all__ = [
- "ToolParser", "ToolParserManager", "Hermes2ProToolParser",
- "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser",
- "JambaToolParser"
+ "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
+ "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser",
+ "Llama3JsonToolParser", "JambaToolParser"
]
diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
new file mode 100644
index 0000000000000..94db8f379e33a
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
@@ -0,0 +1,251 @@
+import json
+import re
+from json import JSONDecoder
+from typing import Dict, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
+ find_common_prefix,
+ is_complete_json,
+ partial_json_loads)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("granite-20b-fc")
+class Granite20bFCToolParser(ToolParser):
+ """
+ Tool call parser for the granite-20b-functioncalling model intended
+ for use with the examples/tool_chat_template_granite20b_fc.jinja
+ template.
+
+ Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
+ are all set
+ """
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ self.bot_token = ""
+ self.tool_start_token = self.bot_token
+ self.tool_call_regex = re.compile(r"\s*")
+
+ def extract_tool_calls(
+ self, model_output: str,
+ request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ if self.tool_start_token not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ dec = JSONDecoder()
+ try:
+ matches = list(self.tool_call_regex.finditer(model_output))
+ logger.debug("Found %d tool call matches", len(matches))
+
+ raw_function_calls = []
+
+ for i, match in enumerate(matches):
+ # position after the tag
+ start_of_json = match.end()
+ # end_index == the start of the next function call
+ # (if exists)
+ next_function_call_start = (matches[i + 1].start()
+ if i + 1 < len(matches) else None)
+
+ raw_function_calls.append(
+ dec.raw_decode(
+ model_output[start_of_json:next_function_call_start])
+ [0])
+
+ logger.debug("Extracted %d tool calls", len(raw_function_calls))
+ tool_calls = [
+ ToolCall(
+ type="function",
+ function=FunctionCall(
+ name=function_call["name"],
+ # function call args are JSON but as a string
+ arguments=json.dumps(function_call["arguments"]),
+ ),
+ ) for function_call in raw_function_calls
+ ]
+
+ content = model_output[:model_output.find(self.bot_token)]
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None,
+ )
+
+ except Exception as e:
+ logger.error("Error in extracting tool call from response %s", e)
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ if len(current_text) < len(
+ self.bot_token) and self.bot_token.startswith(current_text):
+ return None
+
+ if not current_text.startswith(self.bot_token):
+ return DeltaMessage(content=delta_text)
+
+ # bit mask flags for partial JSON parsing. If the name hasn't been
+ # sent yet, don't allow sending
+ # an incomplete string since OpenAI only ever (as far as I have
+ # seen) allows sending the entire tool/ function name at once.
+ flags = Allow.ALL if self.current_tool_name_sent \
+ else Allow.ALL & ~Allow.STR
+ try:
+ tool_call_arr = []
+ is_complete = []
+ try:
+ start_idx = len(self.bot_token)
+ start_idx = consume_space(start_idx, current_text)
+
+ while start_idx < len(current_text):
+ (obj,
+ end_idx) = partial_json_loads(current_text[start_idx:],
+ flags)
+ is_complete.append(
+ is_complete_json(current_text[start_idx:start_idx +
+ end_idx]))
+ start_idx += end_idx
+ start_idx = consume_space(start_idx, current_text)
+ start_idx += len(self.bot_token)
+ start_idx = consume_space(start_idx, current_text)
+ tool_call_arr.append(obj)
+ except partial_json_parser.core.exceptions.MalformedJSON:
+ logger.debug('not enough tokens to parse into JSON yet')
+ return None
+
+ # select as the current tool call the one we're on the state at
+ current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
+ if len(tool_call_arr) > 0 else {}
+
+ # case -- if no tokens have been streamed for the tool, e.g.
+ # only the array brackets, stream nothing
+ if len(tool_call_arr) == 0:
+ return None
+
+ # case: we are starting a new tool in the array
+ # -> array has > 0 length AND length has moved past cursor
+ elif (len(tool_call_arr) > 0
+ and len(tool_call_arr) > self.current_tool_id + 1):
+
+ # if we're moving on to a new call, first make sure we
+ # haven't missed anything in the previous one that was
+ # auto-generated due to JSON completions, but wasn't
+ # streamed to the client yet.
+ if self.current_tool_id >= 0:
+ cur_arguments = current_tool_call.get("arguments")
+ if cur_arguments:
+ cur_args_json = json.dumps(cur_arguments)
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ argument_diff = cur_args_json[sent:]
+
+ logger.debug("got arguments diff: %s", argument_diff)
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+ else:
+ delta = None
+ else:
+ delta = None
+ # re-set stuff pertaining to progress in the current tool
+ self.current_tool_id = len(tool_call_arr) - 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ logger.debug("starting on new tool %d", self.current_tool_id)
+ return delta
+
+ # if the current tool name hasn't been sent, send if available
+ # - otherwise send nothing
+ elif not self.current_tool_name_sent:
+ function_name = current_tool_call.get("name")
+ if function_name:
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True))
+ ])
+ self.current_tool_name_sent = True
+ else:
+ delta = None
+
+ # now we know we're on the same tool call and we're streaming
+ # arguments
+ else:
+ cur_arguments = current_tool_call.get("arguments")
+ delta = None
+
+ if cur_arguments:
+ sent = len(
+ self.streamed_args_for_tool[self.current_tool_id])
+ cur_args_json = json.dumps(cur_arguments)
+ prev_arguments = self.prev_tool_call_arr[
+ self.current_tool_id].get("arguments")
+
+ argument_diff = None
+ if is_complete[self.current_tool_id]:
+ argument_diff = cur_args_json[sent:]
+ elif prev_arguments:
+ prev_args_json = json.dumps(prev_arguments)
+ if cur_args_json != prev_args_json:
+
+ prefix = find_common_prefix(
+ prev_args_json, cur_args_json)
+ argument_diff = prefix[sent:]
+
+ if argument_diff is not None:
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=argument_diff).
+ model_dump(exclude_none=True))
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] += argument_diff
+
+ self.prev_tool_call_arr = tool_call_arr
+ return delta
+
+ except Exception as e:
+ logger.error("Error trying to handle streaming tool call: %s", e)
+ logger.debug(
+ "Skipping chunk as a result of tool streaming extraction "
+ "error")
+ return None
diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
index 1b836a687a1c3..a5f44d69e5fd2 100644
--- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
@@ -1,6 +1,6 @@
import json
import re
-from json import JSONDecodeError, JSONDecoder
+from json import JSONDecoder
from typing import Dict, List, Sequence, Union
import partial_json_parser
@@ -14,34 +14,15 @@
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
-from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
+from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
+ is_complete_json,
+ partial_json_loads)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
-# partial_json_parser doesn't support extra data and
-# JSONDecorder.raw_decode doesn't support partial JSON
-def partial_json_loads(input_str, flags):
- try:
- return (partial_json_parser.loads(input_str, flags), len(input_str))
- except JSONDecodeError as e:
- if "Extra data" in e.msg:
- dec = JSONDecoder()
- return dec.raw_decode(input_str)
- else:
- raise
-
-
-def is_complete_json(input_str):
- try:
- json.loads(input_str)
- return True
- except JSONDecodeError:
- return False
-
-
@ToolParserManager.register_module("llama3_json")
class Llama3JsonToolParser(ToolParser):
"""
diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py
index db7fc5259fc4e..5e4eb23bfaf43 100644
--- a/vllm/entrypoints/openai/tool_parsers/utils.py
+++ b/vllm/entrypoints/openai/tool_parsers/utils.py
@@ -1,3 +1,11 @@
+import json
+from json import JSONDecodeError, JSONDecoder
+from typing import Any, List, Tuple
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
@@ -72,7 +80,7 @@ def extract_intermediate_diff(curr: str, old: str) -> str:
return diff
-def find_all_indices(string, substring):
+def find_all_indices(string: str, substring: str) -> List[int]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
@@ -85,3 +93,29 @@ def find_all_indices(string, substring):
break
indices.append(index)
return indices
+
+
+# partial_json_parser doesn't support extra data and
+# JSONDecorder.raw_decode doesn't support partial JSON
+def partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
+ try:
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
+ except JSONDecodeError as e:
+ if "Extra data" in e.msg:
+ dec = JSONDecoder()
+ return dec.raw_decode(input_str)
+ raise
+
+
+def is_complete_json(input_str: str) -> bool:
+ try:
+ json.loads(input_str)
+ return True
+ except JSONDecodeError:
+ return False
+
+
+def consume_space(i: int, s: str) -> int:
+ while i < len(s) and s[i].isspace():
+ i += 1
+ return i
From 8d7724104aef4381cf268de094360f27ff68f4ab Mon Sep 17 00:00:00 2001
From: Simon Mo
Date: Tue, 29 Oct 2024 15:19:02 -0700
Subject: [PATCH 18/23] [Docs] Add notes about Snowflake Meetup (#9814)
Signed-off-by: simon-mo
---
README.md | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 0836d872358fb..8c8d6eb291cea 100644
--- a/README.md
+++ b/README.md
@@ -13,9 +13,19 @@ Easy, fast, and cheap LLM serving for everyone
| Documentation | Blog | Paper | Discord | Twitter/X | Developer Slack |
+---
+
+**vLLM x Snowfkale Meetup (Wednesday, November 13th, 5:30-8PM PT) at Snowfkale HQ, San Mateo**
+
+We are excited to announce the last in-person vLLM meetup of the year!
+Join the vLLM developers and engineers from Snowflake AI Research to chat about the latest LLM inference optimizations and your 2025 vLLM wishlist!
+Register [here](https://lu.ma/h0qvrajz) and be a part of the event!
+
+---
+
*Latest News* 🔥
-- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
+- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users!
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
@@ -42,7 +52,7 @@ vLLM is fast with:
- Speculative decoding
- Chunked prefill
-**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
+**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
vLLM is flexible and easy to use with:
From bc73e9821cb4f90a88c04e7d550f132d8911266b Mon Sep 17 00:00:00 2001
From: Michael Goin
Date: Tue, 29 Oct 2024 19:02:59 -0400
Subject: [PATCH 19/23] [Bugfix] Fix prefix strings for quantized VLMs (#9772)
---
vllm/model_executor/model_loader/loader.py | 11 +++-
vllm/model_executor/models/blip2.py | 5 +-
vllm/model_executor/models/gemma.py | 58 +++++++++++++------
vllm/model_executor/models/internlm2.py | 56 ++++++++++++------
vllm/model_executor/models/internlm2_ve.py | 16 +++--
vllm/model_executor/models/internvl.py | 5 +-
vllm/model_executor/models/llama.py | 7 ++-
vllm/model_executor/models/llava.py | 20 +++++--
vllm/model_executor/models/llava_next.py | 10 +++-
.../model_executor/models/llava_next_video.py | 10 +++-
vllm/model_executor/models/llava_onevision.py | 10 +++-
vllm/model_executor/models/minicpmv.py | 34 ++++++++---
vllm/model_executor/models/opt.py | 34 ++++++++---
vllm/model_executor/models/paligemma.py | 7 ++-
vllm/model_executor/models/phi3v.py | 19 ++++--
vllm/model_executor/models/pixtral.py | 5 +-
vllm/model_executor/models/qwen2.py | 50 +++++++++++-----
vllm/model_executor/models/qwen2_vl.py | 8 ++-
vllm/model_executor/models/ultravox.py | 5 +-
vllm/model_executor/models/utils.py | 15 +++++
20 files changed, 288 insertions(+), 97 deletions(-)
diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 3cfee13b9fa6e..3ae8a51859f70 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -147,15 +147,20 @@ def _get_model_initialization_kwargs(
return extra_kwargs
-def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
+def build_model(model_class: Type[nn.Module],
+ hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
- quant_config: Optional[QuantizationConfig], *,
+ quant_config: Optional[QuantizationConfig],
+ *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
- scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
+ scheduler_config: Optional[SchedulerConfig],
+ prefix: Optional[str] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
+ if prefix:
+ extra_kwargs["prefix"] = prefix
return model_class(config=hf_config,
cache_config=cache_config,
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index cd2013e91514d..c3b3cc8a4ddb6 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -507,7 +507,10 @@ def __init__(self,
)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 436bd45d53f35..57b2b43c82f89 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -43,7 +43,8 @@
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory, make_layers)
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
logger = init_logger(__name__)
@@ -83,16 +84,23 @@ def __init__(
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
- hidden_size, [intermediate_size] * 2,
+ hidden_size,
+ [intermediate_size] * 2,
bias=False,
- quant_config=quant_config)
- self.down_proj = RowParallelLinear(intermediate_size,
- hidden_size,
- bias=False,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
+ )
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
def forward(self, x):
@@ -104,15 +112,18 @@ def forward(self, x):
class GemmaAttention(nn.Module):
- def __init__(self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- max_position_embeddings: int = 8192,
- rope_theta: float = 10000,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None) -> None:
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ max_position_embeddings: int = 8192,
+ rope_theta: float = 10000,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -142,12 +153,14 @@ def __init__(self,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
@@ -186,6 +199,7 @@ def __init__(
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -198,6 +212,7 @@ def __init__(
rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
)
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
@@ -205,6 +220,7 @@ def __init__(
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -259,8 +275,8 @@ def __init__(
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
- lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
- ),
+ lambda prefix: GemmaDecoderLayer(
+ config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -366,6 +382,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
@@ -375,7 +392,10 @@ def __init__(
self.lora_config = lora_config
self.quant_config = quant_config
- self.model = GemmaModel(config, cache_config, quant_config)
+ self.model = GemmaModel(config,
+ cache_config,
+ quant_config,
+ prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index 9a77e48626ca5..313d98b649b48 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -30,7 +30,8 @@
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory, make_layers)
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
class InternLM2MLP(nn.Module):
@@ -41,16 +42,23 @@ def __init__(
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
- hidden_size, [intermediate_size] * 2,
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.w2 = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
bias=False,
- quant_config=quant_config)
- self.w2 = RowParallelLinear(intermediate_size,
- hidden_size,
- bias=False,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.w2",
+ )
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -75,6 +83,7 @@ def __init__(
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -108,12 +117,14 @@ def __init__(
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
+ prefix=f"{prefix}.wqkv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
+ prefix=f"{prefix}.wo",
)
self.rotary_emb = get_rope(
@@ -123,12 +134,15 @@ def __init__(
base=rope_theta,
rope_scaling=rope_scaling,
)
- self.attn = Attention(self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config)
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
def split_qkv(self, qkv: torch.Tensor):
seq_len = qkv.shape[0]
@@ -176,6 +190,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -192,12 +207,14 @@ def __init__(
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
+ prefix=f"{prefix}.attention",
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
+ prefix=f"{prefix}.feed_forward",
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -251,8 +268,8 @@ def __init__(
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
- lambda prefix: InternLMDecoderLayer(config, cache_config,
- quant_config),
+ lambda prefix: InternLMDecoderLayer(
+ config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
@@ -306,14 +323,19 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
- self.model = InternLM2Model(config, cache_config, quant_config)
+ self.model = InternLM2Model(config,
+ cache_config,
+ quant_config,
+ prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "output"))
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py
index 6effd70b75da3..edd867e4b6457 100644
--- a/vllm/model_executor/models/internlm2_ve.py
+++ b/vllm/model_executor/models/internlm2_ve.py
@@ -15,7 +15,7 @@
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
-from .utils import make_layers
+from .utils import make_layers, maybe_prefix
class InternLM2VEDecoderLayer(nn.Module):
@@ -25,6 +25,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -41,18 +42,21 @@ def __init__(
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
+ prefix=f"{prefix}.attention",
)
self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
+ prefix=f"{prefix}.feed_forward",
)
self.feed_forward_ve = InternLM2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
+ prefix=f"{prefix}.feed_forward_ve",
)
self.attention_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -111,8 +115,8 @@ def __init__(
super().__init__(config, cache_config, quant_config)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
- lambda prefix: InternLM2VEDecoderLayer(config, cache_config,
- quant_config),
+ lambda prefix: InternLM2VEDecoderLayer(
+ config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def forward(
@@ -161,6 +165,10 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config)
- self.model = InternLM2VEModel(config, cache_config, quant_config)
+ self.model = InternLM2VEModel(config,
+ cache_config,
+ quant_config,
+ prefix=maybe_prefix(prefix, "model"))
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index 3ae37d9fe5d85..1c1fde5b30983 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -439,7 +439,10 @@ def __init__(self,
)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
self.mlp1 = self._init_mlp1(config)
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index b0ca1fe006239..98c53bdaae811 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -55,7 +55,8 @@
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory, make_layers)
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
class LlamaMLP(nn.Module):
@@ -500,6 +501,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
@@ -510,7 +512,7 @@ def __init__(
cache_config,
quant_config,
lora_config=lora_config,
- prefix="model")
+ prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
@@ -526,6 +528,7 @@ def __init__(
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index b005d83c17f90..eda99c029881f 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -210,6 +210,7 @@ def init_vision_tower_for_llava(
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
+ prefix: str = "",
):
vision_config = hf_config.vision_config
@@ -224,23 +225,26 @@ def init_vision_tower_for_llava(
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
- quant_config,
+ quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
+ prefix=prefix,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
- quant_config,
+ quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
+ prefix=prefix,
)
elif isinstance(vision_config, PixtralVisionConfig):
return PixtralHFVisionModel(
vision_config,
- quant_config,
+ quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
+ prefix=prefix,
)
msg = f"Unsupported vision config: {type(vision_config)}"
@@ -274,14 +278,20 @@ def __init__(self,
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
- config, quant_config, require_post_norm=False)
+ config,
+ quant_config,
+ require_post_norm=False,
+ prefix="vision_tower")
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index 2a582deeaa2c9..f85129b206919 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -293,7 +293,10 @@ def __init__(self,
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
- config, quant_config, require_post_norm=False)
+ config,
+ quant_config,
+ require_post_norm=False,
+ prefix="vision_tower")
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
@@ -302,7 +305,10 @@ def __init__(self,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
# The same model class supports both language generation and embedding
# because the architecture name is the same
diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py
index 43eec43d56643..b8051d5fc6ae2 100644
--- a/vllm/model_executor/models/llava_next_video.py
+++ b/vllm/model_executor/models/llava_next_video.py
@@ -257,14 +257,20 @@ def __init__(self,
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
- config, quant_config, require_post_norm=False)
+ config,
+ quant_config,
+ require_post_norm=False,
+ prefix="vision_tower")
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index 9606b126141df..a0cf208a65f36 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -415,10 +415,16 @@ def __init__(self,
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
- config, quant_config, require_post_norm=False)
+ config,
+ quant_config,
+ require_post_norm=False,
+ prefix="vision_tower")
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 2ec51dc4647f5..a270282d87bc8 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -394,8 +394,11 @@ def __init__(
self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config)
- self.llm = self.init_llm(config, cache_config, quant_config)
- self.vpm = self.init_vision_module(config, quant_config)
+ self.llm = self.init_llm(config,
+ cache_config,
+ quant_config,
+ prefix="llm")
+ self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@@ -403,9 +406,11 @@ def __init__(
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.resampler.to(device="cuda", dtype=param_dtype)
+ # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix="llm.lm_head")
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@@ -644,6 +649,7 @@ def init_llm(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> nn.Module:
raise NotImplementedError
@@ -651,6 +657,7 @@ def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
) -> nn.Module:
raise NotImplementedError
@@ -690,17 +697,20 @@ def init_llm(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
- quant_config=quant_config),
+ quant_config=quant_config,
+ prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
) -> nn.Module:
# TODO :refactor this vision model
try:
@@ -819,19 +829,23 @@ def init_llm(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(config,
cache_config=cache_config,
- quant_config=quant_config),
+ quant_config=quant_config,
+ prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
@@ -935,20 +949,24 @@ def init_llm(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
- quant_config=quant_config),
+ quant_config=quant_config,
+ prefix=prefix),
name="model")
def init_vision_module(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
+ prefix: str = "",
) -> nn.Module:
model = Idefics2VisionTransformer(config.vision_config,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=prefix)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 37c3fa919124e..10cca8b56268a 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -43,7 +43,8 @@
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory, make_layers)
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -68,6 +69,7 @@ def __init__(
bias: bool = True,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.embed_dim = embed_dim
@@ -85,18 +87,21 @@ def __init__(
total_num_heads,
bias=bias,
quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
+ prefix=f"{prefix}.out_proj",
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
def forward(
self,
@@ -118,6 +123,7 @@ def __init__(
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
):
super().__init__()
self.config = config
@@ -128,6 +134,7 @@ def __init__(
bias=config.enable_bias,
cache_config=cache_config,
quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
)
self.do_layer_norm_before = config.do_layer_norm_before
@@ -139,6 +146,7 @@ def __init__(
config.ffn_dim,
bias=config.enable_bias,
quant_config=quant_config,
+ prefix=f"{prefix}.fc1",
)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
@@ -147,6 +155,7 @@ def __init__(
self.embed_dim,
bias=config.enable_bias,
quant_config=quant_config,
+ prefix=f"{prefix}.fc2",
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
@@ -214,7 +223,8 @@ def __init__(
self.project_out = ReplicatedLinear(config.hidden_size,
config.word_embed_proj_dim,
bias=False,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.project_out")
else:
self.project_out = None
@@ -222,7 +232,8 @@ def __init__(
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
config.hidden_size,
bias=False,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.project_in")
else:
self.project_in = None
@@ -239,7 +250,8 @@ def __init__(
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
- lambda prefix: OPTDecoderLayer(config, cache_config, quant_config),
+ lambda prefix: OPTDecoderLayer(
+ config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -288,9 +300,13 @@ def __init__(
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
):
super().__init__()
- self.decoder = OPTDecoder(config, cache_config, quant_config)
+ self.decoder = OPTDecoder(config,
+ cache_config,
+ quant_config,
+ prefix=f"{prefix}.decoder")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
@@ -335,11 +351,15 @@ def __init__(
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
):
super().__init__()
self.config = config
self.quant_config = quant_config
- self.model = OPTModel(config, cache_config, quant_config)
+ self.model = OPTModel(config,
+ cache_config,
+ quant_config,
+ prefix=maybe_prefix(prefix, "model"))
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
index 7a62a098a4525..8e29c6079b994 100644
--- a/vllm/model_executor/models/paligemma.py
+++ b/vllm/model_executor/models/paligemma.py
@@ -143,14 +143,17 @@ def __init__(self,
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(config.vision_config,
- quant_config)
+ quant_config,
+ prefix="vision_tower")
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config
self.language_model = GemmaForCausalLM(config.text_config,
- cache_config, quant_config)
+ cache_config,
+ quant_config,
+ prefix="language_model")
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index 855a9b17585a4..0962d3d3847c9 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -71,7 +71,8 @@
def _init_img_processor(hf_config: PretrainedConfig,
- quant_config: Optional[QuantizationConfig]):
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "") -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2)
@@ -86,6 +87,7 @@ def _init_img_processor(hf_config: PretrainedConfig,
clip_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
+ prefix=prefix,
)
return img_processor
@@ -152,15 +154,18 @@ def get_img_features(self,
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
- def __init__(self, config: PretrainedConfig,
- quant_config: Optional[QuantizationConfig]) -> None:
+ def __init__(self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig],
+ prefix: str = "") -> None:
super().__init__()
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
- self.img_processor = _init_img_processor(config, quant_config)
+ self.img_processor = _init_img_processor(
+ config, quant_config, prefix=f"{prefix}.img_processor")
image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens']
@@ -537,11 +542,15 @@ def __init__(self,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
+ prefix="model.embed_tokens",
)
# TODO: Optionally initializes this for supporting input embeddings.
- self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
+ self.vision_embed_tokens = Phi3HDImageEmbedding(
+ config, quant_config, prefix="model.vision_embed_tokens")
+ # The prefix is empty intentionally because default prefix of
+ # LlamaForCausalLM is "model"
self.language_model = LlamaForCausalLM(config, cache_config,
quant_config)
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index a9dbb3823743a..6b53bf5660096 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -164,7 +164,10 @@ def __init__(self,
# init MistralForCausalLM
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 23eb1482ffef1..db1029345a8ac 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -49,7 +49,8 @@
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
- make_empty_intermediate_tensors_factory, make_layers)
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
class Qwen2MLP(nn.Module):
@@ -60,16 +61,23 @@ def __init__(
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
- hidden_size, [intermediate_size] * 2,
+ hidden_size,
+ [intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ intermediate_size,
+ hidden_size,
bias=False,
- quant_config=quant_config)
- self.down_proj = RowParallelLinear(intermediate_size,
- hidden_size,
- bias=False,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
+ )
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@@ -92,7 +100,8 @@ def __init__(self,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
- rope_scaling: Optional[Tuple] = None) -> None:
+ rope_scaling: Optional[Tuple] = None,
+ prefix: str = "") -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -122,12 +131,14 @@ def __init__(self,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
@@ -142,7 +153,8 @@ def __init__(self,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
def forward(
self,
@@ -166,6 +178,7 @@ def __init__(
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -180,12 +193,15 @@ def __init__(
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
- rope_scaling=rope_scaling)
+ rope_scaling=rope_scaling,
+ prefix=f"{prefix}.self_attn",
+ )
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@@ -241,6 +257,7 @@ def __init__(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
@@ -249,7 +266,8 @@ def __init__(
config.num_hidden_layers,
lambda prefix: Qwen2DecoderLayer(config=config,
cache_config=cache_config,
- quant_config=quant_config),
+ quant_config=quant_config,
+ prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
)
@@ -393,6 +411,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
+ prefix: str = "",
) -> None:
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
@@ -412,14 +431,19 @@ def __init__(
self.lora_config = lora_config
self.quant_config = quant_config
- self.model = Qwen2Model(config, cache_config, quant_config)
+ self.model = Qwen2Model(config,
+ cache_config,
+ quant_config,
+ prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix=maybe_prefix(
+ prefix, "lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 4e60fe70b25f1..633d66b4af31a 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -938,7 +938,10 @@ def __init__(self,
quant_config=None,
)
- self.model = Qwen2Model(config, cache_config, quant_config)
+ self.model = Qwen2Model(config,
+ cache_config,
+ quant_config,
+ prefix="model")
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
@@ -946,7 +949,8 @@ def __init__(self,
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
- quant_config=quant_config)
+ quant_config=quant_config,
+ prefix="lm_head")
else:
self.lm_head = PPMissingLayer()
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index 5f33b872beecb..f08e4aa355086 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -357,7 +357,10 @@ def __init__(self,
))
self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model(
- config.text_config, cache_config, quant_config)
+ config.text_config,
+ cache_config,
+ quant_config,
+ prefix="language_model")
if config.text_model_id is not None:
self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id,
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index 6995f5805c5e1..0aecb5d151a45 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -242,6 +242,7 @@ def init_vllm_registered_model(
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
+ prefix: str = "",
) -> nn.Module:
"""
Helper function to initialize an inner model registered to vLLM,
@@ -257,6 +258,7 @@ def init_vllm_registered_model(
lora_config=lora_config,
multimodal_config=multimodal_config,
scheduler_config=scheduler_config,
+ prefix=prefix,
)
@@ -610,3 +612,16 @@ def get_vit_attn_backend() -> _Backend:
else:
selected_backend = _Backend.XFORMERS
return selected_backend
+
+
+def maybe_prefix(prefix: str, name: str) -> str:
+ """Add a prefix to a name if the prefix is non-empty.
+
+ Args:
+ prefix: The prefix to add. If empty, no prefix will be added.
+ name: The name to potentially prefix.
+
+ Returns:
+ The string "prefix.name" if prefix was non-empty, otherwise just "name".
+ """
+ return name if not prefix else f"{prefix}.{name}"
From 1ab6f6b4ad5c4aac6ee72e51b7f6712098f9ccff Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Tue, 29 Oct 2024 17:06:24 -0700
Subject: [PATCH 20/23] [core][distributed] fix custom allreduce in pytorch 2.5
(#9815)
Signed-off-by: youkaichao
---
.../device_communicators/custom_all_reduce.py | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py
index 7de5b05a0b053..c3632aee6d11a 100644
--- a/vllm/distributed/device_communicators/custom_all_reduce.py
+++ b/vllm/distributed/device_communicators/custom_all_reduce.py
@@ -191,8 +191,20 @@ def capture(self):
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
+ handle = data[1]
+ # https://github.com/pytorch/pytorch/pull/130890 changes
+ # the binary format of the ipc handle
+ # it starts from pytorch 2.5
+ if len(handle) > 64:
+ assert len(handle) == 66
+ # only support SHAREABLE_HANDLE_VERSION = 1
+ assert int(handle[0]) == 1
+ # only support SHAREABLE_CUDA_MALLOC = 'c'
+ assert handle[1] == ord("c")
+ handle = handle[2:]
+ # TODO: support expandable segment
shard_data = (
- data[1], # ipc handle to base ptr
+ handle, # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
From 64cb1cdc3f3a6c0ca976d68b19d454122c720e6d Mon Sep 17 00:00:00 2001
From: Lily Liu
Date: Tue, 29 Oct 2024 17:28:43 -0700
Subject: [PATCH 21/23] Update README.md (#9819)
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 8c8d6eb291cea..b75bfc5c699a7 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ Easy, fast, and cheap LLM serving for everyone
---
-**vLLM x Snowfkale Meetup (Wednesday, November 13th, 5:30-8PM PT) at Snowfkale HQ, San Mateo**
+**vLLM x Snowflake Meetup (Wednesday, November 13th, 5:30-8PM PT) at Snowflake HQ, San Mateo**
We are excited to announce the last in-person vLLM meetup of the year!
Join the vLLM developers and engineers from Snowflake AI Research to chat about the latest LLM inference optimizations and your 2025 vLLM wishlist!
From 226688bd6114749633132b9ed074c59d50904830 Mon Sep 17 00:00:00 2001
From: Michael Goin
Date: Tue, 29 Oct 2024 22:49:44 -0400
Subject: [PATCH 22/23] [Bugfix][VLM] Make apply_fp8_linear work with >2D input
(#9812)
---
.../layers/quantization/utils/w8a8_utils.py | 33 +++++++++++--------
1 file changed, 20 insertions(+), 13 deletions(-)
diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
index 1879d2855d93d..445117ac99a34 100644
--- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
@@ -96,21 +96,26 @@ def apply_fp8_linear(
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
+ # View input as 2D matrix for fp8 methods
+ input_2d = input.view(-1, input.shape[-1])
+ output_shape = [*input.shape[:-1], weight.shape[1]]
+
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(
- input,
+ input_2d,
input_scale,
scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ
- return ops.cutlass_scaled_mm(qinput,
- weight,
- out_dtype=input.dtype,
- scale_a=x_scale,
- scale_b=weight_scale,
- bias=bias)
+ output = ops.cutlass_scaled_mm(qinput,
+ weight,
+ out_dtype=input.dtype,
+ scale_a=x_scale,
+ scale_b=weight_scale,
+ bias=bias)
+ return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
@@ -119,7 +124,7 @@ def apply_fp8_linear(
# for matrices with batch dimension > 16.
# This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant(
- input,
+ input_2d,
input_scale,
num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic)
@@ -138,8 +143,10 @@ def apply_fp8_linear(
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
- return torch.narrow(output[0], 0, 0, input.shape[0])
- return torch.narrow(output, 0, 0, input.shape[0])
+ output = output[0]
+
+ return torch.narrow(output, 0, 0,
+ input_2d.shape[0]).view(*output_shape)
else:
# Fallback for channelwise case, where we use unfused DQ
@@ -176,15 +183,15 @@ def apply_fp8_linear(
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
- output = torch.narrow(output, 0, 0, input.shape[0])
- x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
+ output = torch.narrow(output, 0, 0, input_2d.shape[0])
+ x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t()
if bias is not None:
output = output + bias
- return output.to(dtype=input.dtype)
+ return output.to(dtype=input.dtype).view(*output_shape)
def apply_int8_linear(
From 62fac4b9aab3c05124d83fcd71db5732774b17d8 Mon Sep 17 00:00:00 2001
From: "Kevin H. Luu"
Date: Tue, 29 Oct 2024 17:34:55 -1000
Subject: [PATCH 23/23] [ci/build] Pin CI dependencies version with pip-compile
(#9810)
Signed-off-by: kevin
---
Dockerfile.rocm | 2 +
requirements-build.txt | 18 +-
requirements-test.in | 37 +++
requirements-test.txt | 593 ++++++++++++++++++++++++++++++++++++++---
4 files changed, 608 insertions(+), 42 deletions(-)
create mode 100644 requirements-test.in
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
index d35889f053e27..562117a313020 100644
--- a/Dockerfile.rocm
+++ b/Dockerfile.rocm
@@ -121,6 +121,8 @@ ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
+RUN python3 -m pip install --upgrade pip
+
# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
diff --git a/requirements-build.txt b/requirements-build.txt
index ea2b688bb3108..7b16d9778c1a6 100644
--- a/requirements-build.txt
+++ b/requirements-build.txt
@@ -1,9 +1,9 @@
-# Should be mirrored in pyproject.toml
-cmake>=3.26
-ninja
-packaging
-setuptools>=61
-setuptools-scm>=8
-torch==2.5.0
-wheel
-jinja2
+# Should be mirrored in pyproject.toml
+cmake>=3.26
+ninja
+packaging
+setuptools>=61
+setuptools-scm>=8
+torch==2.5.0
+wheel
+jinja2
diff --git a/requirements-test.in b/requirements-test.in
new file mode 100644
index 0000000000000..3881f2566b556
--- /dev/null
+++ b/requirements-test.in
@@ -0,0 +1,37 @@
+# testing
+pytest
+tensorizer>=2.9.0
+pytest-forked
+pytest-asyncio
+pytest-rerunfailures
+pytest-shard
+
+# testing utils
+awscli
+einops # required for MPT, qwen-vl and Mamba
+httpx
+librosa # required for audio tests
+opencv-python # required for video tests
+peft
+requests
+ray[adag]==2.35
+sentence-transformers # required for embedding
+soundfile # required for audio test
+timm # required for internvl test
+torch==2.5.0
+transformers_stream_generator # required for qwen-vl test
+matplotlib # required for qwen-vl test
+datamodel_code_generator # required for minicpm3 test
+lm-eval[api]==0.4.4 # required for model evaluation test
+
+# TODO: Add this after fully implementing llava(mantis)
+# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
+
+# Benchmarking
+aiohttp
+
+# quantization
+bitsandbytes>=0.44.0
+buildkite-test-collector==0.1.8
+
+numpy < 2.0.0
diff --git a/requirements-test.txt b/requirements-test.txt
index 9787fa2a4a486..c474c2ec34b22 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -1,34 +1,561 @@
-# testing
-pytest
-tensorizer>=2.9.0
-pytest-forked
-pytest-asyncio
-pytest-rerunfailures
-pytest-shard
-
-# testing utils
-awscli
-einops # required for MPT, qwen-vl and Mamba
-httpx
-librosa # required for audio tests
-opencv-python # required for video tests
-peft
-requests
-ray[adag]==2.35
-sentence-transformers # required for embedding
-soundfile # required for audio test
-timm # required for internvl test
-transformers_stream_generator # required for qwen-vl test
-matplotlib # required for qwen-vl test
-datamodel_code_generator # required for minicpm3 test
-lm-eval[api]==0.4.4 # required for model evaluation test
-
-# TODO: Add this after fully implementing llava(mantis)
-# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
-
-# Benchmarking
-aiohttp
-
-# quantization
-bitsandbytes>=0.44.0
+#
+# This file is autogenerated by pip-compile with Python 3.12
+# by the following command:
+#
+# pip-compile --output-file=requirements-test.txt requirements-test.in
+#
+absl-py==2.1.0
+ # via rouge-score
+accelerate==1.0.1
+ # via
+ # lm-eval
+ # peft
+aiohappyeyeballs==2.4.3
+ # via aiohttp
+aiohttp==3.10.10
+ # via
+ # -r requirements-test.in
+ # datasets
+ # fsspec
+ # lm-eval
+aiosignal==1.3.1
+ # via
+ # aiohttp
+ # ray
+annotated-types==0.7.0
+ # via pydantic
+anyio==4.6.2.post1
+ # via httpx
+argcomplete==3.5.1
+ # via datamodel-code-generator
+attrs==24.2.0
+ # via
+ # aiohttp
+ # jsonlines
+ # jsonschema
+ # referencing
+audioread==3.0.1
+ # via librosa
+awscli==1.35.16
+ # via -r requirements-test.in
+bitsandbytes==0.44.1
+ # via -r requirements-test.in
+black==24.10.0
+ # via datamodel-code-generator
+boto3==1.35.50
+ # via tensorizer
+botocore==1.35.50
+ # via
+ # awscli
+ # boto3
+ # s3transfer
buildkite-test-collector==0.1.8
+ # via -r requirements-test.in
+certifi==2024.8.30
+ # via
+ # httpcore
+ # httpx
+ # requests
+cffi==1.17.1
+ # via soundfile
+chardet==5.2.0
+ # via mbstrdecoder
+charset-normalizer==3.4.0
+ # via requests
+click==8.1.7
+ # via
+ # black
+ # nltk
+ # ray
+colorama==0.4.6
+ # via
+ # awscli
+ # sacrebleu
+ # tqdm-multiprocess
+contourpy==1.3.0
+ # via matplotlib
+cupy-cuda12x==13.3.0
+ # via ray
+cycler==0.12.1
+ # via matplotlib
+datamodel-code-generator==0.26.2
+ # via -r requirements-test.in
+dataproperty==1.0.1
+ # via
+ # pytablewriter
+ # tabledata
+datasets==3.0.2
+ # via
+ # evaluate
+ # lm-eval
+decorator==5.1.1
+ # via librosa
+dill==0.3.8
+ # via
+ # datasets
+ # evaluate
+ # lm-eval
+ # multiprocess
+dnspython==2.7.0
+ # via email-validator
+docutils==0.16
+ # via awscli
+einops==0.8.0
+ # via -r requirements-test.in
+email-validator==2.2.0
+ # via pydantic
+evaluate==0.4.3
+ # via lm-eval
+fastrlock==0.8.2
+ # via cupy-cuda12x
+filelock==3.16.1
+ # via
+ # datasets
+ # huggingface-hub
+ # ray
+ # torch
+ # transformers
+ # triton
+fonttools==4.54.1
+ # via matplotlib
+frozenlist==1.5.0
+ # via
+ # aiohttp
+ # aiosignal
+ # ray
+fsspec[http]==2024.9.0
+ # via
+ # datasets
+ # evaluate
+ # huggingface-hub
+ # torch
+genson==1.3.0
+ # via datamodel-code-generator
+h11==0.14.0
+ # via httpcore
+hiredis==3.0.0
+ # via tensorizer
+httpcore==1.0.6
+ # via httpx
+httpx==0.27.2
+ # via -r requirements-test.in
+huggingface-hub==0.26.2
+ # via
+ # accelerate
+ # datasets
+ # evaluate
+ # peft
+ # sentence-transformers
+ # timm
+ # tokenizers
+ # transformers
+idna==3.10
+ # via
+ # anyio
+ # email-validator
+ # httpx
+ # requests
+ # yarl
+inflect==5.6.2
+ # via datamodel-code-generator
+iniconfig==2.0.0
+ # via pytest
+isort==5.13.2
+ # via datamodel-code-generator
+jinja2==3.1.4
+ # via
+ # datamodel-code-generator
+ # torch
+jmespath==1.0.1
+ # via
+ # boto3
+ # botocore
+joblib==1.4.2
+ # via
+ # librosa
+ # nltk
+ # scikit-learn
+jsonlines==4.0.0
+ # via lm-eval
+jsonschema==4.23.0
+ # via ray
+jsonschema-specifications==2024.10.1
+ # via jsonschema
+kiwisolver==1.4.7
+ # via matplotlib
+lazy-loader==0.4
+ # via librosa
+libnacl==2.1.0
+ # via tensorizer
+librosa==0.10.2.post1
+ # via -r requirements-test.in
+llvmlite==0.43.0
+ # via numba
+lm-eval[api]==0.4.4
+ # via -r requirements-test.in
+lxml==5.3.0
+ # via sacrebleu
+markupsafe==3.0.2
+ # via jinja2
+matplotlib==3.9.2
+ # via -r requirements-test.in
+mbstrdecoder==1.1.3
+ # via
+ # dataproperty
+ # pytablewriter
+ # typepy
+more-itertools==10.5.0
+ # via lm-eval
+mpmath==1.3.0
+ # via sympy
+msgpack==1.1.0
+ # via
+ # librosa
+ # ray
+multidict==6.1.0
+ # via
+ # aiohttp
+ # yarl
+multiprocess==0.70.16
+ # via
+ # datasets
+ # evaluate
+mypy-extensions==1.0.0
+ # via black
+networkx==3.2.1
+ # via torch
+nltk==3.9.1
+ # via rouge-score
+numba==0.60.0
+ # via librosa
+numexpr==2.10.1
+ # via lm-eval
+numpy==1.26.4
+ # via
+ # -r requirements-test.in
+ # accelerate
+ # bitsandbytes
+ # contourpy
+ # cupy-cuda12x
+ # datasets
+ # evaluate
+ # librosa
+ # matplotlib
+ # numba
+ # numexpr
+ # opencv-python
+ # pandas
+ # peft
+ # rouge-score
+ # sacrebleu
+ # scikit-learn
+ # scipy
+ # soxr
+ # tensorizer
+ # torchvision
+ # transformers
+nvidia-cublas-cu12==12.4.5.8
+ # via
+ # nvidia-cudnn-cu12
+ # nvidia-cusolver-cu12
+ # torch
+nvidia-cuda-cupti-cu12==12.4.127
+ # via torch
+nvidia-cuda-nvrtc-cu12==12.4.127
+ # via torch
+nvidia-cuda-runtime-cu12==12.4.127
+ # via torch
+nvidia-cudnn-cu12==9.1.0.70
+ # via torch
+nvidia-cufft-cu12==11.2.1.3
+ # via torch
+nvidia-curand-cu12==10.3.5.147
+ # via torch
+nvidia-cusolver-cu12==11.6.1.9
+ # via torch
+nvidia-cusparse-cu12==12.3.1.170
+ # via
+ # nvidia-cusolver-cu12
+ # torch
+nvidia-nccl-cu12==2.21.5
+ # via torch
+nvidia-nvjitlink-cu12==12.4.127
+ # via
+ # nvidia-cusolver-cu12
+ # nvidia-cusparse-cu12
+ # torch
+nvidia-nvtx-cu12==12.4.127
+ # via torch
+opencv-python==4.10.0.84
+ # via -r requirements-test.in
+packaging==24.1
+ # via
+ # accelerate
+ # black
+ # datamodel-code-generator
+ # datasets
+ # evaluate
+ # huggingface-hub
+ # lazy-loader
+ # matplotlib
+ # peft
+ # pooch
+ # pytest
+ # pytest-rerunfailures
+ # ray
+ # transformers
+ # typepy
+pandas==2.2.3
+ # via
+ # datasets
+ # evaluate
+pathspec==0.12.1
+ # via black
+pathvalidate==3.2.1
+ # via pytablewriter
+peft==0.13.2
+ # via
+ # -r requirements-test.in
+ # lm-eval
+pillow==11.0.0
+ # via
+ # matplotlib
+ # sentence-transformers
+ # torchvision
+platformdirs==4.3.6
+ # via
+ # black
+ # pooch
+pluggy==1.5.0
+ # via pytest
+pooch==1.8.2
+ # via librosa
+portalocker==2.10.1
+ # via sacrebleu
+propcache==0.2.0
+ # via yarl
+protobuf==5.28.3
+ # via
+ # ray
+ # tensorizer
+psutil==6.1.0
+ # via
+ # accelerate
+ # peft
+ # tensorizer
+py==1.11.0
+ # via pytest-forked
+pyarrow==18.0.0
+ # via datasets
+pyasn1==0.6.1
+ # via rsa
+pybind11==2.13.6
+ # via lm-eval
+pycparser==2.22
+ # via cffi
+pydantic[email]==2.9.2
+ # via datamodel-code-generator
+pydantic-core==2.23.4
+ # via pydantic
+pyparsing==3.2.0
+ # via matplotlib
+pytablewriter==1.2.0
+ # via lm-eval
+pytest==8.3.3
+ # via
+ # -r requirements-test.in
+ # buildkite-test-collector
+ # pytest-asyncio
+ # pytest-forked
+ # pytest-rerunfailures
+ # pytest-shard
+pytest-asyncio==0.24.0
+ # via -r requirements-test.in
+pytest-forked==1.6.0
+ # via -r requirements-test.in
+pytest-rerunfailures==14.0
+ # via -r requirements-test.in
+pytest-shard==0.1.2
+ # via -r requirements-test.in
+python-dateutil==2.9.0.post0
+ # via
+ # botocore
+ # matplotlib
+ # pandas
+ # typepy
+pytz==2024.2
+ # via
+ # pandas
+ # typepy
+pyyaml==6.0.2
+ # via
+ # accelerate
+ # awscli
+ # datamodel-code-generator
+ # datasets
+ # huggingface-hub
+ # peft
+ # ray
+ # timm
+ # transformers
+ray[adag]==2.35.0
+ # via -r requirements-test.in
+redis==5.2.0
+ # via tensorizer
+referencing==0.35.1
+ # via
+ # jsonschema
+ # jsonschema-specifications
+regex==2024.9.11
+ # via
+ # nltk
+ # sacrebleu
+ # tiktoken
+ # transformers
+requests==2.32.3
+ # via
+ # -r requirements-test.in
+ # buildkite-test-collector
+ # datasets
+ # evaluate
+ # huggingface-hub
+ # lm-eval
+ # pooch
+ # ray
+ # tiktoken
+ # transformers
+rouge-score==0.1.2
+ # via lm-eval
+rpds-py==0.20.0
+ # via
+ # jsonschema
+ # referencing
+rsa==4.7.2
+ # via awscli
+s3transfer==0.10.3
+ # via
+ # awscli
+ # boto3
+sacrebleu==2.4.3
+ # via lm-eval
+safetensors==0.4.5
+ # via
+ # accelerate
+ # peft
+ # timm
+ # transformers
+scikit-learn==1.5.2
+ # via
+ # librosa
+ # lm-eval
+ # sentence-transformers
+scipy==1.13.1
+ # via
+ # librosa
+ # scikit-learn
+ # sentence-transformers
+sentence-transformers==3.2.1
+ # via -r requirements-test.in
+six==1.16.0
+ # via
+ # python-dateutil
+ # rouge-score
+sniffio==1.3.1
+ # via
+ # anyio
+ # httpx
+soundfile==0.12.1
+ # via
+ # -r requirements-test.in
+ # librosa
+soxr==0.5.0.post1
+ # via librosa
+sqlitedict==2.1.0
+ # via lm-eval
+sympy==1.13.1
+ # via torch
+tabledata==1.3.3
+ # via pytablewriter
+tabulate==0.9.0
+ # via sacrebleu
+tcolorpy==0.1.6
+ # via pytablewriter
+tenacity==9.0.0
+ # via lm-eval
+tensorizer==2.9.0
+ # via -r requirements-test.in
+threadpoolctl==3.5.0
+ # via scikit-learn
+tiktoken==0.8.0
+ # via lm-eval
+timm==1.0.11
+ # via -r requirements-test.in
+tokenizers==0.20.1
+ # via transformers
+torch==2.5.0
+ # via
+ # -r requirements-test.in
+ # accelerate
+ # bitsandbytes
+ # lm-eval
+ # peft
+ # sentence-transformers
+ # tensorizer
+ # timm
+ # torchvision
+torchvision==0.20.0
+ # via timm
+tqdm==4.66.6
+ # via
+ # datasets
+ # evaluate
+ # huggingface-hub
+ # lm-eval
+ # nltk
+ # peft
+ # sentence-transformers
+ # tqdm-multiprocess
+ # transformers
+tqdm-multiprocess==0.0.11
+ # via lm-eval
+transformers==4.45.2
+ # via
+ # lm-eval
+ # peft
+ # sentence-transformers
+ # transformers-stream-generator
+transformers-stream-generator==0.0.5
+ # via -r requirements-test.in
+triton==3.1.0
+ # via torch
+typepy[datetime]==1.3.2
+ # via
+ # dataproperty
+ # pytablewriter
+ # tabledata
+typing-extensions==4.12.2
+ # via
+ # huggingface-hub
+ # librosa
+ # pydantic
+ # pydantic-core
+ # torch
+tzdata==2024.2
+ # via pandas
+urllib3==1.26.20
+ # via
+ # botocore
+ # requests
+word2number==1.1
+ # via lm-eval
+xxhash==3.5.0
+ # via
+ # datasets
+ # evaluate
+yarl==1.17.0
+ # via aiohttp
+zstandard==0.23.0
+ # via lm-eval
+
+# The following packages are considered to be unsafe in a requirements file:
+# setuptools