diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 45e6980a94630..e49562ad6a21f 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -4,7 +4,7 @@ from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) -from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.utils import FlexibleArgumentParser from ...utils import VLLM_PATH diff --git a/tests/entrypoints/openai/test_lora_lineage.py b/tests/entrypoints/openai/test_lora_lineage.py index ab39684c2f31a..ce4f85c13fff9 100644 --- a/tests/entrypoints/openai/test_lora_lineage.py +++ b/tests/entrypoints/openai/test_lora_lineage.py @@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files): "64", ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + # Enable the /v1/load_lora_adapter endpoint + envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server: yield remote_server @@ -67,8 +70,8 @@ async def client_for_lora_lineage(server_with_lora_modules_json): @pytest.mark.asyncio -async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, - zephyr_lora_files): +async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, + zephyr_lora_files): models = await client_for_lora_lineage.models.list() models = models.data served_model = models[0] @@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" + + +@pytest.mark.asyncio +async def test_dynamic_lora_lineage( + client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files): + + response = await client_for_lora_lineage.post("load_lora_adapter", + cast_to=str, + body={ + "lora_name": + "zephyr-lora-3", + "lora_path": + zephyr_lora_files + }) + # Ensure adapter loads before querying /models + assert "success" in response + + models = await client_for_lora_lineage.models.list() + models = models.data + dynamic_lora_model = models[-1] + assert dynamic_lora_model.root == zephyr_lora_files + assert dynamic_lora_model.parent == MODEL_NAME + assert dynamic_lora_model.id == "zephyr-lora-3" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 61677b65af342..97248f1150979 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -8,7 +8,8 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" @@ -50,14 +51,13 @@ async def _async_serving_chat_init(): engine = MockEngine() model_config = await engine.get_model_config() + models = OpenAIServingModels(model_config, BASE_MODEL_PATHS) serving_completion = OpenAIServingChat(engine, model_config, - BASE_MODEL_PATHS, + models, response_role="assistant", chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, request_logger=None) return serving_completion @@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), - BASE_MODEL_PATHS, + models, response_role="assistant", chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, request_logger=None) req = ChatCompletionRequest( model=MODEL_NAME, @@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config(): mock_engine.errored = False # Initialize the serving chat + models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) serving_chat = OpenAIServingChat(mock_engine, mock_model_config, - BASE_MODEL_PATHS, + models, response_role="assistant", chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, request_logger=None) req = ChatCompletionRequest( model=MODEL_NAME, diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_models.py similarity index 61% rename from tests/entrypoints/openai/test_serving_engine.py rename to tests/entrypoints/openai/test_serving_models.py index 096ab6fa0ac09..96897dc730da2 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -4,11 +4,11 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-2-7b" @@ -19,47 +19,45 @@ "Success: LoRA adapter '{lora_name}' removed successfully.") -async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=EngineClient) +async def _async_serving_models_init() -> OpenAIServingModels: mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 - serving_engine = OpenAIServing(mock_engine_client, - mock_model_config, - BASE_MODEL_PATHS, - lora_modules=None, - prompt_adapters=None, - request_logger=None) - return serving_engine + serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + lora_modules=None, + prompt_adapters=None) + + return serving_models @pytest.mark.asyncio async def test_serving_model_name(): - serving_engine = await _async_serving_engine_init() - assert serving_engine._get_model_name(None) == MODEL_NAME + serving_models = await _async_serving_models_init() + assert serving_models.model_name(None) == MODEL_NAME request = LoRARequest(lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1) - assert serving_engine._get_model_name(request) == request.lora_name + assert serving_models.model_name(request) == request.lora_name @pytest.mark.asyncio async def test_load_lora_adapter_success(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') - assert len(serving_engine.lora_requests) == 1 - assert serving_engine.lora_requests[0].lora_name == "adapter" + assert len(serving_models.lora_requests) == 1 + assert serving_models.lora_requests[0].lora_name == "adapter" @pytest.mark.asyncio async def test_load_lora_adapter_missing_fields(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="", lora_path="") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST @@ -67,43 +65,43 @@ async def test_load_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_load_lora_adapter_duplicate(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert response == LORA_LOADING_SUCCESS_MESSAGE.format( lora_name='adapter1') - assert len(serving_engine.lora_requests) == 1 + assert len(serving_models.lora_requests) == 1 request = LoadLoraAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST - assert len(serving_engine.lora_requests) == 1 + assert len(serving_models.lora_requests) == 1 @pytest.mark.asyncio async def test_unload_lora_adapter_success(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") - response = await serving_engine.load_lora_adapter(request) - assert len(serving_engine.lora_requests) == 1 + response = await serving_models.load_lora_adapter(request) + assert len(serving_models.lora_requests) == 1 request = UnloadLoraAdapterRequest(lora_name="adapter1") - response = await serving_engine.unload_lora_adapter(request) + response = await serving_models.unload_lora_adapter(request) assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( lora_name='adapter1') - assert len(serving_engine.lora_requests) == 0 + assert len(serving_models.lora_requests) == 0 @pytest.mark.asyncio async def test_unload_lora_adapter_missing_fields(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) - response = await serving_engine.unload_lora_adapter(request) + response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST @@ -111,9 +109,9 @@ async def test_unload_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_unload_lora_adapter_not_found(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") - response = await serving_engine.unload_lora_adapter(request) + response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bac72d87376da..74fe378fdae42 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -58,7 +58,9 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( @@ -269,6 +271,10 @@ def base(request: Request) -> OpenAIServing: return tokenization(request) +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat @@ -336,10 +342,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): - handler = base(raw_request) + handler = models(raw_request) - models = await handler.show_available_models() - return JSONResponse(content=models.model_dump()) + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) @router.get("/version") @@ -505,26 +511,22 @@ async def stop_profile(raw_request: Request): @router.post("/v1/load_lora_adapter") async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + handler = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + handler = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @@ -628,13 +630,18 @@ def init_app_state( resolved_chat_template = load_chat_template(args.chat_template) logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_models = OpenAIServingModels( + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + # TODO: The chat template is now broken for lora adapters :( state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, - base_model_paths, + state.openai_serving_models, args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -646,16 +653,14 @@ def init_app_state( state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, - base_model_paths, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, + state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) if model_config.runner_type == "generate" else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -663,7 +668,7 @@ def init_app_state( state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -671,14 +676,13 @@ def init_app_state( state.openai_serving_scores = OpenAIServingScores( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger ) if model_config.task == "score" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, - base_model_paths, - lora_modules=args.lora_modules, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 908f8c3532c9e..22206ef8dbfe6 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,7 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 572ed27b39083..822c0f5f7c211 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,7 +20,8 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -213,13 +214,17 @@ async def main(args): request_logger = RequestLogger(max_log_len=args.max_log_len) # Create the openai serving objects. + openai_serving_models = OpenAIServingModels( + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + ) openai_serving_chat = OpenAIServingChat( engine, model_config, - base_model_paths, + openai_serving_models, args.response_role, - lora_modules=None, - prompt_adapters=None, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", @@ -228,7 +233,7 @@ async def main(args): openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, - base_model_paths, + openai_serving_models, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d085333563d19..9ba5eeb7709c9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -21,10 +21,8 @@ ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - OpenAIServing, - PromptAdapterPath) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput @@ -42,11 +40,9 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, response_role: str, *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, @@ -57,9 +53,7 @@ def __init__( ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - prompt_adapters=prompt_adapters, + models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) @@ -126,7 +120,7 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - model_name = self._get_model_name(lora_request) + model_name = self.models.model_name(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index aaad7b8c7f44c..17197dce8da23 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -21,10 +21,8 @@ RequestResponseMetadata, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - OpenAIServing, - PromptAdapterPath) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -41,18 +39,14 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - prompt_adapters=prompt_adapters, + models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) diff_sampling_param = self.model_config.get_diff_sampling_param() @@ -170,7 +164,7 @@ async def create_completion( result_generator = merge_async_iterators(*generators) - model_name = self._get_model_name(lora_request) + model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index b8fb9d6bd77f2..e7116a3d95d10 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -16,7 +16,8 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) @@ -46,7 +47,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], @@ -54,9 +55,7 @@ def __init__( ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=None, - prompt_adapters=None, + models=models, request_logger=request_logger) self.chat_template = chat_template diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5b6a089e4c319..319f869240036 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,7 +1,5 @@ import json -import pathlib from concurrent.futures.thread import ThreadPoolExecutor -from dataclasses import dataclass from http import HTTPStatus from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, TypedDict, Union) @@ -28,13 +26,10 @@ DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, - ErrorResponse, - LoadLoraAdapterRequest, - ModelCard, ModelList, - ModelPermission, ScoreRequest, + ErrorResponse, ScoreRequest, TokenizeChatRequest, - TokenizeCompletionRequest, - UnloadLoraAdapterRequest) + TokenizeCompletionRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable from vllm.inputs import TokensPrompt @@ -48,30 +43,10 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid +from vllm.utils import is_list_of, make_async, random_uuid logger = init_logger(__name__) - -@dataclass -class BaseModelPath: - name: str - model_path: str - - -@dataclass -class PromptAdapterPath: - name: str - local_path: str - - -@dataclass -class LoRAModulePath: - name: str - path: str - base_model_name: Optional[str] = None - - CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, EmbeddingCompletionRequest, ScoreRequest, TokenizeCompletionRequest] @@ -96,10 +71,8 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): @@ -109,35 +82,7 @@ def __init__( self.model_config = model_config self.max_model_len = model_config.max_model_len - self.base_model_paths = base_model_paths - - self.lora_id_counter = AtomicCounter(0) - self.lora_requests = [] - if lora_modules is not None: - self.lora_requests = [ - LoRARequest(lora_name=lora.name, - lora_int_id=i, - lora_path=lora.path, - base_model_name=lora.base_model_name - if lora.base_model_name - and self._is_model_supported(lora.base_model_name) - else self.base_model_paths[0].name) - for i, lora in enumerate(lora_modules, start=1) - ] - - self.prompt_adapter_requests = [] - if prompt_adapters is not None: - for i, prompt_adapter in enumerate(prompt_adapters, start=1): - with pathlib.Path(prompt_adapter.local_path, - "adapter_config.json").open() as f: - adapter_config = json.load(f) - num_virtual_tokens = adapter_config["num_virtual_tokens"] - self.prompt_adapter_requests.append( - PromptAdapterRequest( - prompt_adapter_name=prompt_adapter.name, - prompt_adapter_id=i, - prompt_adapter_local_path=prompt_adapter.local_path, - prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + self.models = models self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids @@ -150,33 +95,6 @@ def __init__( self._tokenize_prompt_input_or_inputs, executor=self._tokenizer_executor) - async def show_available_models(self) -> ModelList: - """Show available models. Right now we only have one model.""" - model_cards = [ - ModelCard(id=base_model.name, - max_model_len=self.max_model_len, - root=base_model.model_path, - permission=[ModelPermission()]) - for base_model in self.base_model_paths - ] - lora_cards = [ - ModelCard(id=lora.lora_name, - root=lora.local_path, - parent=lora.base_model_name if lora.base_model_name else - self.base_model_paths[0].name, - permission=[ModelPermission()]) - for lora in self.lora_requests - ] - prompt_adapter_cards = [ - ModelCard(id=prompt_adapter.prompt_adapter_name, - root=self.base_model_paths[0].name, - permission=[ModelPermission()]) - for prompt_adapter in self.prompt_adapter_requests - ] - model_cards.extend(lora_cards) - model_cards.extend(prompt_adapter_cards) - return ModelList(data=model_cards) - def create_error_response( self, message: str, @@ -205,11 +123,13 @@ async def _check_model( ) -> Optional[ErrorResponse]: if self._is_model_supported(request.model): return None - if request.model in [lora.lora_name for lora in self.lora_requests]: + if request.model in [ + lora.lora_name for lora in self.models.lora_requests + ]: return None if request.model in [ prompt_adapter.prompt_adapter_name - for prompt_adapter in self.prompt_adapter_requests + for prompt_adapter in self.models.prompt_adapter_requests ]: return None return self.create_error_response( @@ -223,10 +143,10 @@ def _maybe_get_adapters( None, PromptAdapterRequest]]: if self._is_model_supported(request.model): return None, None - for lora in self.lora_requests: + for lora in self.models.lora_requests: if request.model == lora.lora_name: return lora, None - for prompt_adapter in self.prompt_adapter_requests: + for prompt_adapter in self.models.prompt_adapter_requests: if request.model == prompt_adapter.prompt_adapter_name: return None, prompt_adapter # if _check_model has been called earlier, this will be unreachable @@ -588,91 +508,5 @@ def _get_decoded_token(logprob: Logprob, return logprob.decoded_token return tokenizer.decode(token_id) - async def _check_load_lora_adapter_request( - self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: - # Check if both 'lora_name' and 'lora_path' are provided - if not request.lora_name or not request.lora_path: - return self.create_error_response( - message="Both 'lora_name' and 'lora_path' must be provided.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - # Check if the lora adapter with the given name already exists - if any(lora_request.lora_name == request.lora_name - for lora_request in self.lora_requests): - return self.create_error_response( - message= - f"The lora adapter '{request.lora_name}' has already been" - "loaded.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - return None - - async def _check_unload_lora_adapter_request( - self, - request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: - # Check if either 'lora_name' or 'lora_int_id' is provided - if not request.lora_name and not request.lora_int_id: - return self.create_error_response( - message= - "either 'lora_name' and 'lora_int_id' needs to be provided.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - # Check if the lora adapter with the given name exists - if not any(lora_request.lora_name == request.lora_name - for lora_request in self.lora_requests): - return self.create_error_response( - message= - f"The lora adapter '{request.lora_name}' cannot be found.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - return None - - async def load_lora_adapter( - self, - request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: - error_check_ret = await self._check_load_lora_adapter_request(request) - if error_check_ret is not None: - return error_check_ret - - lora_name, lora_path = request.lora_name, request.lora_path - unique_id = self.lora_id_counter.inc(1) - self.lora_requests.append( - LoRARequest(lora_name=lora_name, - lora_int_id=unique_id, - lora_path=lora_path)) - return f"Success: LoRA adapter '{lora_name}' added successfully." - - async def unload_lora_adapter( - self, - request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: - error_check_ret = await self._check_unload_lora_adapter_request(request - ) - if error_check_ret is not None: - return error_check_ret - - lora_name = request.lora_name - self.lora_requests = [ - lora_request for lora_request in self.lora_requests - if lora_request.lora_name != lora_name - ] - return f"Success: LoRA adapter '{lora_name}' removed successfully." - def _is_model_supported(self, model_name): - return any(model.name == model_name for model in self.base_model_paths) - - def _get_model_name(self, lora: Optional[LoRARequest]): - """ - Returns the appropriate model name depending on the availability - and support of the LoRA or base model. - Parameters: - - lora: LoRARequest that contain a base_model_name. - Returns: - - str: The name of the base model or the first available model path. - """ - if lora is not None: - return lora.lora_name - return self.base_model_paths[0].name + return self.models.is_base_model(model_name) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py new file mode 100644 index 0000000000000..26966896bc272 --- /dev/null +++ b/vllm/entrypoints/openai/serving_models.py @@ -0,0 +1,210 @@ +import json +import pathlib +from dataclasses import dataclass +from http import HTTPStatus +from typing import List, Optional, Union + +from vllm.config import ModelConfig +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoraAdapterRequest, + ModelCard, ModelList, + ModelPermission, + UnloadLoraAdapterRequest) +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.utils import AtomicCounter + + +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: Optional[str] = None + + +class OpenAIServingModels: + """Shared instance to hold data about the loaded base model(s) and adapters. + + Handles the routes: + - /v1/models + - /v1/load_lora_adapter + - /v1/unload_lora_adapter + """ + + def __init__( + self, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + lora_modules: Optional[List[LoRAModulePath]] = None, + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): + super().__init__() + + self.base_model_paths = base_model_paths + self.max_model_len = model_config.max_model_len + + self.lora_id_counter = AtomicCounter(0) + self.lora_requests = [] + if lora_modules is not None: + self.lora_requests = [ + LoRARequest(lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + base_model_name=lora.base_model_name + if lora.base_model_name + and self.is_base_model(lora.base_model_name) else + self.base_model_paths[0].name) + for i, lora in enumerate(lora_modules, start=1) + ] + + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + def is_base_model(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) + + def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + """Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora_request is not None: + return lora_request.lora_name + return self.base_model_paths[0].name + + async def show_available_models(self) -> ModelList: + """Show available models. This includes the base model and all + adapters""" + model_cards = [ + ModelCard(id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()]) + for base_model in self.base_model_paths + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.base_model_paths[0].name, + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) + return ModelList(data=model_cards) + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 01852f0df1eca..5830322071e58 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -15,7 +15,8 @@ PoolingChatRequest, PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators @@ -44,7 +45,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], @@ -52,9 +53,7 @@ def __init__( ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=None, - prompt_adapters=None, + models=models, request_logger=request_logger) self.chat_template = chat_template diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index a8a126e697641..5d3e7139d7a17 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -10,7 +10,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, ScoreResponse, ScoreResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput @@ -50,15 +51,13 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=None, - prompt_adapters=None, + models=models, request_logger=request_logger) async def create_score( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 2e849333680d4..b67ecfb01316f 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -15,9 +15,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - OpenAIServing) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger logger = init_logger(__name__) @@ -29,18 +28,15 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, - lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - prompt_adapters=None, + models=models, request_logger=request_logger) self.chat_template = chat_template