From c3e9afe07f6346c8ca42d95fe18c99495162bd37 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 10 Dec 2024 13:01:49 +0000 Subject: [PATCH] Fixed formatting issues Signed-off-by: Akshat Tripathi --- tests/lora/test_punica_variation.py | 7 ++++--- vllm/worker/cpu_model_runner.py | 26 +++++++++++--------------- 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 9e8f3196f2731..f280c758c6fd1 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -26,9 +26,9 @@ sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice sgmv_shrink = torch.ops.vllm.sgmv_shrink else: - from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) + from vllm.lora.ops.default.lora_ops import ( # type: ignore + bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform @@ -57,6 +57,7 @@ def assert_close(a, b): }[a.dtype] torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 593f70b7fd564..db8a03abf0cc5 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,7 +2,7 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, TypeVar, Union) import torch @@ -24,7 +24,6 @@ MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -198,7 +197,7 @@ def build(self) -> ModelInputForCPU: input_data.seq_lens, input_data.query_lens, -1, -1) is_prompt = (self.seq_group_metadata_list[0].is_prompt - if self.seq_group_metadata_list else None) + if self.seq_group_metadata_list else None) # LoRA data. lora_requests = set() lora_mapping = None @@ -210,17 +209,15 @@ def build(self) -> ModelInputForCPU: lora_mapping = self._prepare_lora_input( self.seq_group_metadata_list, is_prompt) - return self.model_input_cls( - input_tokens=input_tokens, - input_positions=input_positions, - token_type_ids=token_type_ids, - seq_lens=input_data.seq_lens, - query_lens=input_data.query_lens, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - lora_mapping=lora_mapping, - lora_requests=lora_requests - ) + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: @@ -411,7 +408,6 @@ def _compute_multi_modal_input(self, self.input_data.multi_modal_placeholder_maps[modality].extend( placeholder_map) - def _prepare_lora_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], is_prefill: bool) -> LoRAMapping: