From e51719ae72dd1dcdf55436a99ac8bed245b51422 Mon Sep 17 00:00:00 2001 From: Lucas Tucker <47258766+lucas-tucker@users.noreply.github.com> Date: Mon, 23 Dec 2024 07:55:49 -0600 Subject: [PATCH] mypy type checking for vllm/worker (#11418) Signed-off-by: lucast2021 Co-authored-by: lucast2021 --- vllm/worker/cpu_worker.py | 3 +-- vllm/worker/multi_step_model_runner.py | 13 +++++++------ vllm/worker/worker_base.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 09758a5d9accf..b5dfebfce6f75 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -333,9 +333,8 @@ def execute_worker( def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: assert execute_model_req is not None - virtual_engine = execute_model_req.virtual_engine + virtual_engine: int = execute_model_req.virtual_engine num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) - blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device="cpu", dtype=torch.int64).view(-1, 2) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 18b03bf1bfb56..f3d7c726a29f1 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -406,8 +406,9 @@ def _async_process_outputs(self, model_input: StatefulModelInput, if not cont: break - def _final_process_outputs(self, model_input: StatefulModelInput, - output_proc_callback: Optional[Callable]): + def _final_process_outputs( + self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]) -> List[SamplerOutput]: assert model_input.frozen_model_input is not None has_async_callback = output_proc_callback is not None @@ -594,8 +595,8 @@ def execute_model( # should be [SamplerOutput] return output - def _update_sampling_metadata(self, sampling_metadata, num_seqs, - num_queries): + def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata, + num_seqs: Optional[int], num_queries: int): assert sampling_metadata.num_prompts == 0 assert len(sampling_metadata.seq_groups) == num_queries @@ -850,13 +851,13 @@ def _pythonize_sampler_output( seq_ids = seq_group.seq_ids next_token_ids = sample_result parent_ids = [0] + seq_outputs: List[SequenceOutput] if cache is not None: completion_seq_group_output: CompletionSequenceGroupOutput = \ cache.cached_completion_seq_group_output.get_object() completion_seq_group_output.samples.clear() - seq_outputs: List[ - SequenceOutput] = completion_seq_group_output.samples + seq_outputs = completion_seq_group_output.samples else: seq_outputs = [] diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 6d00102e0a324..3ac7fb8dfb766 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -452,7 +452,7 @@ def init_worker(self, *args, **kwargs): self.worker = worker_class(*args, **kwargs) assert self.worker is not None - def execute_method(self, method, *args, **kwargs): + def execute_method(self, method: str, *args, **kwargs): try: target = self if self.worker is None else self.worker executor = getattr(target, method)