Skip to content

Commit

Permalink
Fixed formatting issues
Browse files Browse the repository at this point in the history
Signed-off-by: Akshat Tripathi <[email protected]>
  • Loading branch information
Akshat-Tripathi committed Dec 10, 2024
1 parent 622c344 commit c3e9afe
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
7 changes: 4 additions & 3 deletions tests/lora/test_punica_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
sgmv_shrink = torch.ops.vllm.sgmv_shrink
else:
from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.default.lora_ops import ( # type: ignore
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)

from vllm.platforms import current_platform

Expand Down Expand Up @@ -57,6 +57,7 @@ def assert_close(a, b):
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)


@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
Expand Down
26 changes: 11 additions & 15 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
TypeVar, Union)

import torch
Expand All @@ -24,7 +24,6 @@
MultiModalKwargs, MultiModalPlaceholderMap)
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
Expand Down Expand Up @@ -198,7 +197,7 @@ def build(self) -> ModelInputForCPU:
input_data.seq_lens, input_data.query_lens, -1, -1)

is_prompt = (self.seq_group_metadata_list[0].is_prompt
if self.seq_group_metadata_list else None)
if self.seq_group_metadata_list else None)
# LoRA data.
lora_requests = set()
lora_mapping = None
Expand All @@ -210,17 +209,15 @@ def build(self) -> ModelInputForCPU:
lora_mapping = self._prepare_lora_input(
self.seq_group_metadata_list, is_prompt)

return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
lora_mapping=lora_mapping,
lora_requests=lora_requests
)
return self.model_input_cls(input_tokens=input_tokens,
input_positions=input_positions,
token_type_ids=token_type_ids,
seq_lens=input_data.seq_lens,
query_lens=input_data.query_lens,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
lora_mapping=lora_mapping,
lora_requests=lora_requests)

def _build_input_data(self):
for seq_group_metadata in self.seq_group_metadata_list:
Expand Down Expand Up @@ -411,7 +408,6 @@ def _compute_multi_modal_input(self,
self.input_data.multi_modal_placeholder_maps[modality].extend(
placeholder_map)


def _prepare_lora_input(
self, seq_group_metadata_list: List[SequenceGroupMetadata],
is_prefill: bool) -> LoRAMapping:
Expand Down

0 comments on commit c3e9afe

Please sign in to comment.