Skip to content

Commit

Permalink
[Bugfix][Refactor] Unify model management in frontend (#11660)
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde authored Jan 1, 2025
1 parent 0c6f998 commit 4db72e5
Show file tree
Hide file tree
Showing 15 changed files with 365 additions and 307 deletions.
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.utils import FlexibleArgumentParser

from ...utils import VLLM_PATH
Expand Down
32 changes: 29 additions & 3 deletions tests/entrypoints/openai/test_lora_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files):
"64",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
# Enable the /v1/load_lora_adapter endpoint
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}

with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
yield remote_server


Expand All @@ -67,8 +70,8 @@ async def client_for_lora_lineage(server_with_lora_modules_json):


@pytest.mark.asyncio
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
zephyr_lora_files):
models = await client_for_lora_lineage.models.list()
models = models.data
served_model = models[0]
Expand All @@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"


@pytest.mark.asyncio
async def test_dynamic_lora_lineage(
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):

response = await client_for_lora_lineage.post("load_lora_adapter",
cast_to=str,
body={
"lora_name":
"zephyr-lora-3",
"lora_path":
zephyr_lora_files
})
# Ensure adapter loads before querying /models
assert "success" in response

models = await client_for_lora_lineage.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"
20 changes: 10 additions & 10 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
Expand Down Expand Up @@ -50,14 +51,13 @@ async def _async_serving_chat_init():
engine = MockEngine()
model_config = await engine.get_model_config()

models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
serving_completion = OpenAIServingChat(engine,
model_config,
BASE_MODEL_PATHS,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion

Expand All @@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
BASE_MODEL_PATHS,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
Expand Down Expand Up @@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False

# Initialize the serving chat
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
BASE_MODEL_PATHS,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
req = ChatCompletionRequest(
model=MODEL_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import pytest

from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.lora.request import LoRARequest

MODEL_NAME = "meta-llama/Llama-2-7b"
Expand All @@ -19,101 +19,99 @@
"Success: LoRA adapter '{lora_name}' removed successfully.")


async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=EngineClient)
async def _async_serving_models_init() -> OpenAIServingModels:
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048

serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
BASE_MODEL_PATHS,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_engine
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config,
lora_modules=None,
prompt_adapters=None)

return serving_models


@pytest.mark.asyncio
async def test_serving_model_name():
serving_engine = await _async_serving_engine_init()
assert serving_engine._get_model_name(None) == MODEL_NAME
serving_models = await _async_serving_models_init()
assert serving_models.model_name(None) == MODEL_NAME
request = LoRARequest(lora_name="adapter",
lora_path="/path/to/adapter2",
lora_int_id=1)
assert serving_engine._get_model_name(request) == request.lora_name
assert serving_models.model_name(request) == request.lora_name


@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_engine.lora_requests) == 1
assert serving_engine.lora_requests[0].lora_name == "adapter"
assert len(serving_models.lora_requests) == 1
assert serving_models.lora_requests[0].lora_name == "adapter"


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 1
assert len(serving_models.lora_requests) == 1

request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert len(serving_engine.lora_requests) == 1
assert len(serving_models.lora_requests) == 1


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert len(serving_engine.lora_requests) == 1
response = await serving_models.load_lora_adapter(request)
assert len(serving_models.lora_requests) == 1

request = UnloadLoraAdapterRequest(lora_name="adapter1")
response = await serving_engine.unload_lora_adapter(request)
response = await serving_models.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 0
assert len(serving_models.lora_requests) == 0


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_engine.unload_lora_adapter(request)
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
serving_engine = await _async_serving_engine_init()
serving_models = await _async_serving_models_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_engine.unload_lora_adapter(request)
response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
Loading

0 comments on commit 4db72e5

Please sign in to comment.