Skip to content

Commit

Permalink
Add lora support
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
Varun Sundar Rabindranath committed Dec 6, 2024
1 parent 3bc94ca commit e730a07
Show file tree
Hide file tree
Showing 18 changed files with 653 additions and 300 deletions.
18 changes: 18 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,21 @@ def get_model_patched(**kwargs):
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)


@pytest.fixture(params=[True, False])
def run_with_both_engines_lora(request):
# Automatically runs tests twice, once with V1 and once without
use_v1 = request.param
# Tests decorated with `@skip_v1` are only run without v1
skip_v1 = request.node.get_closest_marker("skip_v1")

if use_v1:
if skip_v1:
pytest.skip("Skipping test on vllm V1")
with patch('vllm.envs.VLLM_USE_V1', True), \
patch('vllm.envs.VLLM_V1_FORCE_DISABLE_PREFIX_CACHING', True):
yield
else:
with patch('vllm.envs.VLLM_USE_V1', False):
yield
9 changes: 9 additions & 0 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
Expand All @@ -62,6 +70,7 @@ def test_baichuan_lora(baichuan_lora_files):
assert output2[i] == expected_lora_output[i]


@pytest.mark.skip_v1
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
num_gpus_available, fully_sharded):
Expand Down
12 changes: 12 additions & 0 deletions tests/lora/test_chatglm3_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
Expand All @@ -63,6 +73,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
Expand All @@ -83,6 +94,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(current_platform.is_rocm(),
reason="There can be output mismatch on ROCm")
def test_gemma_lora(gemma_lora_files):
Expand Down
16 changes: 13 additions & 3 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import pytest
import ray

import vllm
Expand Down Expand Up @@ -48,10 +49,9 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

return generated_texts


Expand All @@ -71,9 +71,16 @@ def generate_and_test(llm, sql_lora_files):
print("removing lora")


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
Expand Down Expand Up @@ -110,6 +117,7 @@ def get_num_gpu_blocks_no_lora():
"less when using lora than when not using lora")


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4(sql_lora_files):
Expand All @@ -124,6 +132,7 @@ def test_llama_lora_tp4(sql_lora_files):
generate_and_test(llm, sql_lora_files)


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
Expand All @@ -139,6 +148,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
generate_and_test(llm, sql_lora_files)


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_lora_bias_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
Expand Down
8 changes: 8 additions & 0 deletions tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.xfail(
current_platform.is_rocm(),
reason="MiniCPM-V dependency xformers incompatible with ROCm")
Expand Down
10 changes: 10 additions & 0 deletions tests/lora/test_phi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


def test_phi2_lora(phi2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
Expand Down
9 changes: 9 additions & 0 deletions tests/lora/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def format_prompt_tuples(prompt):
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
Expand Down Expand Up @@ -162,6 +170,7 @@ def expect_match(output, expected_output):
cleanup_dist_env_and_memory()


@pytest.mark.skip_v1
@pytest.mark.parametrize("model", MODELS)
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model):
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model

# Force prefix caching disable.
if envs.VLLM_USE_V1 and envs.VLLM_V1_FORCE_DISABLE_PREFIX_CACHING:
self.enable_prefix_caching = False

# Override the default value of enable_prefix_caching if it's not set
# by user.
if self.enable_prefix_caching is None:
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_V1_FORCE_DISABLE_PREFIX_CACHING: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False


Expand Down Expand Up @@ -454,6 +455,10 @@ def get_default_config_root():
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),

# If set, disable prefix caching irrespective of the engine args setting.
"VLLM_V1_FORCE_DISABLE_PREFIX_CACHING":
lambda: bool(int(os.getenv("VLLM_V1_FORCE_DISABLE_PREFIX_CACHING", "0"))),

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
Expand Down
26 changes: 24 additions & 2 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
Expand All @@ -30,8 +31,6 @@ def __init__(
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
# TODO: Support LoRA.
assert lora_config is None, "V1 does not support LoRA yet."

# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
Expand Down Expand Up @@ -171,6 +170,16 @@ def schedule(self) -> "SchedulerOutput":
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

# Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set()
if self.lora_config:
requested_loras = \
set(req.lora_request.lora_int_id \
for req in scheduled_running_reqs \
if req.lora_request and \
req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras

# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting:
Expand All @@ -182,6 +191,17 @@ def schedule(self) -> "SchedulerOutput":
break

request = self.waiting[0]

# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
req_lora_id = request.lora_request.lora_int_id
if len(requested_loras) == self.lora_config.max_loras and \
req_lora_id not in requested_loras:
# cannot schedule
break
requested_loras.add(req_lora_id)

# Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks(
request)
Expand Down Expand Up @@ -514,6 +534,7 @@ class NewRequestData:
sampling_params: SamplingParams
block_ids: List[int]
num_computed_tokens: int
lora_request: Optional[LoRARequest]

@classmethod
def from_request(
Expand All @@ -531,6 +552,7 @@ def from_request(
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
lora_request=request.lora_request,
)


Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class DetokenizerRequest:
stop: List[str]
include_stop_str_in_output: bool

lora_request: Optional[LoRARequest]


@dataclass
class EngineCoreRequest:
Expand Down
21 changes: 13 additions & 8 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_tokenizer
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput

logger = init_logger(__name__)
Expand Down Expand Up @@ -197,12 +197,13 @@ def __init__(self,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
revision: Optional[str] = None):
# TODO: once we support LoRA, we should should pass the tokenizer
# here. We currently have two copies (this + in the LLMEngine).
self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
revision=revision)
# per-request tokenizers, like in LoRA, are created in
# add_request. All other requests use the base_tokenizer.
self._base_tokenizer = get_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
revision=revision)

# Request id -> IncrementalDetokenizer
self.request_states: Dict[str, IncrementalDetokenizer] = {}
Expand Down Expand Up @@ -233,8 +234,12 @@ def add_request(

assert (request.request_id not in self.request_states)

req_tokenizer = self._base_tokenizer \
if request.lora_request is None else \
get_lora_tokenizer(request.lora_request)

request_state = IncrementalDetokenizer.from_new_request(
self.tokenizer, request)
req_tokenizer, request)
self.request_states[request.request_id] = request_state

def step(
Expand Down
Loading

0 comments on commit e730a07

Please sign in to comment.