diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 5780f09a646cb..721c9c026cf16 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -32,22 +32,17 @@ jobs: pip install types-setuptools - name: Mypy run: | - mypy tests --config-file pyproject.toml - mypy vllm/*.py --config-file pyproject.toml - mypy vllm/attention --config-file pyproject.toml - mypy vllm/core --config-file pyproject.toml - mypy vllm/distributed --config-file pyproject.toml - mypy vllm/engine --config-file pyproject.toml - mypy vllm/entrypoints --config-file pyproject.toml - mypy vllm/executor --config-file pyproject.toml - mypy vllm/inputs --config-file pyproject.toml - mypy vllm/logging --config-file pyproject.toml - mypy vllm/lora --config-file pyproject.toml - mypy vllm/model_executor --config-file pyproject.toml - mypy vllm/multimodal --config-file pyproject.toml - mypy vllm/platforms --config-file pyproject.toml - mypy vllm/spec_decode --config-file pyproject.toml - mypy vllm/transformers_utils --config-file pyproject.toml - mypy vllm/usage --config-file pyproject.toml - mypy vllm/worker --config-file pyproject.toml + mypy tests --follow-imports skip + mypy vllm/attention --follow-imports skip + mypy vllm/core --follow-imports skip + mypy vllm/distributed --follow-imports skip + mypy vllm/engine --follow-imports skip + mypy vllm/entrypoints --follow-imports skip + mypy vllm/executor --follow-imports skip + mypy vllm/lora --follow-imports skip + mypy vllm/model_executor --follow-imports skip + mypy vllm/prompt_adapter --follow-imports skip + mypy vllm/spec_decode --follow-imports skip + mypy vllm/worker --follow-imports skip + mypy diff --git a/format.sh b/format.sh index 5ad6d6f2938bb..71697cffacfb4 100755 --- a/format.sh +++ b/format.sh @@ -96,23 +96,19 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' -mypy tests --config-file pyproject.toml -mypy vllm/*.py --config-file pyproject.toml -mypy vllm/attention --config-file pyproject.toml -mypy vllm/core --config-file pyproject.toml -mypy vllm/distributed --config-file pyproject.toml -mypy vllm/engine --config-file pyproject.toml -mypy vllm/entrypoints --config-file pyproject.toml -mypy vllm/executor --config-file pyproject.toml -mypy vllm/logging --config-file pyproject.toml -mypy vllm/lora --config-file pyproject.toml -mypy vllm/model_executor --config-file pyproject.toml -mypy vllm/multimodal --config-file pyproject.toml -mypy vllm/prompt_adapter --config-file pyproject.toml -mypy vllm/spec_decode --config-file pyproject.toml -mypy vllm/transformers_utils --config-file pyproject.toml -mypy vllm/usage --config-file pyproject.toml -mypy vllm/worker --config-file pyproject.toml +mypy tests --follow-imports skip +mypy vllm/attention --follow-imports skip +mypy vllm/core --follow-imports skip +mypy vllm/distributed --follow-imports skip +mypy vllm/engine --follow-imports skip +mypy vllm/entrypoints --follow-imports skip +mypy vllm/executor --follow-imports skip +mypy vllm/lora --follow-imports skip +mypy vllm/model_executor --follow-imports skip +mypy vllm/prompt_adapter --follow-imports skip +mypy vllm/spec_decode --follow-imports skip +mypy vllm/worker --follow-imports skip +mypy # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/pyproject.toml b/pyproject.toml index 1ba1eacd90084..cd5d196a16200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,9 +48,23 @@ python_version = "3.8" ignore_missing_imports = true check_untyped_defs = true -follow_imports = "skip" +follow_imports = "silent" -files = "vllm" +# After fixing type errors resulting from follow_imports: "skip" -> "silent", +# move the directory here and remove it from format.sh and mypy.yaml +files = [ + "vllm/*.py", + "vllm/adapter_commons", + "vllm/assets", + "vllm/inputs", + "vllm/logging", + "vllm/multimodal", + "vllm/platforms", + "vllm/server", + "vllm/transformers_utils", + "vllm/triton_utils", + "vllm/usage", +] # TODO(woosuk): Include the code from Megatron and HuggingFace. exclude = [ "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6ca667eb85640..e351d602189e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: Type[torch.dtype], + out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index b4721b4e1aedd..2156f6b18adb6 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -25,27 +25,33 @@ def _reshape_activation_tensor( x2 = x2.reshape(num, d) return x1, x2 + @staticmethod def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: x1, x2 = ipex_ops._reshape_activation_tensor(x) ipex.llm.functional.silu_mul(x1, x2, out) + @staticmethod def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: x1, x2 = ipex_ops._reshape_activation_tensor(x) ipex.llm.functional.gelu_mul(x1, x2, out, "none") + @staticmethod def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: x1, x2 = ipex_ops._reshape_activation_tensor(x) ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") + @staticmethod def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: out.copy_(torch.nn.functional.gelu(x)) + @staticmethod def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: out.copy_(torch.nn.functional.gelu(x)) # TODO add implementation of gelu_quick here # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + @staticmethod def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, @@ -78,12 +84,21 @@ def paged_attention_v1( ).view(num_kv_heads, 1).repeat_interleave(num_queries_per_tokens).flatten() # todo: ipex will refactor namespace - torch.xpu.paged_attention_v1(out, query.contiguous(), - key_cache.view_as(value_cache), - value_cache, head_mapping, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes) + torch.xpu.paged_attention_v1( # type: ignore + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + @staticmethod def paged_attention_v2( out: torch.Tensor, exp_sum: torch.Tensor, @@ -119,13 +134,24 @@ def paged_attention_v2( ).view(num_kv_heads, 1).repeat_interleave(num_queries_per_tokens).flatten() # todo: ipex will refactor namespace - torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out, - query.contiguous(), - key_cache.view_as(value_cache), - value_cache, head_mapping, block_tables, - context_lens, scale, block_size, - max_context_len, alibi_slopes) + torch.xpu.paged_attention_v2( # type: ignore + out, + exp_sum, + max_logits, + tmp_out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + head_mapping, + block_tables, + context_lens, + scale, + block_size, + max_context_len, + alibi_slopes, + ) + @staticmethod def rotary_embedding( positions: torch.Tensor, # [batch_size, seq_len] query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] @@ -158,6 +184,7 @@ def rotary_embedding( ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, rotary_dim, is_neox, positions) + @staticmethod def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, @@ -189,17 +216,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, rotary_dim, is_neox, positions) + @staticmethod def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) out.copy_(tmp) + @staticmethod def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, epsilon, True) input.copy_(tmp) + @staticmethod def varlen_attention( query: torch.Tensor, key: torch.Tensor, @@ -222,6 +252,7 @@ def varlen_attention( softmax_scale, zero_tensors, is_causal, return_softmax, gen_) + @staticmethod def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, @@ -240,8 +271,13 @@ def reshape_and_cache( def copy_blocks(key_caches: List[torch.Tensor], value_caches: List[torch.Tensor], block_mapping: torch.Tensor) -> None: - torch.xpu.copy_blocks(key_caches, value_caches, block_mapping) + torch.xpu.copy_blocks( # type: ignore + key_caches, + value_caches, + block_mapping, + ) + @staticmethod def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: - torch.xpu.swap_blocks(src, dst, block_mapping) + torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index 6939b1405f3e1..a5c04ab78fbe8 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -31,7 +31,7 @@ def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], super().__init__(capacity) self.deactivate_fn = deactivate_fn - def _on_remove(self, key: Hashable, value: T): + def _on_remove(self, key: Hashable, value: Optional[T]): logger.debug("Removing adapter int id: %d", key) self.deactivate_fn(key) return super()._on_remove(key, value) @@ -59,46 +59,46 @@ def __len__(self) -> int: @property @abstractmethod - def adapter_slots(self): - ... + def adapter_slots(self) -> int: + raise NotImplementedError @property @abstractmethod - def capacity(self): - ... + def capacity(self) -> int: + raise NotImplementedError @abstractmethod def activate_adapter(self, adapter_id: int) -> bool: - ... + raise NotImplementedError @abstractmethod def deactivate_adapter(self, adapter_id: int) -> bool: - ... + raise NotImplementedError @abstractmethod def add_adapter(self, adapter: Any) -> bool: - ... + raise NotImplementedError @abstractmethod def set_adapter_mapping(self, mapping: Any) -> None: - ... + raise NotImplementedError @abstractmethod def remove_adapter(self, adapter_id: int) -> bool: - ... + raise NotImplementedError @abstractmethod - def remove_all_adapters(self): - ... + def remove_all_adapters(self) -> None: + raise NotImplementedError @abstractmethod def get_adapter(self, adapter_id: int) -> Optional[Any]: - ... + raise NotImplementedError @abstractmethod def list_adapters(self) -> Dict[int, Any]: - ... + raise NotImplementedError @abstractmethod def pin_adapter(self, adapter_id: int) -> bool: - ... + raise NotImplementedError diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py index 69775ab7d4548..f98adeba1c705 100644 --- a/vllm/adapter_commons/request.py +++ b/vllm/adapter_commons/request.py @@ -1,19 +1,19 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from dataclasses import dataclass @dataclass -class AdapterRequest: +class AdapterRequest(ABC): """ Base class for adapter requests. """ @property @abstractmethod - def adapter_id(self): - ... + def adapter_id(self) -> int: + raise NotImplementedError - def __post_init__(self): + def __post_init__(self) -> None: if self.adapter_id < 1: raise ValueError(f"id must be > 0, got {self.adapter_id}") diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index acf18993af6d7..83929e82ebf04 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -12,25 +12,25 @@ def __init__(self, device: torch.device): @property @abstractmethod def is_enabled(self) -> bool: - ... + raise NotImplementedError @abstractmethod def set_active_adapters(self, requests: Set[Any], mapping: Optional[Any]) -> None: - ... + raise NotImplementedError @abstractmethod def add_adapter(self, adapter_request: Any) -> bool: - ... + raise NotImplementedError @abstractmethod def remove_adapter(self, adapter_id: int) -> bool: - ... + raise NotImplementedError @abstractmethod - def remove_all_adapters(self): - ... + def remove_all_adapters(self) -> None: + raise NotImplementedError @abstractmethod def list_adapters(self) -> Set[int]: - ... + raise NotImplementedError diff --git a/vllm/config.py b/vllm/config.py index e7b54e04b00d5..fd48cc3a6b371 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -724,7 +724,7 @@ def __init__( backend) self._verify_args() - self.rank = 0 + self.rank: int = 0 @property def use_ray(self) -> bool: @@ -850,6 +850,7 @@ def _verify_args(self) -> None: class DeviceConfig: + device: Optional[torch.device] def __init__(self, device: str = "auto") -> None: if device == "auto": diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 004348d4c49a3..1efe2206abe81 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,8 +5,6 @@ from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union -from transformers import PreTrainedTokenizer - import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -40,7 +38,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, +from vllm.transformers_utils.tokenizer_group import (AnyTokenizer, + BaseTokenizerGroup, get_tokenizer_group) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) @@ -477,13 +476,12 @@ def get_tokenizer_group( return self.tokenizer def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer(lora_request) - def get_tokenizer_for_seq(self, - sequence: Sequence) -> "PreTrainedTokenizer": + def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 321c9ac2c1d5f..b374a7946b11e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,7 +5,6 @@ from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union from pydantic import Field -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Annotated from vllm.config import ModelConfig @@ -30,6 +29,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer_group import AnyTokenizer logger = init_logger(__name__) @@ -49,8 +49,6 @@ class LoRAModulePath: AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, EmbeddingRequest, TokenizeRequest] -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - class TextTokensPrompt(TypedDict): prompt: str diff --git a/vllm/scripts.py b/vllm/scripts.py index aefa5cec93a57..403b22239aed0 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -4,9 +4,10 @@ import os import signal import sys -from typing import Optional +from typing import List, Optional from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None: def chat(system_prompt: Optional[str], model_name: str, client: OpenAI) -> None: - conversation = [] + conversation: List[ChatCompletionMessageParam] = [] if system_prompt is not None: conversation.append({"role": "system", "content": system_prompt}) print("Please enter a message for the chat model:") while True: input_message = input("> ") - message = {"role": "user", "content": input_message} - conversation.append(message) + conversation.append({"role": "user", "content": input_message}) chat_completion = client.chat.completions.create(model=model_name, messages=conversation) @@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str, response_message = chat_completion.choices[0].message output = response_message.content - conversation.append(response_message) + conversation.append(response_message) # type: ignore print(output) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 0a45028e7759b..76f418674532f 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -37,6 +37,8 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, The prompt logprobs with the decoded tokens. """ prms = seq_group.sampling_params + assert prms is not None + # We can pick any sequence for the prompt. seq = next(iter(seq_group.seqs_dict.values())) # Only prompt, without the generated token. diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9f54f5409b181..7a0436dd1fb16 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -2,10 +2,9 @@ from vllm.config import TokenizerPoolConfig from vllm.executor.ray_utils import ray -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) -from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( - TokenizerGroup) + +from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup if ray: from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import ( @@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) -__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"] +__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 9682db6966ddf..abbcdf2807f6f 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Union -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + class BaseTokenizerGroup(ABC): """A group of tokenizers that can be used for LoRA adapters.""" @@ -47,17 +49,17 @@ async def encode_async( @abstractmethod def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: """Get a tokenizer for a LoRA request.""" pass @abstractmethod async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: """Get a tokenizer for a LoRA request.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index 32384398a4c12..eebdf7bf644d0 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -6,18 +6,16 @@ from ray.exceptions import ActorDiedError except ImportError: # For older versions of Ray - from ray.exceptions import RayActorError as ActorDiedError + from ray.exceptions import RayActorError as ActorDiedError # type: ignore from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -from transformers import PreTrainedTokenizer from vllm.config import TokenizerPoolConfig from vllm.executor.ray_utils import ray from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) -from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( - TokenizerGroup) + +from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup logger = init_logger(__name__) @@ -67,7 +65,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, **self._tokenizer_config, ) self._ray_tokenizer_group_cls = ray.remote( - self._worker_cls).options(**ray_actor_options) + self._worker_cls).options(**ray_actor_options) # type: ignore self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] self._idle_actors: Optional[asyncio.Queue] = None @@ -83,8 +81,10 @@ def pool_size(self) -> int: return len(self.tokenizer_actors) def ping(self): - return ray.get( - [actor.ping.remote() for actor in self.tokenizer_actors]) + return ray.get([ + actor.ping.remote() # type: ignore + for actor in self.tokenizer_actors + ]) def _ensure_queue_initialized(self): if self._idle_actors is None: @@ -208,15 +208,15 @@ def get_max_input_len(self, return self._local_tokenizer_group.get_max_input_len(lora_request) def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: return self._local_tokenizer_group.get_lora_tokenizer(lora_request) async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: return await self._local_tokenizer_group.get_lora_tokenizer_async( lora_request) diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 74c041f13bad9..a5186e48068e9 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -1,16 +1,14 @@ from typing import List, Optional -from transformers import PreTrainedTokenizer - from vllm.config import TokenizerPoolConfig from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import (get_lora_tokenizer, get_lora_tokenizer_async, get_tokenizer) -from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) from vllm.utils import LRUCache +from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup + class TokenizerGroup(BaseTokenizerGroup): """A group of tokenizers that can be used for LoRA adapters.""" @@ -22,8 +20,8 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) - self.lora_tokenizers = LRUCache[PreTrainedTokenizer]( - capacity=max_num_seqs) if enable_lora else None + self.lora_tokenizers = LRUCache[AnyTokenizer]( + capacity=max_num_seqs if enable_lora else 0) @classmethod def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], @@ -41,7 +39,7 @@ def get_max_input_len(self, return self.max_input_length def _raise_if_input_too_long(self, - encoded_tokens: List[str], + encoded_tokens: List[int], lora_request: Optional[LoRARequest] = None): input_length = len(encoded_tokens) if lora_request: @@ -72,9 +70,9 @@ async def encode_async( return ret def get_lora_tokenizer( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: @@ -83,12 +81,12 @@ def get_lora_tokenizer( self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) return tokenizer else: - return self.lora_tokenizers.get(lora_request.lora_int_id) + return self.lora_tokenizers[lora_request.lora_int_id] async def get_lora_tokenizer_async( - self, - lora_request: Optional[LoRARequest] = None - ) -> "PreTrainedTokenizer": + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: if not lora_request or not self.enable_lora: return self.tokenizer if lora_request.lora_int_id not in self.lora_tokenizers: @@ -97,4 +95,4 @@ async def get_lora_tokenizer_async( self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) return tokenizer else: - return self.lora_tokenizers.get(lora_request.lora_int_id) + return self.lora_tokenizers[lora_request.lora_int_id] diff --git a/vllm/utils.py b/vllm/utils.py index 1448316e66edb..b7589ca50ba5b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -94,8 +94,10 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> Optional[T]: - return self.get(key) + def __getitem__(self, key: Hashable) -> T: + value = self.cache[key] # Raise KeyError if not exists + self.cache.move_to_end(key) + return value def __setitem__(self, key: Hashable, value: T) -> None: self.put(key, value) @@ -109,8 +111,9 @@ def touch(self, key: Hashable) -> None: def get(self, key: Hashable, default_value: Optional[T] = None) -> Optional[T]: + value: Optional[T] if key in self.cache: - value: Optional[T] = self.cache[key] + value = self.cache[key] self.cache.move_to_end(key) else: value = default_value @@ -590,8 +593,8 @@ def current_memory_usage(self) -> float: torch.cuda.reset_peak_memory_stats(self.device) mem = torch.cuda.max_memory_allocated(self.device) elif is_xpu(): - torch.xpu.reset_peak_memory_stats(self.device) - mem = torch.xpu.max_memory_allocated(self.device) + torch.xpu.reset_peak_memory_stats(self.device) # type: ignore + mem = torch.xpu.max_memory_allocated(self.device) # type: ignore return mem def __enter__(self):