Skip to content

Commit

Permalink
[Bugfix] Access get_vocab instead of vocab in tool parsers (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Oct 9, 2024
1 parent 21906a6 commit cfaa600
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
7 changes: 7 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import importlib.util
import os
from functools import cached_property
from typing import Callable, Dict, List, Optional, Sequence, Type, Union

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down Expand Up @@ -29,6 +30,12 @@ def __init__(self, tokenizer: AnyTokenizer):

self.model_tokenizer = tokenizer

@cached_property
def vocab(self) -> Dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()

def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Expand Down
7 changes: 3 additions & 4 deletions vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ def __init__(self, tokenizer: AnyTokenizer):
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id: int = self.model_tokenizer.vocab.get(
self.tool_call_start_token, None)
self.tool_call_end_token_id: int = self.model_tokenizer.vocab.get(
self.tool_call_end_token, None)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end "
Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def __init__(self, tokenizer: AnyTokenizer):
self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.model_tokenizer.get_vocab().get(
self.bot_token, None)
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
if not self.bot_token_id:
raise RuntimeError(
Expand Down

0 comments on commit cfaa600

Please sign in to comment.