diff --git a/tools/mypy.sh b/tools/mypy.sh index e6187a08ffd98..d69b61c7f34fc 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -13,24 +13,14 @@ run_mypy() { run_mypy # Note that this is less strict than CI run_mypy tests -run_mypy vllm/assets run_mypy vllm/attention -#run_mypy vllm/compilation -#run_mypy vllm/core +run_mypy vllm/compilation run_mypy vllm/distributed run_mypy vllm/engine -run_mypy vllm/entrypoints run_mypy vllm/executor -#run_mypy vllm/inputs -run_mypy vllm/logging run_mypy vllm/lora run_mypy vllm/model_executor -run_mypy vllm/multimodal -run_mypy vllm/platforms run_mypy vllm/plugins run_mypy vllm/prompt_adapter run_mypy vllm/spec_decode -run_mypy vllm/transformers_utils -run_mypy vllm/usage -#run_mypy vllm/vllm_flash_attn run_mypy vllm/worker diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0112f49876996..b46f0721d0caf 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -92,7 +92,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Optional[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4780358cea517..6d9832e2c39c0 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -244,8 +244,8 @@ def compiled_graph_wrapper(*args): def select_default_backend(level: int) -> Union[str, Callable]: if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: - backend = "eager" - return backend + backend_str = "eager" + return backend_str assert level in [ CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE ], f"Invalid level {level}" diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 655c4c4430179..3ae74cc5cb7dd 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -35,6 +35,8 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): def cls_decorator_helper(cls: type): # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile`` + if not hasattr(cls, 'forward'): + raise TypeError("decorated class should have a forward method.") sig = inspect.signature(cls.forward) for k in dynamic_arg_dims: if k not in sig.parameters: @@ -63,13 +65,13 @@ def _support_torch_compile(cls: type, # other than TorchCompileWrapperWithCustomDispatcher cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) - old_init = cls.__init__ + old_init = cls.__init__ # type: ignore def __init__(self, *args, **kwargs): old_init(self, *args, **kwargs) TorchCompileWrapperWithCustomDispatcher.__init__(self) - cls.__init__ = __init__ + cls.__init__ = __init__ # type: ignore def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation @@ -109,5 +111,5 @@ def __call__(self, *args, **kwargs): model_output = self.forward(*args, **kwargs) return model_output - cls.__call__ = __call__ + cls.__call__ = __call__ # type: ignore return cls diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 1594b64a61b94..7366ed4d16b0b 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -73,7 +73,7 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): return # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 frame = sys._getframe() - while True: + while frame and frame.f_back: frame = frame.f_back code_name = frame.f_code.co_name file_name = frame.f_code.co_filename.split(os.path.sep)[-1] diff --git a/vllm/config.py b/vllm/config.py index 7a3248f4087ae..9bf109721380b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -610,13 +610,14 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb + self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() # Will be set after profiling. - self.num_gpu_blocks = None - self.num_cpu_blocks = None + self.num_gpu_blocks: Optional[int] = None + self.num_cpu_blocks: Optional[int] = None def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus @@ -693,7 +694,8 @@ def __post_init__(self): @classmethod def create_config( - cls, tokenizer_pool_size: int, tokenizer_pool_type: str, + cls, tokenizer_pool_size: int, + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]], tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. @@ -1528,7 +1530,7 @@ class LoRAConfig: max_loras: int fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None - lora_dtype: Optional[torch.dtype] = None + lora_dtype: Optional[Union[torch.dtype, str]] = None lora_extra_vocab_size: int = 256 # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1f0a121711db5..e7eaaf12272d6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,8 +4,9 @@ import time from collections import deque from dataclasses import dataclass, field -from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import Callable, Deque, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -115,7 +116,7 @@ class ScheduledSequenceGroup: class SchedulerOutputs: """The scheduling decision made from a scheduler.""" # Scheduled sequence groups. - scheduled_seq_groups: Iterable[ScheduledSequenceGroup] + scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] # Number of prefill groups scheduled. num_prefill_groups: int # Total number of batched tokens. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1b132cf76a10d..ef06cb1e6217f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, - Tuple, Type, Union) + Tuple, Type, Union, cast) import torch @@ -89,7 +89,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' - config_format: str = 'auto' + config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -181,7 +181,7 @@ class EngineArgs: scheduling_policy: Literal["fcfs", "priority"] = "fcfs" def __post_init__(self): - if self.tokenizer is None: + if not self.tokenizer: self.tokenizer = self.model # Setup plugins @@ -836,7 +836,8 @@ def from_cli_args(cls, args: argparse.Namespace): def create_model_config(self) -> ModelConfig: return ModelConfig( model=self.model, - tokenizer=self.tokenizer, + # We know this is not None because we set it in __post_init__ + tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, dtype=self.dtype, @@ -907,8 +908,9 @@ def create_engine_config(self) -> EngineConfig: self.enable_prefix_caching = False cache_config = CacheConfig( + # neuron needs block_size = max_model_len block_size=self.block_size if self.device != "neuron" else - self.max_model_len, # neuron needs block_size = max_model_len + (self.max_model_len if self.max_model_len is not None else 0), gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 563e52a37d935..36ba49efe42e1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,7 +6,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union, overload +from typing import Set, Type, Union, cast, overload import torch from typing_extensions import TypeVar @@ -44,7 +44,7 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, Sequence, SequenceGroup, SequenceGroupMetadata, - SequenceStatus) + SequenceGroupOutput, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -188,7 +188,7 @@ def validate_output( raise TypeError(f"Expected output of type {output_type}, " f"but found type {type(output)}") - return output + return cast(_O, output) @classmethod def validate_outputs( @@ -1039,6 +1039,7 @@ def _process_model_outputs(self, scheduler_outputs.scheduled_seq_groups) has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] if has_multiple_outputs: assert self.scheduler_config.is_multi_step or \ self.speculative_config @@ -1084,6 +1085,7 @@ def _process_model_outputs(self, finished_before.append(i) continue + output: List[SequenceGroupOutput] if has_multiple_outputs: output = outputs_by_sequence_group[i] else: @@ -1096,7 +1098,7 @@ def _process_model_outputs(self, seq_group, seq_group_meta, is_first_step_output) else: seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size) + seq_group_meta.token_chunk_size or 0) if outputs: for o in outputs: @@ -1104,13 +1106,13 @@ def _process_model_outputs(self, and seq_group.metrics is not None): if seq_group.metrics.model_forward_time is not None: seq_group.metrics.model_forward_time += ( - o.model_forward_time) + o.model_forward_time or 0) else: seq_group.metrics.model_forward_time = ( o.model_forward_time) if seq_group.metrics.model_execute_time is not None: seq_group.metrics.model_execute_time += ( - o.model_execute_time) + o.model_execute_time or 0) else: seq_group.metrics.model_execute_time = ( o.model_execute_time) @@ -1236,8 +1238,10 @@ def _advance_to_next_step( seq_group, seq_group_metadata, seq_group.state.num_steps == 1) else: - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 42acd3ea4c94c..98bf59be3469d 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING from typing import Counter as CollectionsCounter -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union, cast import numpy as np import prometheus_client @@ -249,10 +249,11 @@ def __init__(self, labelnames: Optional[List[str]] = None, buckets: Optional[List[float]] = None): labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] self._histogram = ray_metrics.Histogram(name=name, description=documentation, tag_keys=labelnames_tuple, - boundaries=buckets) + boundaries=boundaries) def labels(self, **labels): self._histogram.set_default_tags(labels) @@ -267,9 +268,12 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _gauge_cls = _RayGaugeWrapper - _counter_cls = _RayCounterWrapper - _histogram_cls = _RayHistogramWrapper + _gauge_cls: Type[prometheus_client.Gauge] = cast( + Type[prometheus_client.Gauge], _RayGaugeWrapper) + _counter_cls: Type[prometheus_client.Counter] = cast( + Type[prometheus_client.Counter], _RayCounterWrapper) + _histogram_cls: Type[prometheus_client.Histogram] = cast( + Type[prometheus_client.Histogram], _RayHistogramWrapper) def __init__(self, labelnames: List[str], max_model_len: int): if ray_metrics is None: diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 6bf553666a852..9732c7098e160 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, overload) + Optional, Union, cast, overload) import cloudpickle import zmq @@ -513,9 +513,14 @@ def encode( assert (prompt is not None and pooling_params is not None and request_id is not None) - return self._process_request(prompt, pooling_params, request_id, - lora_request, trace_headers, None, - priority) + return cast( + AsyncGenerator[EmbeddingRequestOutput, None], + self._process_request(prompt, + pooling_params, + request_id, + lora_request, + trace_headers, + priority=priority)) async def _process_request( self, @@ -543,7 +548,9 @@ async def _process_request( build_guided_decoding_logits_processor_async( sampling_params=params, tokenizer=await self.get_tokenizer(lora_request), - default_guided_backend=self.decoding_config.guided_decoding_backend + default_guided_backend=(self.decoding_config.guided_decoding_backend + if self.decoding_config + else DecodingConfig.guided_decoding_backend), ) # 1) Create output queue for this requests. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 2bf0ce83c7607..ad0e970f36ff5 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -73,11 +73,9 @@ def __init__(self, # For MQLLMEngine, we can use cached outputs, since each new request # output is immediately pickled and send over the socket, which frees # the python object to be reused again. - use_cached_outputs = True + kwargs['use_cached_outputs'] = True - self.engine = LLMEngine(*args, - **kwargs, - use_cached_outputs=use_cached_outputs) + self.engine = LLMEngine(*args, **kwargs) self.log_requests = log_requests self.use_async_sockets = use_async_sockets diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 74ddb250ccd9e..3ed37a269c4b4 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List +from typing import Callable, List, cast from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -9,8 +9,10 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -57,6 +59,7 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, """ for output in outputs: # Concatenate single-step prompt logprob processing results. + assert isinstance(output, CompletionSequenceGroupOutput) single_step_process_prompt_logprob(self, seq_group, output) @staticmethod @@ -100,8 +103,18 @@ def process_outputs(self, "Beam search not supported in multi-step decoding.") seq = seqs[0] seq_id = seq.seq_id - assert all( - [seq_id == output.samples[0].parent_seq_id for output in outputs]) + # This method is defined in the more generic + # SequenceGroupOutputProcessor, but here we assume that the outputs are + # of a more specific type. + assert all([ + isinstance(output, CompletionSequenceGroupOutput) + for output in outputs + ]) + compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs) + assert all([ + seq_id == output.samples[0].parent_seq_id + for output in compl_outputs + ]) if is_async: # Async case: We process tokens one by one. Here, we know the token @@ -113,7 +126,7 @@ def process_outputs(self, # Since there's only one sequence per sequence group, # we can take the first sample. - samples = [output.samples[0] for output in outputs] + samples = [output.samples[0] for output in compl_outputs] # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index cfa84077685a0..9f8ebaf1f4d8c 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -6,8 +6,9 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import (CompletionSequenceGroupOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -16,7 +17,7 @@ def single_step_process_prompt_logprob( sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, - output: SequenceGroupOutput) -> None: + output: CompletionSequenceGroupOutput) -> None: """Process prompt logprobs associated with the :class:`SequenceGroupOutput` for a given step. @@ -106,6 +107,7 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, """ assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] + assert isinstance(output, CompletionSequenceGroupOutput) single_step_process_prompt_logprob(self, seq_group, output) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 0c5f8fb7f5be7..a71ad493d9920 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -57,7 +57,7 @@ def maybe_stop_sequence( # Check if a stop token was encountered. # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: + if last_token_id in (sampling_params.stop_token_ids or ()): if new_char_count and ( not sampling_params.include_stop_str_in_output): # Remove last token @@ -92,7 +92,7 @@ def _check_stop_strings(seq: Sequence, new_char_count: int, Returns the stop string if matched or else None. """ - if not new_char_count: + if not new_char_count or not sampling_params.stop: return None for stop_str in sampling_params.stop: diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 76782888031e3..770982a207e6c 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,22 +1,25 @@ from typing import List from typing import Sequence as GenericSequence -from typing import Union +from typing import cast from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import PoolerOutput, SequenceGroupOutput +from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput def create_output_by_sequence_group( - outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], + outputs: GenericSequence[SamplerOutput], num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ - output_by_sequence_group: List[List[SequenceGroupOutput]] = [ + output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ [] for _ in range(num_seq_groups) ] for step in outputs: + sequence_group_output: CompletionSequenceGroupOutput for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) - return output_by_sequence_group + # Cast to the more generic type that CompletionSequenceGroupOutput + # inherits from. + return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e4184277..842c872da3964 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Sequence, TypedDict, Union, overload +from typing import List, Literal, Sequence, TypedDict, Union, cast, overload from typing_extensions import TypeIs @@ -44,13 +44,16 @@ def parse_and_batch_prompt( if is_list_of(prompt, str): # case 2: array of strings + prompt = cast(List[str], prompt) return [ ParsedText(content=elem, is_tokens=False) for elem in prompt ] if is_list_of(prompt, int): # case 3: array of tokens + prompt = cast(List[int], prompt) return [ParsedTokens(content=prompt, is_tokens=True)] if is_list_of(prompt, list): + prompt = cast(List[List[int]], prompt) if len(prompt[0]) == 0: raise ValueError("please provide at least one prompt") diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 42a6a0e6b3229..f86c6ec362ebe 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from importlib.util import find_spec from math import inf -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union import msgspec import torch @@ -117,12 +117,15 @@ class SamplerOutput( # block/sync across workers, cpu-gpu sync time and sampling time. model_execute_time: Optional[float] = None - def __getitem__(self, idx: int): + def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: return self.outputs[idx] def __setitem__(self, idx: int, value): self.outputs[idx] = value + def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: + return iter(self.outputs) + def __len__(self): return len(self.outputs) diff --git a/vllm/outputs.py b/vllm/outputs.py index 07650241cb638..15cb8d53186df 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence from typing import Union +from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -92,7 +93,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: Optional[str], + prompt: Optional[PromptType], prompt_token_ids: Optional[List[int]], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], diff --git a/vllm/sequence.py b/vllm/sequence.py index 3bb35ea955c8c..4797d7bd5c4c5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -765,7 +765,7 @@ def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill self.init_multi_step(num_steps=num_lookahead_slots + 1) - def get_last_latency(self, now: float) -> Optional[float]: + def get_last_latency(self, now: float) -> float: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. if self.is_prefill(): @@ -1175,7 +1175,7 @@ class PoolerOutput( spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - def __getitem__(self, idx: int): + def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput: return self.outputs[idx] def __setitem__(self, idx: int, value):