Skip to content

Commit

Permalink
[CI/Build] mypy: Resolve some errors from checking vllm/engine (#9267)
Browse files Browse the repository at this point in the history
Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb authored Oct 16, 2024
1 parent 8345045 commit 776dbd7
Show file tree
Hide file tree
Showing 20 changed files with 109 additions and 74 deletions.
12 changes: 1 addition & 11 deletions tools/mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,14 @@ run_mypy() {

run_mypy # Note that this is less strict than CI
run_mypy tests
run_mypy vllm/assets
run_mypy vllm/attention
#run_mypy vllm/compilation
#run_mypy vllm/core
run_mypy vllm/compilation
run_mypy vllm/distributed
run_mypy vllm/engine
run_mypy vllm/entrypoints
run_mypy vllm/executor
#run_mypy vllm/inputs
run_mypy vllm/logging
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/multimodal
run_mypy vllm/platforms
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/transformers_utils
run_mypy vllm/usage
#run_mypy vllm/vllm_flash_attn
run_mypy vllm/worker
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def compiled_graph_wrapper(*args):

def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager"
return backend
backend_str = "eager"
return backend_str
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"
Expand Down
8 changes: 5 additions & 3 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
def cls_decorator_helper(cls: type):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if not hasattr(cls, 'forward'):
raise TypeError("decorated class should have a forward method.")
sig = inspect.signature(cls.forward)
for k in dynamic_arg_dims:
if k not in sig.parameters:
Expand Down Expand Up @@ -63,13 +65,13 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

old_init = cls.__init__
old_init = cls.__init__ # type: ignore

def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
TorchCompileWrapperWithCustomDispatcher.__init__(self)

cls.__init__ = __init__
cls.__init__ = __init__ # type: ignore

def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
Expand Down Expand Up @@ -109,5 +111,5 @@ def __call__(self, *args, **kwargs):
model_output = self.forward(*args, **kwargs)
return model_output

cls.__call__ = __call__
cls.__call__ = __call__ # type: ignore
return cls
2 changes: 1 addition & 1 deletion vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while True:
while frame and frame.f_back:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
Expand Down
10 changes: 6 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,14 @@ def __init__(
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb

self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()

# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None
self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None

def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
Expand Down Expand Up @@ -709,7 +710,8 @@ def __post_init__(self):

@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
Expand Down Expand Up @@ -1544,7 +1546,7 @@ class LoRAConfig:
max_loras: int
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
Expand Down
7 changes: 4 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import time
from collections import deque
from dataclasses import dataclass, field
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Expand Down Expand Up @@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.
Expand Down
12 changes: 7 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union)
Tuple, Type, Union, cast)

import torch

Expand Down Expand Up @@ -89,7 +89,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -181,7 +181,7 @@ class EngineArgs:
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"

def __post_init__(self):
if self.tokenizer is None:
if not self.tokenizer:
self.tokenizer = self.model

# Setup plugins
Expand Down Expand Up @@ -837,7 +837,8 @@ def from_cli_args(cls, args: argparse.Namespace):
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
tokenizer=self.tokenizer,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
Expand Down Expand Up @@ -908,8 +909,9 @@ def create_engine_config(self) -> EngineConfig:
self.enable_prefix_caching = False

cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
(self.max_model_len if self.max_model_len is not None else 0),
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
Expand Down
20 changes: 12 additions & 8 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, overload
from typing import Set, Type, Union, cast, overload

import torch
from typing_extensions import TypeVar
Expand Down Expand Up @@ -44,7 +44,7 @@
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
Expand Down Expand Up @@ -188,7 +188,7 @@ def validate_output(
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")

return output
return cast(_O, output)

@classmethod
def validate_outputs(
Expand Down Expand Up @@ -1039,6 +1039,7 @@ def _process_model_outputs(self,
scheduler_outputs.scheduled_seq_groups)

has_multiple_outputs: bool = len(outputs) > 1
outputs_by_sequence_group: List[List[SequenceGroupOutput]]
if has_multiple_outputs:
assert self.scheduler_config.is_multi_step or \
self.speculative_config
Expand Down Expand Up @@ -1084,6 +1085,7 @@ def _process_model_outputs(self,
finished_before.append(i)
continue

output: List[SequenceGroupOutput]
if has_multiple_outputs:
output = outputs_by_sequence_group[i]
else:
Expand All @@ -1096,21 +1098,21 @@ def _process_model_outputs(self,
seq_group, seq_group_meta, is_first_step_output)
else:
seq_group.update_num_computed_tokens(
seq_group_meta.token_chunk_size)
seq_group_meta.token_chunk_size or 0)

if outputs:
for o in outputs:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
o.model_forward_time or 0)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
o.model_execute_time or 0)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
Expand Down Expand Up @@ -1236,8 +1238,10 @@ def _advance_to_next_step(
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
token_chunk_size = (seq_group_metadata.token_chunk_size
if seq_group_metadata.token_chunk_size
is not None else 0)
seq_group.update_num_computed_tokens(token_chunk_size)

if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
Expand Down
14 changes: 9 additions & 5 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union, cast

import numpy as np
import prometheus_client
Expand Down Expand Up @@ -249,10 +249,11 @@ def __init__(self,
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
boundaries=boundaries)

def labels(self, **labels):
self._histogram.set_default_tags(labels)
Expand All @@ -267,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
_gauge_cls: Type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper)
_counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)

def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
Expand Down
17 changes: 12 additions & 5 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload)
Optional, Union, cast, overload)

import cloudpickle
import zmq
Expand Down Expand Up @@ -513,9 +513,14 @@ def encode(
assert (prompt is not None and pooling_params is not None
and request_id is not None)

return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers, None,
priority)
return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))

async def _process_request(
self,
Expand Down Expand Up @@ -543,7 +548,9 @@ async def _process_request(
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.guided_decoding_backend
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
)

# 1) Create output queue for this requests.
Expand Down
6 changes: 2 additions & 4 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@ def __init__(self,
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True
kwargs['use_cached_outputs'] = True

self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
Expand Down
Loading

0 comments on commit 776dbd7

Please sign in to comment.