Skip to content

Commit

Permalink
format and fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
Varun Sundar Rabindranath committed Dec 17, 2024
1 parent 0c9ca09 commit 797dab2
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 25 deletions.
2 changes: 1 addition & 1 deletion vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
assert lora_request is None
return self.detokenizer.tokenizer
return self.detokenizer.get_tokenizer(lora_request)

async def is_tracing_enabled(self) -> bool:
return False
Expand Down
14 changes: 10 additions & 4 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import (
Expand Down Expand Up @@ -208,6 +209,14 @@ def __init__(self,
# Request id -> IncrementalDetokenizer
self.request_states: Dict[str, IncrementalDetokenizer] = {}

def get_tokenizer(self,
lora_request: Optional[LoRARequest] = None
) -> AnyTokenizer:
if lora_request:
return get_lora_tokenizer(lora_request)
else:
return self._base_tokenizer

def is_request_active(self, request_id: str):
return request_id in self.request_states

Expand All @@ -234,10 +243,7 @@ def add_request(

assert (request.request_id not in self.request_states)

req_tokenizer = self._base_tokenizer if (
request.lora_request is None) else get_lora_tokenizer(
request.lora_request)

req_tokenizer = self.get_tokenizer(request.lora_request)
request_state = IncrementalDetokenizer.from_new_request(
req_tokenizer, request)
self.request_states[request.request_id] = request_state
Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,9 @@ def make_sampling_metadata(
max_num_logprobs=self.max_num_logprobs,
)

def make_lora_inputs(self, num_scheduled_tokens: np.array) \
-> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Expand Down
23 changes: 12 additions & 11 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional
from typing import TYPE_CHECKING, Dict, List, Tuple, cast

import numpy as np
import torch
Expand Down Expand Up @@ -220,7 +220,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=req_data.lora_request,
lora_request=new_req_data.lora_request,
)
req_ids_to_add.append(req_id)

Expand Down Expand Up @@ -264,15 +264,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
num_scheduled_tokens_list = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
num_scheduled_tokens_list.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
assert max_num_scheduled_tokens > 0

# Get request indices.
Expand Down Expand Up @@ -632,13 +633,13 @@ def profile_run(self) -> None:
num_tokens = self.max_num_tokens
min_tokens_per_req: int = num_tokens // num_reqs

num_scheduled_tokens: List[int] = [min_tokens_per_req] * num_reqs
num_scheduled_tokens[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens) == num_tokens
assert len(num_scheduled_tokens) == num_reqs
num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs

num_scheduled_tokens: np.array = np.array(num_scheduled_tokens,
dtype=np.int32)
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1

with self.maybe_profile_with_lora(self.lora_config,
Expand Down
10 changes: 5 additions & 5 deletions vllm/v1/worker/lora_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from contextlib import contextmanager
from typing import List, Set, Tuple
from typing import Set, Tuple

import numpy as np
import torch.nn as nn
Expand Down Expand Up @@ -71,7 +71,7 @@ def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...],
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)

def set_active_loras(self, input_batch: InputBatch,
num_scheduled_tokens: np.array) -> None:
num_scheduled_tokens: np.ndarray) -> None:

prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs
token_lora_mapping: Tuple[int,
Expand All @@ -84,7 +84,7 @@ def set_active_loras(self, input_batch: InputBatch,

@contextmanager
def maybe_profile_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens: np.array):
num_scheduled_tokens: np.ndarray):
if lora_config is None:
yield
else:
Expand All @@ -105,12 +105,12 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig,
num_scheduled_tokens)

# Make dummy lora requests
lora_requests: List[LoRARequest] = [
lora_requests: Set[LoRARequest] = {
LoRARequest(lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path")
for lora_id in range(1, num_loras + 1)
]
}

with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
Expand Down

0 comments on commit 797dab2

Please sign in to comment.