From 0c9ca096300437735b0f28bbe752cf3224e95c5b Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 8 Nov 2024 17:17:26 +0000 Subject: [PATCH] Add lora support Signed-off-by: Varun Sundar Rabindranath --- tests/lora/conftest.py | 18 +++ tests/lora/test_baichuan.py | 9 ++ tests/lora/test_chatglm3_tp.py | 12 ++ tests/lora/test_gemma.py | 8 ++ tests/lora/test_llama_tp.py | 12 ++ tests/lora/test_lora_bias_e2e.py | 8 ++ tests/lora/test_minicpmv.py | 8 ++ tests/lora/test_phi.py | 10 ++ tests/lora/test_quant_model.py | 9 ++ vllm/engine/arg_utils.py | 4 + vllm/envs.py | 5 + vllm/v1/core/scheduler.py | 24 +++- vllm/v1/engine/__init__.py | 2 + vllm/v1/engine/detokenizer.py | 21 ++-- vllm/v1/engine/processor.py | 3 +- vllm/v1/worker/gpu_input_batch.py | 58 +++++++++- vllm/v1/worker/gpu_model_runner.py | 52 +++++++-- vllm/v1/worker/lora_model_runner_mixin.py | 129 ++++++++++++++++++++++ 18 files changed, 370 insertions(+), 22 deletions(-) create mode 100644 vllm/v1/worker/lora_model_runner_mixin.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 29ecf37808205..57d937d95ab6c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -269,3 +269,21 @@ def get_model_patched(**kwargs): def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. model_runner.model) + + +@pytest.fixture(params=[True, False]) +def run_with_both_engines_lora(request): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + with patch('vllm.envs.VLLM_USE_V1', True), patch( + 'vllm.envs.VLLM_V1_FORCE_DISABLE_PREFIX_CACHING', True): + yield + else: + with patch('vllm.envs.VLLM_USE_V1', False): + yield diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 0ba2ce3617b67..393026ef47c3f 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_baichuan_lora(baichuan_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, @@ -62,6 +70,7 @@ def test_baichuan_lora(baichuan_lora_files): assert output2[i] == expected_lora_output[i] +@pytest.mark.skip_v1 @pytest.mark.parametrize("fully_sharded", [True, False]) def test_baichuan_tensor_parallel_equality(baichuan_lora_files, num_gpus_available, fully_sharded): diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 49a527b99ac16..91d69fefddcb6 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest @@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, @@ -64,6 +74,7 @@ def test_chatglm3_lora(chatglm3_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_chatglm3_lora_tp4(chatglm3_lora_files): @@ -85,6 +96,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files): assert output2[i] == EXPECTED_LORA_OUTPUT[i] +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 5ae705e474ec6..93bc619069570 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail(current_platform.is_rocm(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index dfeac380951d8..07aec1a85d173 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,5 +1,6 @@ from typing import List +import pytest import ray import vllm @@ -71,6 +72,14 @@ def generate_and_test(llm, sql_lora_files): print("removing lora") +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_llama_lora(sql_lora_files): @@ -111,6 +120,7 @@ def get_num_gpu_blocks_no_lora(): "less when using lora than when not using lora") +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_llama_lora_tp4(sql_lora_files): @@ -126,6 +136,7 @@ def test_llama_lora_tp4(sql_lora_files): generate_and_test(llm, sql_lora_files) +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): @@ -142,6 +153,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): generate_and_test(llm, sql_lora_files) +@pytest.mark.skip_v1 @multi_gpu_test(num_gpus=4) @fork_new_process_for_each_test def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py index c2520c847d873..2f7ab4128b553 100644 --- a/tests/lora/test_lora_bias_e2e.py +++ b/tests/lora/test_lora_bias_e2e.py @@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("lora_bias", [True]) @pytest.mark.parametrize("fully_sharded", [True, False]) def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 1f3de9edc0d0f..776df53e66e25 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -56,6 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail( current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm") diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 5a3fcb8d690d9..8656a4b6b870b 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 026269667b473..9ba2cf9c142a1 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -68,6 +68,14 @@ def format_prompt_tuples(prompt): return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, @@ -163,6 +171,7 @@ def expect_match(output, expected_output): cleanup_dist_env_and_memory() +@pytest.mark.skip_v1 @pytest.mark.parametrize("model", MODELS) def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, model): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0aa367a173b6c..511834713ffb2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -203,6 +203,10 @@ def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model + # Force prefix caching disable. + if envs.VLLM_USE_V1 and envs.VLLM_V1_FORCE_DISABLE_PREFIX_CACHING: + self.enable_prefix_caching = False + # Override the default value of enable_prefix_caching if it's not set # by user. if self.enable_prefix_caching is None: diff --git a/vllm/envs.py b/vllm/envs.py index da17b747ea215..c42cc9dbac328 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,6 +69,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False + VLLM_V1_FORCE_DISABLE_PREFIX_CACHING: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -458,6 +459,10 @@ def get_default_config_root(): "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # If set, disable prefix caching irrespective of the engine args setting. + "VLLM_V1_FORCE_DISABLE_PREFIX_CACHING": + lambda: bool(int(os.getenv("VLLM_V1_FORCE_DISABLE_PREFIX_CACHING", "0"))), + # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index f76364f64033d..3f4d510326fce 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -7,6 +7,7 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -32,8 +33,6 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config - # TODO: Support LoRA. - assert lora_config is None, "V1 does not support LoRA yet." # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -174,6 +173,14 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Record the LoRAs in scheduled_running_reqs + requested_loras: Set[int] = set() + if self.lora_config: + requested_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(requested_loras) <= self.lora_config.max_loras + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: @@ -185,6 +192,17 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request: + req_lora_id = request.lora_request.lora_int_id + if len(requested_loras) == self.lora_config.max_loras and ( + req_lora_id not in requested_loras): + # cannot schedule + break + requested_loras.add(req_lora_id) + # Get already-cached tokens. computed_blocks = self.kv_cache_manager.get_computed_blocks( request) @@ -521,6 +539,7 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int + lora_request: Optional[LoRARequest] @classmethod def from_request( @@ -538,6 +557,7 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, ) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index cc0c7ea23469a..51a83847330d3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -22,6 +22,8 @@ class DetokenizerRequest: stop: List[str] include_stop_str_in_output: bool + lora_request: Optional[LoRARequest] + @dataclass class EngineCoreRequest: diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 02f34e2b54dd5..fcfd7466a39c5 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -7,7 +7,7 @@ from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_tokenizer from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput logger = init_logger(__name__) @@ -197,12 +197,13 @@ def __init__(self, tokenizer_mode: str = "auto", trust_remote_code: bool = False, revision: Optional[str] = None): - # TODO: once we support LoRA, we should should pass the tokenizer - # here. We currently have two copies (this + in the LLMEngine). - self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - revision=revision) + # per-request tokenizers, like in LoRA, are created in + # add_request. All other requests use the base_tokenizer. + self._base_tokenizer = get_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + revision=revision) # Request id -> IncrementalDetokenizer self.request_states: Dict[str, IncrementalDetokenizer] = {} @@ -233,8 +234,12 @@ def add_request( assert (request.request_id not in self.request_states) + req_tokenizer = self._base_tokenizer if ( + request.lora_request is None) else get_lora_tokenizer( + request.lora_request) + request_state = IncrementalDetokenizer.from_new_request( - self.tokenizer, request) + req_tokenizer, request) self.request_states[request.request_id] = request_state def step( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 679bf8e25e9ca..a7c1f4d28640f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -49,7 +49,7 @@ def __init__( ) if model_config.mm_cache_preprocessor else None # TODO: run in an ThreadpoolExecutor or BackgroundProcess. - # This ideally should releases the GIL, so we should not block the + # This ideally should release the GIL, so we should not block the # asyncio loop while this is running. def process_inputs( self, @@ -132,6 +132,7 @@ def process_inputs( sampling_params.output_kind, sampling_params.stop, sampling_params.include_stop_str_in_output, + lora_request, ) # Make Request for EngineCore. diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 5c113c74778df..c4383a11aeb91 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,11 +1,12 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata @@ -29,6 +30,8 @@ class CachedRequestState: num_computed_tokens: int output_token_ids: List[int] + lora_request: Optional[LoRARequest] + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) @@ -110,6 +113,11 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_requests: Set[LoRARequest] = set() + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -169,6 +177,15 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) + # Add request lora ID + if request.lora_request: + self.request_lora_mapping[ + req_index] = request.lora_request.lora_int_id + self.lora_requests.add(request.lora_request) + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: @@ -182,6 +199,12 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) + + # LoRA + # only update request_lora_mapping. Defer the updates + # to lora_requests to prepare_lora_inputs. + self.request_lora_mapping[req_index] = 0 + return req_index def clear(self) -> None: @@ -194,6 +217,9 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_requests.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -236,6 +262,9 @@ def condense(self, empty_req_indices: List[int]) -> None: if generator is not None: self.generators[empty_index] = generator + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -262,6 +291,33 @@ def make_sampling_metadata( max_num_logprobs=self.max_num_logprobs, ) + def make_lora_inputs(self, num_scheduled_tokens: np.array) \ + -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + + active_lora_ids: Set[int] = set(np.unique(req_lora_mapping)) + active_lora_requests: Set[LoRARequest] = \ + set({lr for lr in self.lora_requests \ + if lr.lora_int_id in active_lora_ids}) + # Update lora requests + self.lora_requests = active_lora_requests + + return prompt_lora_mapping, token_lora_mapping, self.lora_requests + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 67166fb05085c..b7b744a5981e6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Tuple, cast, Optional import numpy as np import torch @@ -22,6 +22,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -29,7 +30,7 @@ logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -219,6 +220,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], + lora_request=req_data.lora_request, ) req_ids_to_add.append(req_id) @@ -352,6 +354,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): block_table=self.input_batch.block_table[:num_reqs], slot_mapping=slot_mapping, ) + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled @@ -570,6 +577,12 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -612,14 +625,33 @@ def profile_run(self) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) - logits = self.model.compute_logits(hidden_states, None) - logits = logits[:self.max_num_tokens] - # TODO(woosuk): Consider the memory usage of the sampler. - torch.cuda.synchronize() - del hidden_states, logits + + # compute num tokens per request. For profile, have maximum num_reqs and + # that collectively have maximum num_tokens. + num_reqs = self.scheduler_config.max_num_seqs + num_tokens = self.max_num_tokens + min_tokens_per_req: int = num_tokens // num_reqs + + num_scheduled_tokens: List[int] = [min_tokens_per_req] * num_reqs + num_scheduled_tokens[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens) == num_tokens + assert len(num_scheduled_tokens) == num_reqs + + num_scheduled_tokens: np.array = np.array(num_scheduled_tokens, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + + with self.maybe_profile_with_lora(self.lora_config, + num_scheduled_tokens): + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + torch.cuda.synchronize() + + del hidden_states, logits gc.collect() def capture_model(self) -> None: diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000000000..19156fd75e5e3 --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,129 @@ +""" +Define LoRA adapter for model runner. +""" + +from contextlib import contextmanager +from typing import List, Set, Tuple + +import numpy as np +import torch.nn as nn + +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + LORA_WARMUP_RANK = 8 + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + assert supports_lora( + model), f"{model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(model.config, "max_position_embeddings"): + max_pos_embeddings = model.config.max_position_embeddings + else: + max_pos_embeddings = ( + model.config.text_config.max_position_embeddings) + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], + token_lora_mapping: Tuple[int, ...], + lora_requests: Set[LoRARequest]) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + # We dont make any distinction between prefills and decodes in the + # scheduler. To that effect, set is_prefill to True so we use the + # sgmv punica kernels always. + lora_mapping = LoRAMapping(token_lora_mapping, + prompt_lora_mapping, + is_prefill=True) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def set_active_loras(self, input_batch: InputBatch, + num_scheduled_tokens: np.array) -> None: + + prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: Tuple[int, + ...] # of size np.sum(num_scheduled_tokens) + lora_requests: Set[LoRARequest] + prompt_lora_mapping, token_lora_mapping, lora_requests = \ + input_batch.make_lora_inputs(num_scheduled_tokens) + return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + + @contextmanager + def maybe_profile_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.array): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_reqs = len(num_scheduled_tokens) + num_loras = lora_config.max_loras + + # Make prompt lora mapping + # Assign LoRA IDs to requests arbitrarily + prompt_lora_mapping = np.random.randint(low=1, + high=num_loras + 1, + size=num_reqs, + dtype=np.int32) + # Make token lora mapping + token_lora_mapping = np.repeat(prompt_lora_mapping, + num_scheduled_tokens) + + # Make dummy lora requests + lora_requests: List[LoRARequest] = [ + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + ] + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), + lora_requests) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters()