Skip to content

Commit

Permalink
[CI/Build] Fix some mypy errors from checking vllm/engine
Browse files Browse the repository at this point in the history
This fixes 27 of the 100 errors that occur when running:

    mypy --follow-imports silent vllm/engine

It is submitted as a checkpoint to make it easier to review.

Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb committed Oct 11, 2024
1 parent cbac1b7 commit d1061ce
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 44 deletions.
2 changes: 1 addition & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
Expand Down
13 changes: 7 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ class CacheConfig:
profiled num_gpu_blocks if specified. Does nothing if None.
"""

# Will be set after profiling.
num_gpu_blocks: Optional[int] = None
num_cpu_blocks: Optional[int] = None

def __init__(
self,
block_size: int,
Expand All @@ -602,10 +606,6 @@ def __init__(
self._verify_cache_dtype()
self._verify_prefix_caching()

# Will be set after profiling.
self.num_gpu_blocks = None
self.num_cpu_blocks = None

def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
Expand Down Expand Up @@ -681,7 +681,8 @@ def __post_init__(self):

@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
Expand Down Expand Up @@ -1514,7 +1515,7 @@ class LoRAConfig:
max_loras: int
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_dtype: Optional[Union[torch.dtype, str]] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
Expand Down
7 changes: 4 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import time
from collections import deque
from dataclasses import dataclass, field
from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set,
Tuple, Union)
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as SequenceType
from typing import Set, Tuple, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Expand Down Expand Up @@ -115,7 +116,7 @@ class ScheduledSequenceGroup:
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
scheduled_seq_groups: SequenceType[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.
Expand Down
12 changes: 7 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union)
Tuple, Type, Union, cast)

import torch

Expand Down Expand Up @@ -89,7 +89,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
config_format: str = 'auto'
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -181,7 +181,7 @@ class EngineArgs:
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"

def __post_init__(self):
if self.tokenizer is None:
if not self.tokenizer:
self.tokenizer = self.model
from vllm.plugins import load_general_plugins
load_general_plugins()
Expand Down Expand Up @@ -834,7 +834,8 @@ def from_cli_args(cls, args: argparse.Namespace):
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
tokenizer=self.tokenizer,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
dtype=self.dtype,
Expand Down Expand Up @@ -905,8 +906,9 @@ def create_engine_config(self) -> EngineConfig:
self.enable_prefix_caching = False

cache_config = CacheConfig(
# neuron needs block_size = max_model_len
block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len
(self.max_model_len if self.max_model_len is not None else 0),
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, overload
from typing import Set, Type, Union, cast, overload

import torch
from typing_extensions import TypeVar
Expand Down Expand Up @@ -188,7 +188,7 @@ def validate_output(
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")

return output
return cast(_O, output)

@classmethod
def validate_outputs(
Expand Down Expand Up @@ -1236,8 +1236,9 @@ def _advance_to_next_step(
seq_group, seq_group_metadata,
seq_group.state.num_steps == 1)
else:
seq_group.update_num_computed_tokens(
seq_group_metadata.token_chunk_size)
token_chunk_size = (seq_group_metadata.token_chunk_size if
seq_group_metadata.token_chunk_size else 0)
seq_group.update_num_computed_tokens(token_chunk_size)

if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
Expand Down
14 changes: 9 additions & 5 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from typing import Counter as CollectionsCounter
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Type, Union, cast

import numpy as np
import prometheus_client
Expand Down Expand Up @@ -255,10 +255,11 @@ def __init__(self,
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
boundaries=boundaries)

def labels(self, **labels):
self._histogram.set_default_tags(labels)
Expand All @@ -273,9 +274,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
_gauge_cls: Type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper)
_counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper)

def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
Expand Down
16 changes: 11 additions & 5 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, overload)
Optional, Union, cast, overload)

import cloudpickle
import zmq
Expand Down Expand Up @@ -605,9 +605,14 @@ def encode(
assert (prompt is not None and pooling_params is not None
and request_id is not None)

return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers, None,
priority)
return cast(
AsyncGenerator[EmbeddingRequestOutput, None],
self._process_request(prompt,
pooling_params,
request_id,
lora_request,
trace_headers,
priority=priority))

async def _process_request(
self,
Expand All @@ -630,7 +635,8 @@ async def _process_request(
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params.guided_decoding is not None and \
self.decoding_config and lora_request:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
Expand Down
6 changes: 2 additions & 4 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@ def __init__(self,
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True
kwargs['use_cached_outputs'] = True

self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def maybe_stop_sequence(
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if sampling_params.stop_token_ids and \
last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
Expand Down Expand Up @@ -92,7 +93,7 @@ def _check_stop_strings(seq: Sequence, new_char_count: int,
Returns the stop string if matched or else None.
"""
if not new_char_count:
if not new_char_count or not sampling_params.stop:
return None

for stop_str in sampling_params.stop:
Expand Down
8 changes: 3 additions & 5 deletions vllm/engine/output_processor/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import List
from typing import Iterable, List
from typing import Sequence as GenericSequence
from typing import Union

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import PoolerOutput, SequenceGroupOutput
from vllm.sequence import SequenceGroupOutput


def create_output_by_sequence_group(
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
outputs: GenericSequence[Iterable],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class SamplerOutput(
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None

def __getitem__(self, idx: int):
def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
return self.outputs[idx]

def __setitem__(self, idx: int, value):
Expand Down
3 changes: 2 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Sequence as GenericSequence
from typing import Union

from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
Expand Down Expand Up @@ -92,7 +93,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt: Optional[PromptType],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
Expand Down
4 changes: 2 additions & 2 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
self.init_multi_step(num_steps=num_lookahead_slots + 1)

def get_last_latency(self, now: float) -> Optional[float]:
def get_last_latency(self, now: float) -> float:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
Expand Down Expand Up @@ -1175,7 +1175,7 @@ class PoolerOutput(

spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None

def __getitem__(self, idx: int):
def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
return self.outputs[idx]

def __setitem__(self, idx: int, value):
Expand Down

0 comments on commit d1061ce

Please sign in to comment.