Skip to content

Commit

Permalink
Clean-up LoRA flow (#518)
Browse files Browse the repository at this point in the history
This PR cleans up LoRA flow by removing unnecessary functions and
variables.

Removed special handling of `max_num_batched_tokens` for HPU in
`models.py` since we internally handle this in PunicaWrapperHPU
[PR](https://github.com/vllm-project/vllm/blob/d51d66c3252107d5b986d2eab7af1c210dceb708/vllm/lora/punica_wrapper/punica_hpu.py#L17)

Removed `convert_mapping` from `models.py` based on this
[PR](https://github.com/vllm-project/vllm/pull/5036/files#:~:text=def%20convert_mapping)

Co-authored-by: Vivek Goel <[email protected]>
  • Loading branch information
SanjuCSudhakaran and vivekgoe authored Jan 17, 2025
1 parent b3a0db2 commit 4db525d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 127 deletions.
1 change: 0 additions & 1 deletion vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def set_lora(

def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embeddings_indices = None
embeddings_indices = self.punica_wrapper.embeddings_indices
indices = embeddings_indices[1].view_as(x)
full_lora_a_embeddings = F.embedding(
Expand Down
119 changes: 2 additions & 117 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import os
import re
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple, Type,
Union)
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import safetensors.torch
import torch
Expand All @@ -30,7 +29,6 @@
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)
Expand All @@ -50,116 +48,6 @@ class LongContextLoRAContext:
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)


def convert_mapping(
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional[LongContextLoRAContext] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], List[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors.
Used to index into each tensor. It contains length for
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices). If long_lora doesn't
exist, it only contains first 4 entries.
"""
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
device = "hpu" if current_platform.is_hpu() else "cuda"
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset

indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices, lora_indices, embedding_indices
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
device=device,
dtype=torch.long)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size)
])
embeddings_indices[embeddings_indices == -1] = max_loras - 1
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
sampler_indices_padded = (
torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) +
(sampler_indices_padded * len(sampler_indices_padded)))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1], sampler_indices.shape[-1],
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)

return (base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices, indices_len)


def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
Expand Down Expand Up @@ -440,10 +328,7 @@ def __init__(
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
punica_max_num_batched_tokens = max_num_batched_tokens
if current_platform.is_hpu():
punica_max_num_batched_tokens = 3 * max_num_batched_tokens
self.punica_wrapper = get_punica_wrapper(punica_max_num_batched_tokens,
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it.
Expand Down
9 changes: 0 additions & 9 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,15 +1914,6 @@ def get_counter_dict(self, cache_config, duration, seq_len,
return counters


def unwrap_model(model):
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
return unwrap_model(model._orig_mod)
else:
model = list(vars(model)['_modules'].values())[0]
modules = list(vars(model)['_modules'].values())
return modules


class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
"""
GPU model runner with sampling step.
Expand Down

0 comments on commit 4db525d

Please sign in to comment.