Skip to content

Commit

Permalink
Merge branch 'main' into release
Browse files Browse the repository at this point in the history
  • Loading branch information
gsolard committed Sep 17, 2024
2 parents 294f375 + 2d65723 commit da48451
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ license = {file="LICENSE"}
readme = "README.md"
requires-python = ">=3.10,<4.0"
dependencies = [
"vllm>=0.6.1.post1,<1.0",
"fastapi>=0.114.1,<1.0",
"vllm>=0.6.1.post2,<1.0",
"fastapi>=0.114.2,<1.0",
"pydantic_settings>=2.5.2,<3.0",
"uvicorn[standard]>=0.30.6,<1.0",
"prometheus_client>=0.20.0,<1.0",
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
vllm==0.6.1.post1
fastapi==0.114.1
vllm==0.6.1.post2
fastapi==0.114.2
pydantic-settings==2.5.2
uvicorn[standard]==0.30.6
prometheus_client==0.20.0
Expand Down
24 changes: 15 additions & 9 deletions src/happy_vllm/model/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.entrypoints.logger import RequestLogger
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc import RPCUtilityRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
Expand Down Expand Up @@ -84,24 +85,29 @@ async def _load_model(self, async_engine_client: AsyncEngineRPCClient, args: Nam

logger.info(f"Loading the model from {args.model}")
if args.model_name != "TEST MODEL":
self._model = async_engine_client
if isinstance(self._model.tokenizer, TokenizerGroup): # type: ignore
self._tokenizer = self._model.tokenizer.tokenizer # type: ignore
self._model = async_engine_client
model_config = await self._model.get_model_config()
# Define the tokenizer differently if we have an AsyncLLMEngine
if isinstance(self._model, AsyncLLMEngine):
tokenizer_tmp = self._model.engine.tokenizer
else:
self._tokenizer = self._model.tokenizer # type: ignore
tokenizer_tmp = self._model.tokenizer
if isinstance(tokenizer_tmp, TokenizerGroup): # type: ignore
self._tokenizer = tokenizer_tmp.tokenizer # type: ignore
else:
self._tokenizer = tokenizer_tmp # type: ignore
self._tokenizer_lmformatenforcer = build_token_enforcer_tokenizer_data(self._tokenizer)
self.max_model_len = self._model.model_config.max_model_len # type: ignore
self.max_model_len = model_config.max_model_len # type: ignore
# To take into account Mistral tokenizers
try:
self.original_truncation_side = self._tokenizer.truncation_side # type: ignore
except:
self.original_truncation_side = "left"
model_config = await self._model._get_model_config_rpc()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
self.openai_serving_chat = OpenAIServingChat(cast(AsyncEngineClient,self._model), model_config, [args.model_name],
self.openai_serving_chat = OpenAIServingChat(cast(AsyncEngineClient, self._model), model_config, [args.model_name],
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
Expand All @@ -110,12 +116,12 @@ async def _load_model(self, async_engine_client: AsyncEngineRPCClient, args: Nam
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
self.openai_serving_completion = OpenAIServingCompletion(cast(AsyncEngineClient,self._model), model_config, [args.model_name],
self.openai_serving_completion = OpenAIServingCompletion(cast(AsyncEngineClient, self._model), model_config, [args.model_name],
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids)
self.openai_serving_tokenization = OpenAIServingTokenization(cast(AsyncEngineClient,self._model), model_config, [args.model_name],
self.openai_serving_tokenization = OpenAIServingTokenization(cast(AsyncEngineClient, self._model), model_config, [args.model_name],
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template)
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.8
1.1.9

0 comments on commit da48451

Please sign in to comment.