diff --git a/CMakeLists.txt b/CMakeLists.txt index 462d3a01cc..5fd5afff42 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,6 +44,8 @@ if (BUILD_TEST) GIT_REPOSITORY https://github.com/catchorg/Catch2.git GIT_TAG v3.8.0 GIT_SHALLOW ON + GIT_PROGRESS TRUE + USES_TERMINAL_DOWNLOAD TRUE EXCLUDE_FROM_ALL ) FetchContent_MakeAvailable(Catch2) @@ -53,8 +55,10 @@ endif() FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git - GIT_TAG v3.9.2 - GIT_SHALLOW ON + GIT_TAG v3.9.2 + GIT_SHALLOW ON + GIT_PROGRESS TRUE + USES_TERMINAL_DOWNLOAD TRUE EXCLUDE_FROM_ALL ) @@ -66,13 +70,38 @@ FetchContent_MakeAvailable(repo-cutlass) FetchContent_Declare( yaml-cpp GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git - GIT_TAG 0.8.0 + GIT_TAG 0.8.0 + GIT_PROGRESS TRUE + USES_TERMINAL_DOWNLOAD TRUE PATCH_COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/cmake/yaml-cpp_cmake_policy.patch - UPDATE_DISCONNECTED 1 + UPDATE_DISCONNECTED 1 ) set(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL "Build static library of yaml-cpp") FetchContent_MakeAvailable(yaml-cpp) +FetchContent_Declare( + xgrammar + GIT_REPOSITORY https://github.com/mlc-ai/xgrammar.git + GIT_TAG v0.1.25 + GIT_SUBMODULES "3rdparty/dlpack" + GIT_PROGRESS TRUE + USES_TERMINAL_DOWNLOAD TRUE + UPDATE_DISCONNECTED 1 +) + +FetchContent_GetProperties(xgrammar) +if(NOT xgrammar_POPULATED) + # Fetch the content using previously declared details + FetchContent_Populate(xgrammar) + + file(WRITE ${xgrammar_SOURCE_DIR}/config.cmake "set(XGRAMMAR_BUILD_PYTHON_BINDINGS OFF)\n") + if(NOT MSVC) + file(APPEND ${xgrammar_SOURCE_DIR}/config.cmake "set(CMAKE_CXX_FLAGS \"-Wno-error\")\n") + endif() + + # Bring the populated content into the build + add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR}) +endif() # the environment variable # ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0 @@ -266,7 +295,9 @@ add_subdirectory(src) if (BUILD_PY_FFI) if (CALL_FROM_SETUP_PY) install(TARGETS _turbomind DESTINATION ${CMAKE_INSTALL_PREFIX}) + install(TARGETS _xgrammar DESTINATION ${CMAKE_INSTALL_PREFIX}) else() install(TARGETS _turbomind DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib) + install(TARGETS _xgrammar DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/lib) endif() endif () diff --git a/debug.sh b/debug.sh index d9c93bab73..95da648d26 100755 --- a/debug.sh +++ b/debug.sh @@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/bash -e builder="-G Ninja" @@ -15,4 +15,5 @@ cmake ${builder} .. \ -DCMAKE_CUDA_FLAGS="-lineinfo" \ -DUSE_NVTX=ON \ -DPYTHON_EXECUTABLE=$(which python3) \ + -DFETCHCONTENT_QUIET=OFF \ -DBUILD_TEST=ON diff --git a/docker/prepare_wheel.sh b/docker/prepare_wheel.sh index 1ffbbcf06b..4250c8820a 100755 --- a/docker/prepare_wheel.sh +++ b/docker/prepare_wheel.sh @@ -17,7 +17,6 @@ if [[ ${PYTHON_VERSION} = "3.13" ]]; then pip install setuptools_rust pip wheel -v --no-build-isolation --no-deps -w /wheels "git+https://github.com/google/sentencepiece.git@v0.2.0#subdirectory=python" - pip wheel -v --no-build-isolation --no-deps -w /wheels --use-deprecated=legacy-resolver outlines_core==0.1.26 fi if [[ "${CUDA_VERSION_SHORT}" != "cu118" ]]; then diff --git a/generate.sh b/generate.sh index 0c25b8cbf2..a59d5339ae 100755 --- a/generate.sh +++ b/generate.sh @@ -14,4 +14,5 @@ cmake ${builder} .. \ -DBUILD_PY_FFI=ON \ -DBUILD_MULTI_GPU=ON \ -DCMAKE_CUDA_FLAGS="-lineinfo" \ - -DUSE_NVTX=ON + -DUSE_NVTX=ON \ + -DFETCHCONTENT_QUIET=OFF diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 73c02e2914..c078d97d75 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -63,7 +63,7 @@ class GenerationConfig: around special tokens. The behavior of Fast tokenizers is to have this to False. This is setup to True in slow tokenizers. logprobs (int): Number of log probabilities to return per output token. - response_format (Dict): Only pytorch backend support formatting + response_format (Dict): Generate responses according to given formatting. response. Examples: { "type": "json_schema", diff --git a/lmdeploy/pytorch/engine/guided_process.py b/lmdeploy/pytorch/engine/guided_process.py index cc25906f60..6f01bd23ae 100644 --- a/lmdeploy/pytorch/engine/guided_process.py +++ b/lmdeploy/pytorch/engine/guided_process.py @@ -1,161 +1,87 @@ -# Copyright 2024- the Outlines developers -# This file is adapted from -# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - +# Copyright (c) OpenMMLab. All rights reserved. import copy -import math -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import defaultdict +import json +import logging from functools import lru_cache -from typing import DefaultDict, Dict, List, Union +from typing import Optional import torch -from outlines.fsm.guide import CFGGuide, Generate, RegexGuide, Write -from outlines.fsm.json_schema import build_regex_from_schema -from pydantic import BaseModel +import xgrammar as xgr from transformers import PreTrainedTokenizerBase +logger = logging.getLogger('guided_process') -class BaseLogitsProcessor: - - def init_state(self): - """Initialize the FSM states.""" - self.fsm_state: DefaultDict[int, int] = defaultdict(int) - - def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token.""" - - seq_id = hash(tuple(input_ids)) - - if len(input_ids) == 0: - self.init_state() - else: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self.fsm_state[seq_id] = self.fsm.get_next_state(state=self.fsm_state[last_seq_id], token_id=last_token) - - instruction = self.fsm.get_next_instruction(self.fsm_state[seq_id]) - if type(instruction) == Generate: - allowed_tokens = instruction.tokens - elif type(instruction) == Write: - # TODO: support fast forward tokens - allowed_tokens = [instruction.tokens[0]] - else: - raise TypeError(f'Unsupported instruction type {type(instruction)}') +class BaseLogitsProcessor: + """Base logits processor that uses xgrammar matcher for guided decoding.""" - mask = torch.full((scores.shape[-1], ), -math.inf, device=scores.device) - mask[allowed_tokens] = 0 - scores.add_(mask) + def __init__(self, compiled_grammar: xgr.CompiledGrammar, tokenizer_info: xgr.TokenizerInfo): + self.matcher = xgr.GrammarMatcher(compiled_grammar, terminate_without_stop_token=True) + self.token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size) + def process(self, scores: torch.Tensor) -> torch.Tensor: + """Apply grammar constraints to logits before sampling the next + token.""" + self.matcher.fill_next_token_bitmask(self.token_bitmask) + xgr.apply_token_bitmask_inplace(scores, self.token_bitmask.to(scores.device)) return scores - def adapt_tokenizer(self, tokenizer): - """Adapt tokenizer to use to compile the FSM. + def accept(self, token_id: int) -> bool: + """Update matcher state after a token is generated.""" + return self.matcher.accept_token(token_id) - The API of Outlines tokenizers is slightly different to that of `transformers`. In addition we need to handle - the missing spaces to Llama's tokenizer to be able to compile FSMs for this model. - """ - from outlines.integrations.utils import adapt_tokenizer - tokenizer = adapt_tokenizer(tokenizer) - # vocab size greater than logits shape because of '[UNUSED_TOKEN_...]' - if hasattr(tokenizer, '_tokenizer'): - tokenizer.vocabulary = tokenizer._tokenizer.get_vocab(with_added_tokens=False) - return tokenizer + def reset(self): + """Reset matcher state for next generation.""" + self.matcher.reset() class RegexLogitsProcessor(BaseLogitsProcessor): + """Regex-guided logits processor using xgrammar.""" - def __init__(self, regex_string: str, tokenizer): - """Compile the FSM that drives the regex-structured generation. - - Args: - regex_string: A string that represents a regular expression - tokenizer: The model's tokenizer - """ - tokenizer = self.adapt_tokenizer(copy.deepcopy(tokenizer)) - fsm = RegexGuide(regex_string, tokenizer) - self.fsm = fsm - - -class JSONLogitsProcessor(RegexLogitsProcessor): - - def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer): - """Compile the FSM that drives the JSON-guided generation. - - Args: - schema: A str schema that encodes the structure we want the model - to generate - tokenizer: The model's tokenizer - """ - regex_string = build_regex_from_schema(schema) - super().__init__(regex_string, tokenizer) - + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None): + tokenizer = copy.deepcopy(tokenizer) + if vocab_size_padded is None: + vocab_size_padded = tokenizer.vocab_size -class CFGLogitsProcessor(BaseLogitsProcessor): + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded) - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): - """Compile the FSM that drives the context free grammar generation. + compiler = xgr.GrammarCompiler(tokenizer_info) + compiled = compiler.compile_regex_grammar(regex_string) - Parameters - ---------- - cfg - A string that represents a context-free grammar - tokenizer - The model's tokenizer - """ - tokenizer = self.adapt_tokenizer(tokenizer) - fsm = CFGGuide(cfg, tokenizer) - self.fsm = fsm + super().__init__(compiled, tokenizer_info) -# copied from https://github.com/vllm-project/vllm/blob/a7f65c2be93f491771aca31106f790bf381c0bad/vllm/model_executor/guided_decoding/outlines_decoding.py#L31 # noqa -JSON_GRAMMAR = r""" -?start: object | array +class JSONLogitsProcessor(BaseLogitsProcessor): + """JSON-schema guided logits processor using xgrammar.""" -?value: object -| array -| UNESCAPED_STRING -| SIGNED_NUMBER -> number -| "true" -> true -| "false" -> false -| "null" -> null + def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase, vocab_size_padded: Optional[int] = None): + tokenizer = copy.deepcopy(tokenizer) + tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size_padded) + if vocab_size_padded is None: + vocab_size_padded = tokenizer.vocab_size -array : "[" [value ("," value)*] "]" -object : "{" [pair ("," pair)*] "}" -pair : UNESCAPED_STRING ":" value + compiler = xgr.GrammarCompiler(tokenizer_info) + if isinstance(schema, str): + schema = json.loads(schema) -%import common.UNESCAPED_STRING -%import common.SIGNED_NUMBER -%import common.WS + assert isinstance(schema, dict) + compiled = compiler.compile_json_schema(schema) -%ignore WS -""" + super().__init__(compiled, tokenizer_info) @lru_cache(maxsize=32) -def _get_guided_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, type: str): +def _get_guided_logits_processor(guide: str, + tokenizer: PreTrainedTokenizerBase, + type: str, + vocab_size_padded: Optional[int] = None): try: - if type == 'json_object': - return CFGLogitsProcessor(guide, tokenizer) - elif type == 'json_schema': - return JSONLogitsProcessor(guide, tokenizer) + if type == 'json_schema': + return JSONLogitsProcessor(guide, tokenizer, vocab_size_padded) elif type == 'regex_schema': - return RegexLogitsProcessor(guide, tokenizer) + return RegexLogitsProcessor(guide, tokenizer, vocab_size_padded) else: return None except Exception as e: - from lmdeploy.utils import get_logger - logger = get_logger('lmdeploy') logger.error(e) - return None + raise diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index b30fbb3992..b25447e764 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -78,12 +78,9 @@ def _multinomial_sampling(scores: torch.Tensor, return multinomial_sampling(scores, seeds, offsets, indices) -def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided_input_ids: Optional[torch.Tensor], - tokenizer: object): - if guided_input_ids is None: - return scores - for i in range(len(response_formats)): - _format = response_formats[i] +def _get_guided_processors(response_formats: Tuple[Dict], tokenizer: object, vocab_size_padded: int): + processors = {} + for i, _format in enumerate(response_formats): if isinstance(_format, Dict) and _format.get('type', 'text') != 'text': if _format['type'] == 'json_schema': schema = _format['json_schema'] @@ -91,10 +88,8 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided for key in ['json_schema', 'schema']: if key in schema: schema = json.dumps(schema[key], ensure_ascii=False) - elif schema is None: - from .guided_process import JSON_GRAMMAR - schema = JSON_GRAMMAR - elif isinstance(schema, str): + + if not isinstance(schema, str): raise ValueError(f'Cannot parse schema {schema}. The schema must be ' 'either a dictionary or a string that contains the' ' JSON Schema specification') @@ -102,11 +97,11 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided schema = _format.get('regex_schema', '') else: raise ValueError(f"unsupported format type: {_format['type']}") + from .guided_process import _get_guided_logits_processor - processor = _get_guided_logits_processor(schema, tokenizer, _format['type']) - if processor: - scores[i] = processor(guided_input_ids[i].tolist(), scores[i]) - return scores + processors[i] = _get_guided_logits_processor(schema, tokenizer, _format['type'], vocab_size_padded) + + return processors SeqList = List[SchedulerSequence] @@ -131,7 +126,6 @@ class SamplingInputs: logits_processors: List[List[LogitsProcessor]] = None max_num_logprobs: Optional[int] = None all_ids: Optional[torch.Tensor] = None - guided_input_ids: Optional[torch.Tensor] = None num_ignore_eos: torch.Tensor = None batch_size: int = 0 @@ -169,6 +163,8 @@ def __init__(self, self.tokenizer = tokenizer self.sampling_vocab_size = sampling_vocab_size self.logprobs_mode = logprobs_mode + self.guided_processors = _get_guided_processors(sampling_inputs.response_formats, tokenizer, + sampling_vocab_size) async def _wait_stream_once(self): """Wait stream once.""" @@ -205,9 +201,12 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: sampling_inputs = self.sampling_inputs all_ids = sampling_inputs.all_ids - guided_input_ids = sampling_inputs.guided_input_ids - custom_logits_processors = self.sampling_inputs.logits_processors + if self.guided_processors: + await self._wait_stream_once() + for i, processor in self.guided_processors.items(): + scores[i] = processor.process(scores[i]) + if any(custom_logits_processors): await self._wait_stream_once() scores = _apply_custom_logits_processors(custom_logits_processors, all_ids, scores) @@ -232,9 +231,6 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: stop_mask = torch.where(ignore_eos[:, None], stop_mask, False) scores = _process_bad_words_(scores, stop_words, stop_mask) - if guided_input_ids is not None: - await self._wait_stream_once() - scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer) return scores, logprobs @torch.inference_mode() @@ -272,7 +268,7 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): logits = logits[..., :self.sampling_vocab_size] if sampling_inputs.max_top_k == 1: - return logits.argmax(-1) + result = logits.argmax(-1) else: # sort logits is too slow. and we only need topk logits max_topk = sampling_inputs.max_top_k @@ -280,7 +276,13 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): scores, indices = logits.sort(1, descending=True) else: scores, indices = logits.topk(max_topk, dim=1) - return __random_sampling(scores, indices) + result = __random_sampling(scores, indices) + + if self.guided_processors: + for i, processor in self.guided_processors.items(): + processor.accept(result[i]) + + return result @torch.inference_mode() def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTensor): diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py index 4096db2cb7..f18419b4f7 100644 --- a/lmdeploy/pytorch/strategies/ar/model_agent.py +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -72,10 +72,6 @@ def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: if all_ids is not None: sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1) - guided_input_ids = sampling_inputs.guided_input_ids - if guided_input_ids is not None: - sampling_inputs.guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None]], 1) - return sampling_inputs def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria: diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index b2516f091a..ce5a048cf0 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -27,22 +27,6 @@ def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs) return output -def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): - """Gather input ids for guided decode.""" - if not any(sampling_inputs.response_formats or ()): - return None - batch = len(seqs) - max_len = max(seq.num_new_tokens for seq in seqs) - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_new_tokens - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.generated_ids) - output[idx, -h_len:] = h_ids - return output - - def _get_num_ignore_eos(seqs: SeqList): """Get num ignore eos.""" ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] @@ -186,6 +170,5 @@ def __get_bad_words(bad_words): pad_token_id = self.pad_token_id sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) - sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input) sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) return sampling_input diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 2ad5d5ecd7..45048e25a5 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -35,7 +35,6 @@ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: 'random_seeds', 'random_offsets', 'all_ids', - 'guided_input_ids', 'num_ignore_eos', ] for name in update_attr_names: diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 604f1604a5..c8a149214e 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -129,17 +129,17 @@ def create_error_response(status: HTTPStatus, message: str, error_type='invalid_ async def check_request(request) -> Optional[JSONResponse]: """Check if a request is valid.""" if hasattr(request, 'model') and request.model not in get_model_list(): - return create_error_response(HTTPStatus.NOT_FOUND, f'The model `{request.model}` does not exist.') + return create_error_response(HTTPStatus.NOT_FOUND, f'The model {request.model!r} does not exist.') if hasattr(request, 'n') and request.n <= 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The n `{request.n}` must be a positive int.') + return create_error_response(HTTPStatus.BAD_REQUEST, f'The n {request.n!r} must be a positive int.') if hasattr(request, 'top_p') and not (request.top_p > 0 and request.top_p <= 1): - return create_error_response(HTTPStatus.BAD_REQUEST, f'The top_p `{request.top_p}` must be in (0, 1].') + return create_error_response(HTTPStatus.BAD_REQUEST, f'The top_p {request.top_p!r} must be in (0, 1].') if hasattr(request, 'top_k') and request.top_k < 0: return create_error_response(HTTPStatus.BAD_REQUEST, - f'The top_k `{request.top_k}` cannot be a negative integer.') + f'The top_k {request.top_k!r} cannot be a negative integer.') if hasattr(request, 'temperature') and not (request.temperature <= 2 and request.temperature >= 0): return create_error_response(HTTPStatus.BAD_REQUEST, - f'The temperature `{request.temperature}` must be in [0, 2]') + f'The temperature {request.temperature!r} must be in [0, 2]') return @@ -315,8 +315,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque 1.0 means no penalty - stop (str | List[str] | None): To stop generating further tokens. Only accept stop words that's encoded to one token idex. - - response_format (Dict | None): Only pytorch backend support formatting - response. Examples: `{"type": "json_schema", "json_schema": {"name": + - response_format (Dict | None): To generate response according to given + schema. Examples: `{"type": "json_schema", "json_schema": {"name": "test","schema": {"properties": {"name": {"type": "string"}}, "required": ["name"], "type": "object"}}}` or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}` @@ -365,7 +365,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque if error_check_ret is not None: return error_check_ret if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.') + return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') model_name = request.model adapter_name = None @@ -385,8 +385,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque gen_logprobs = request.top_logprobs response_format = None if request.response_format and request.response_format.type != 'text': - if VariableInterface.async_engine.backend != 'pytorch': - return create_error_response(HTTPStatus.BAD_REQUEST, 'only pytorch backend can use response_format now') response_format = request.response_format.model_dump() if request.logit_bias is not None: @@ -717,7 +715,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None if error_check_ret is not None: return error_check_ret if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: - return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.') + return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id {request.session_id!r} is occupied.') model_name = request.model adapter_name = None diff --git a/lmdeploy/serve/proxy/proxy.py b/lmdeploy/serve/proxy/proxy.py index 4f88340593..d796bc64a1 100644 --- a/lmdeploy/serve/proxy/proxy.py +++ b/lmdeploy/serve/proxy/proxy.py @@ -312,7 +312,7 @@ async def check_request_model(self, model_name) -> Optional[JSONResponse]: """Check if a request is valid.""" if model_name in self.model_list: return - ret = create_error_response(HTTPStatus.NOT_FOUND, f'The model `{model_name}` does not exist.') + ret = create_error_response(HTTPStatus.NOT_FOUND, f'The model {model_name!r} does not exist.') return ret def handle_unavailable_model(self, model_name): @@ -538,8 +538,8 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque 1.0 means no penalty - stop (str | List[str] | None): To stop generating further tokens. Only accept stop words that's encoded to one token idex. - - response_format (Dict | None): Only pytorch backend support formatting - response. Examples: `{"type": "json_schema", "json_schema": {"name": + - response_format (Dict | None): To generate response according to given + schema. Examples: `{"type": "json_schema", "json_schema": {"name": "test","schema": {"properties": {"name": {"type": "string"}}, "required": ["name"], "type": "object"}}}` or `{"type": "regex_schema", "regex_schema": "call me [A-Za-z]{1,10}"}` diff --git a/lmdeploy/turbomind/tokenizer_info.py b/lmdeploy/turbomind/tokenizer_info.py new file mode 100644 index 0000000000..56af0e7b6c --- /dev/null +++ b/lmdeploy/turbomind/tokenizer_info.py @@ -0,0 +1,343 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Borrowed from xgrammar's TokenizerInfo +"""This module provides the tokenizer info class to handle the tokenizer +information.""" + +import json +import logging +from enum import Enum +from typing import List, Optional, Union + +import _xgrammar as _xgr # noqa: E402 + +try: + import sentencepiece +except ImportError: + sentencepiece = None +try: + import tiktoken +except ImportError: + tiktoken = None + +from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast + +logger = logging.getLogger(__name__) + + +class VocabType(Enum): + """The type of the vocabulary. + + Used in TokenizerInfo. XGrammar supports three types of + vocabularies: RAW, BYTE_FALLBACK, BYTE_LEVEL. + """ + + RAW = 0 + """The vocabulary is in the raw format. + + The tokens in the vocabulary are kept in their original form without any processing. This kind of tokenizer includes + the tiktoken tokenizer, e.g. microsoft/Phi-3-small-8k-instruct, Qwen/Qwen-7B-Chat, etc. + """ + + BYTE_FALLBACK = 1 + r"""The vocabulary used in the byte fallback BPE tokenizer. + + The tokens are encoded through the byte-fallback conversion. E.g. "\u001b" -> "<0x1B>", " apple" -> "▁apple". This + kind of tokenizer includes meta-llama/Llama-2-7b-chat, microsoft/Phi-3.5-mini-instruct, etc. + """ + + BYTE_LEVEL = 2 + """The vocabulary used in the byte level BPE tokenizer. + + The tokens are encoded through the byte-to-unicode conversion, as in + https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59 + + This kind of tokenizer includes meta-llama/Meta-Llama-3-8B-Instruct, + meta-llama/Meta-Llama-3.1-8B-Instruct, etc. + """ + + +class TokenizerInfo(_xgr.TokenizerInfo): + """The tokenizer info contains the vocabulary, the type of the vocabulary, + and necessary information for the grammar-guided generation. + + Note that although some tokenizers will encode the tokens in a special format, e.g. "<0x1B>" for "\u001b" in the + ByteFallback tokenizer, and "Ġ" for " " in the Byte-Level BPE tokenizer, TokenizerInfo always decodes the vocabulary + to the original format (e.g. "\u001b" and " "). + + Also note that some models (e.g. Phi-3 and Deepseek-V2) may pad the vocabulary to a multiple of 32. In this case, + the model's vocab_size is larger than the tokenizer's vocabulary size. Please pass the model's vocab_size to the + vocab_size parameter in the constructor, because this information is used to determine the size of the token mask. + """ + + def __init__( + self, + encoded_vocab: Union[List[bytes], List[str]], + vocab_type: VocabType = VocabType.RAW, + *, + vocab_size: Optional[int] = None, + stop_token_ids: Optional[Union[List[int], int]] = None, + add_prefix_space: bool = False, + ) -> None: + """Construct the tokenizer info. + + Parameters + ---------- + encoded_vocab : Union[List[bytes], List[str]] + The encoded vocabulary of the tokenizer. + + vocab_type : VocabType, default: VocabType.RAW + The type of the vocabulary. See also VocabType. + + vocab_size : Optional[int], default: None + The size of the vocabulary. If not provided, the vocabulary size will be len(encoded_vocab). + + stop_token_ids : Optional[List[int]], default: None + The stop token ids. If not provided, the stop token ids will be auto detected (but may not + be correct). + + add_prefix_space : bool, default: False + Whether the tokenizer will prepend a space before the text in the tokenization process. + """ + if isinstance(stop_token_ids, int): + stop_token_ids = [stop_token_ids] + + super().__init__(encoded_vocab, vocab_type.value, vocab_size, stop_token_ids, add_prefix_space) + + @staticmethod + def _is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool: + if tiktoken is None: + return False + + # helper to check if tokenizer is a tiktoken tokenizer + has_tiktoken_encoding = hasattr(tokenizer, 'tokenizer') and isinstance(tokenizer.tokenizer, tiktoken.Encoding) + + filename_pattern = (hasattr(tokenizer, 'vocab_files_names') and 'vocab_file' in tokenizer.vocab_files_names + and 'tiktoken' in tokenizer.vocab_files_names['vocab_file']) + + return has_tiktoken_encoding or filename_pattern + + @staticmethod + def _is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool: + if sentencepiece is None: + return False + + # helper to check if tokenizer is a sentence piece tokenizer + has_sp_model_attr = hasattr(tokenizer, 'sp_model') and isinstance(tokenizer.sp_model, + sentencepiece.SentencePieceProcessor) + + has_nested_sp_model_attr = (hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model') + and isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor)) + + return has_sp_model_attr or has_nested_sp_model_attr + + @staticmethod + def from_huggingface( + tokenizer: PreTrainedTokenizerBase, + *, + vocab_size: Optional[int] = None, + stop_token_ids: Optional[Union[List[int], int]] = None, + ) -> 'TokenizerInfo': + """Construct the tokenizer info from the huggingface tokenizer. This + constructor supports various tokenizer backends, including the + huggingface fast tokenizer and tiktoken tokenizer. Necessary + information is automatically detected from the tokenizer. + + The vocab_size parameter is introduced to handle the misalignment between the model's + vocab_size and the tokenizer's vocabulary size. User should pass the model's vocab_size + (could be defined in the model config) here. See docs of vocab_size for more details. + + The stop token ids is by default the eos_token_id of the tokenizer. If there are other + stop tokens, you can specify them manually. + + Parameters + ---------- + tokenizer : PreTrainedTokenizerBase + The huggingface tokenizer. + + vocab_size : Optional[int], default: None + The vocabulary size **defined by the model** (**not the tokenizer**). This equals to the + vocab dimension of the model's lm_head. This is the size of the token mask. + + It can be: + + 1. the same as the tokenizer's vocabulary size. This is the most common case. + 2. larger than the tokenizer's vocabulary size. This happens when the model has padding + to lm_head, possibly due to aligning lm_head to the power of 2. + E.g. Phi-3 and Deepseek-V2. + 3. smaller than the tokenizer's vocabulary size. This happens when the tokenizer has + some added tokens that will not supported by the model. E.g. + Llama-3.2 Vision and Molmo-72B-0924 has padded `<|image|>` tokens, but they will not + be considered in lm_head or generated by the model. + + model_vocab_size need to be provided for case 2 and 3. If not provided, it will be + set to the tokenizer's vocabulary size. + + stop_token_ids : Optional[List[int]], default: None + The stop token ids. If not provided, the eos_token_id of the tokenizer will be used. + + Returns + ------- + tokenizer_info : TokenizerInfo + The tokenizer info. + """ + if isinstance(stop_token_ids, int): + stop_token_ids = [stop_token_ids] + if isinstance(stop_token_ids, list) and len(stop_token_ids) == 0: + raise ValueError('stop_token_ids cannot be empty') + + try: + vocab_dict = tokenizer.get_vocab() + except AttributeError as e: + msg = (f'Cannot get the vocabulary of the tokenizer {type(tokenizer)}. The tokenizer ' + 'should have a get_vocab method.') + raise ValueError(msg) from e + + # Some tokenizer don't have token id 0 or 1 or 2. So the max_id could be larger than the + # number of tokens. + max_id = max(vocab_dict.values()) + tokenizer_vocab_size = max(len(vocab_dict), max_id + 1) + + vocab_size = vocab_size or tokenizer_vocab_size + + # maintain tokenizer's indexing + encoded_vocab = [''] * vocab_size + for token, idx in vocab_dict.items(): + if idx < vocab_size: + encoded_vocab[idx] = token + + if isinstance(tokenizer, PreTrainedTokenizerFast): + # huggingface fast tokenizer + # - the vocabulary is directly obtained from tokenizer.get_vocab() + # (tokenizer.backend_tokenizer.to_str() may not contain the full vocab, special + # tokens may be omitted) + # - the vocab size is obtained from len(tokenizer.get_vocab()) or provided by user + # - the vocab type and add_prefix_space are obtained from + # tokenizer.backend_tokenizer.to_str() + # - stop token id is provided by user, or auto detected. + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None: + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + else: + logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, ' + 'stop_token_ids is neither provided by user nor found from the tokenizer. ' + 'It will be automatically detected.') + metadata = json.loads(TokenizerInfo._detect_metadata_from_hf(backend_str)) + return TokenizerInfo( + encoded_vocab, + vocab_type=VocabType(metadata['vocab_type']), + vocab_size=vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=metadata['add_prefix_space'], + ) + + elif TokenizerInfo._is_tiktoken_tokenizer(tokenizer): + # tiktoken tokenizer + # e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously) + if stop_token_ids is None: + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + else: + logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, ' + 'stop_token_ids is neither provided by user nor found from the tokenizer. ' + 'It will be automatically detected.') + return TokenizerInfo( + encoded_vocab, + VocabType.RAW, + vocab_size=vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=False, + ) + + elif TokenizerInfo._is_sentencepiece_tokenizer(tokenizer): + # sentencepiece tokenizer + # e.g. Chatglm3-6b + if hasattr(tokenizer, 'sp_model'): + sp_model = tokenizer.sp_model + elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'): + sp_model = tokenizer.tokenizer.sp_model + + if stop_token_ids is None: + if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + else: + eos_id = sp_model.eos_id() + if eos_id != -1: + stop_token_ids = [eos_id] + else: + logger.warning('When constructing TokenizerInfo from a huggingface tokenizer, ' + 'stop_token_ids is neither provided by user nor found from the tokenizer. ' + 'It will be automatically detected.') + # detect vocab_type of tokenizer + if '<0x0A>' in vocab_dict: + vocab_type = VocabType.BYTE_FALLBACK + else: + vocab_type = VocabType.RAW + + return TokenizerInfo( + encoded_vocab, + vocab_type=vocab_type, + vocab_size=vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=True, + ) + + else: + # TODO(yixin): unsupported tokenizer + raise ValueError(f'Unsupported tokenizer type: {type(tokenizer)}') + + @property + def vocab_type(self) -> VocabType: + """The type of the vocabulary.""" + return VocabType(self._handle.vocab_type) + + @property + def vocab_size(self) -> int: + """The size of the vocabulary.""" + return self._handle.vocab_size + + @property + def add_prefix_space(self) -> bool: + """Whether the tokenizer will prepend a space before the text in the + tokenization process.""" + return self._handle.add_prefix_space + + @property + def prepend_space_in_tokenization(self) -> bool: + """Whether the tokenizer will prepend a space before the text in the + tokenization process. + + This property is deprecated. Use add_prefix_space instead. + """ + logger.warning('prepend_space_in_tokenization is deprecated. Use add_prefix_space instead.') + return self.add_prefix_space + + @property + def decoded_vocab(self) -> List[bytes]: + """The decoded vocabulary of the tokenizer. + + This converts the tokens in the LLM's vocabulary back to the original format of the input text. E.g. for type + ByteFallback, the token <0x1B> is converted back to "\u001b". + """ + return self._handle.decoded_vocab + + @property + def stop_token_ids(self) -> List[int]: + """The stop token ids.""" + return self._handle.stop_token_ids + + @property + def special_token_ids(self) -> List[int]: + """The special token ids. + + Special tokens include control tokens, reserved tokens, padded tokens, etc. Now it is automatically detected + from the vocabulary. + """ + return self._handle.special_token_ids + + def dump_metadata(self) -> str: + """Dump the metadata of the tokenizer to a json string. + + It can be used to construct the tokenizer info from the vocabulary and the metadata string. + """ + return self._handle.dump_metadata() diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index dac5325364..93e945da45 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -33,6 +33,9 @@ lmdeploy_dir = osp.split(lmdeploy.__file__)[0] sys.path.append(osp.join(lmdeploy_dir, 'lib')) import _turbomind as _tm # noqa: E402 +import _xgrammar as _xgr # noqa: E402 + +from .tokenizer_info import TokenizerInfo # noqa: E402 logger = get_logger('lmdeploy') @@ -702,6 +705,26 @@ async def async_stream_infer(self, input_meta=input_meta, gen_config=gen_config) + if gen_config.response_format is not None: + tokenizer = self.tm_model.tokenizer + vocab_size = self.tm_model.config.model_config.vocab_size + decode_grammar_type = gen_config.response_format['type'] + decode_grammar = gen_config.response_format[decode_grammar_type]['schema'] + + tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size) + compiler = _xgr.GrammarCompiler(tokenizer_info) + + if decode_grammar_type == 'json_schema': + decode_grammar = json.dumps(decode_grammar) + grammar = compiler.compile_json_schema(decode_grammar) + elif decode_grammar_type == 'regex': + decode_grammar = str(decode_grammar) + grammar = compiler.compile_regex(decode_grammar) + else: + assert False, f'Decode grammar type {decode_grammar_type} should be in ["json_schema", "regex"]' + + self.model_inst.set_grammar(grammar) + session = _tm.SessionParam(id=session_id, step=step, start=sequence_start, end=sequence_end) inputs = _np_dict_to_tm_dict(inputs) diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 984ebdc166..8e037ef521 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -7,7 +7,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.11.1 pillow @@ -24,3 +23,4 @@ torch-npu>=2.3.1,<2.8.0 torchvision>=0.18.1,<0.23.0 transformers uvicorn +xgrammar diff --git a/requirements/runtime_camb.txt b/requirements/runtime_camb.txt index 4ba6ef8462..5b37b003c0 100644 --- a/requirements/runtime_camb.txt +++ b/requirements/runtime_camb.txt @@ -6,7 +6,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.11.1 pillow @@ -21,3 +20,4 @@ torch<=2.6.0,>=2.4.0 torchvision<=0.21.0,>=0.15.0 transformers uvicorn +xgrammar diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index 21502e1103..f7ac027ee5 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -7,7 +7,6 @@ mmengine-lite numpy openai openai_harmony -outlines partial_json_parser peft<=0.14.0 pillow @@ -26,3 +25,4 @@ torchvision<=0.23.0,>=0.15.0 transformers triton<=3.4.0,>=3.0.0; sys_platform == "linux" uvicorn +xgrammar diff --git a/requirements/runtime_maca.txt b/requirements/runtime_maca.txt index 19a016cbed..70202d5ce5 100644 --- a/requirements/runtime_maca.txt +++ b/requirements/runtime_maca.txt @@ -6,7 +6,6 @@ mmengine-lite numpy openai openai_harmony -outlines<0.1.0 partial_json_parser peft<=0.11.1 pillow @@ -22,3 +21,4 @@ torchvision<=0.21.0,>=0.15.0 transformers triton>=2.1.0; sys_platform == "linux" uvicorn +xgrammar diff --git a/requirements/runtime_rocm.txt b/requirements/runtime_rocm.txt index 1605fd6043..47d6f66fcd 100644 --- a/requirements/runtime_rocm.txt +++ b/requirements/runtime_rocm.txt @@ -6,7 +6,6 @@ mmengine-lite numpy openai openai_harmony -outlines partial_json_parser peft<=0.14.0 pillow @@ -20,3 +19,4 @@ shortuuid tiktoken transformers uvicorn +xgrammar diff --git a/requirements/test.txt b/requirements/test.txt index 41b25dac35..3fe279d6ce 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,6 @@ allure-pytest coverage +jsonschema nvidia-ml-py pytest pytest-assume @@ -9,3 +10,4 @@ pytest-rerunfailures pytest-sugar pytest-xdist pyyaml +timm diff --git a/src/turbomind/engine/CMakeLists.txt b/src/turbomind/engine/CMakeLists.txt index 6836d98155..8c46860abc 100644 --- a/src/turbomind/engine/CMakeLists.txt +++ b/src/turbomind/engine/CMakeLists.txt @@ -3,6 +3,6 @@ cmake_minimum_required(VERSION 3.8) add_library(engine STATIC gateway.cc request_queue.cc model_request.cc) -target_link_libraries(engine PRIVATE core) +target_link_libraries(engine PRIVATE core xgrammar) set_property(TARGET engine PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET engine PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index 48b20bff68..ba7ebe321f 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -127,6 +127,10 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output r->output_ids = outputs_->at("output_ids"); r->sequence_length = outputs_->at("sequence_length"); + if (grammar_) { + r->matcher = std::make_shared(*grammar_); + } + // Keep a weak reference for canceling the request request_ = r; @@ -135,4 +139,9 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output return OutputParam{outputs_, state, metrics}; } +void ModelRequest::setGrammar(const xgrammar::CompiledGrammar& grammar) +{ + grammar_ = std::make_shared(grammar); +} + } // namespace turbomind diff --git a/src/turbomind/engine/model_request.h b/src/turbomind/engine/model_request.h index 7009885550..7582163095 100644 --- a/src/turbomind/engine/model_request.h +++ b/src/turbomind/engine/model_request.h @@ -4,6 +4,8 @@ #include +#include + #include "src/turbomind/core/core.h" #include "src/turbomind/engine/gateway.h" @@ -38,6 +40,7 @@ class ModelRequest { }; OutputParam Forward(InputParam param, std::function cb); + void setGrammar(const xgrammar::CompiledGrammar& grammar); protected: Gateway* const gateway_; @@ -52,8 +55,9 @@ class ModelRequest { std::weak_ptr request_; - std::shared_ptr inputs_; - std::shared_ptr outputs_; + std::shared_ptr inputs_; + std::shared_ptr outputs_; + std::shared_ptr grammar_; }; } // namespace turbomind diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index f02e385f5f..aa50a48100 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -10,6 +10,8 @@ #include #include +#include + #include "src/turbomind/core/core.h" #include "src/turbomind/utils/metrics.h" @@ -151,6 +153,8 @@ struct Request { kCancel = 8, kInconsistency = 9, // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode }; + + std::shared_ptr matcher; }; inline void UpdateState(Request& r, int status, int seq_len) diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt index 7c63d752ae..2dc16c7a81 100644 --- a/src/turbomind/kernels/CMakeLists.txt +++ b/src/turbomind/kernels/CMakeLists.txt @@ -69,6 +69,9 @@ add_library(sampling_kernels STATIC sampling_kernels.cu) set_property(TARGET sampling_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET sampling_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +add_library(apply_token_bitmask_inplace_cuda STATIC apply_token_bitmask_inplace_cuda.cu) +set_property(TARGET apply_token_bitmask_inplace_cuda PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET apply_token_bitmask_inplace_cuda PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) add_subdirectory(attention) add_subdirectory(gemm) diff --git a/src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu b/src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu new file mode 100644 index 0000000000..d77d449009 --- /dev/null +++ b/src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu @@ -0,0 +1,225 @@ +// Modified from xgrammar python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu + +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// clang-format off +#include +#include +#include + +#include "src/turbomind/core/context.h" +#include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h" +// clang-format on + +using namespace std; + +#ifndef CUDART_INF_FP16 +#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) +#endif + +#if __CUDA_ARCH__ >= 800 +#ifndef CUDART_INF_BF16 +#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) +#endif +#endif + +constexpr int32_t BITS_PER_BLOCK = 32; +constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; + +template +__device__ T NegativeInfinity() +{ + return -INFINITY; +} + +template<> +__device__ __half NegativeInfinity<__half>() +{ + return -CUDART_INF_FP16; +} + +#if __CUDA_ARCH__ >= 800 +template<> +__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>() +{ + return -CUDART_INF_BF16; +} +#endif + +template +__device__ PackedT PackedNegativeInfinity() +{ + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + T packed[kAlignment]; +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + packed[i] = NegativeInfinity(); + } + return *reinterpret_cast(packed); +} + +template +__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride) +{ + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + constexpr uint32_t kPackedMask = (1 << kAlignment) - 1; + + const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y]; + + const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread; + T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset; + const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK; + const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment); + T logits_reg[kAlignment]; + +#pragma unroll + for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread; + offset += THREADS_PER_THREAD_BLOCK * kAlignment) { + if (block_offset + offset >= vocab_size) { + break; + } + + const uint32_t bitmask_val = + (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask; + + if (bitmask_val == 0) { + continue; + } + + if (bitmask_val == kPackedMask) { + *reinterpret_cast(logits_gmem_ptr + offset) = PackedNegativeInfinity(); + continue; + } + + *reinterpret_cast(logits_reg) = *reinterpret_cast(logits_gmem_ptr + offset); +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + if (((bitmask_val >> i) & 1)) { + logits_reg[i] = NegativeInfinity(); + } + } + *reinterpret_cast(logits_gmem_ptr + offset) = *reinterpret_cast(logits_reg); + } +} + +template::value>> +constexpr auto CeilDiv(T numerator, T denominator) +{ + return (numerator + denominator - 1) / denominator; +} + +template +void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) +{ + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); + const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); + + const dim3 block(THREADS_PER_THREAD_BLOCK); + const auto& stream = turbomind::core::Context::stream(); + + if (num_bits_per_thread <= 4 && kAlignment <= 4) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } + else if (num_bits_per_thread <= 8 && kAlignment <= 8) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } + else if (num_bits_per_thread <= 16 && kAlignment <= 16) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } + else { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); + LogitsBitmaskKernel + <<>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } +} + +template +void ApplyTokenBitmaskInplaceDispatchToPackedT(T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) +{ + if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } + else { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } +} + +namespace turbomind { +using namespace turbomind::core; + +void ApplyTokenBitmaskInplace(Tensor logits, Tensor bitmask, std::optional indices) +{ + std::pair logits_shape = + logits.ndim() == 2 ? + std::make_pair(static_cast(logits.shape(0)), static_cast(logits.shape(1))) : + std::make_pair(1, static_cast(logits.shape(0))); + + std::pair bitmask_shape = + bitmask.ndim() == 2 ? + std::make_pair(static_cast(bitmask.shape(0)), static_cast(bitmask.shape(1))) : + std::make_pair(1, static_cast(bitmask.shape(0))); + + int vocab_size = std::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK); + + int32_t num_rows = logits_shape.first; + int32_t* indices_ptr = nullptr; + if (indices) { + num_rows = indices->shape(0); + indices_ptr = indices->data(); + } + else { + TM_CHECK(logits_shape.first == bitmask_shape.first) << "logits and bitmask must have the same batch size."; + } + + // Currently we use only float logits. + TM_CHECK(logits.dtype() == kFloat32); + ApplyTokenBitmaskInplaceDispatchToPackedT(logits.data(), + bitmask.data(), + indices_ptr, + vocab_size, + logits.stride(0), + bitmask.stride(0), + num_rows); +} +} // namespace turbomind diff --git a/src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h b/src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h new file mode 100644 index 0000000000..bffffd0285 --- /dev/null +++ b/src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h @@ -0,0 +1,7 @@ +#include "src/turbomind/core/tensor.h" + +namespace turbomind { +void ApplyTokenBitmaskInplace(core::Tensor logits, + core::Tensor bitmask, + std::optional indices = std::nullopt); +} diff --git a/src/turbomind/layers/CMakeLists.txt b/src/turbomind/layers/CMakeLists.txt index 975ee77ec7..30977b2f5f 100644 --- a/src/turbomind/layers/CMakeLists.txt +++ b/src/turbomind/layers/CMakeLists.txt @@ -21,5 +21,5 @@ add_library(DynamicDecodeLayer STATIC DynamicDecodeLayer.cc) set_property(TARGET DynamicDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET DynamicDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(DynamicDecodeLayer PUBLIC CUDA::cudart - LogitsProcessorLayer SamplingLayer StopCriteriaLayer - gpt_kernels nvtx_utils) + LogitsProcessorLayer SamplingLayer StopCriteriaLayer GuidedDecodeLayer + gpt_kernels nvtx_utils) diff --git a/src/turbomind/layers/DynamicDecodeLayer.cc b/src/turbomind/layers/DynamicDecodeLayer.cc index 799beff53a..5a66bf1fb6 100644 --- a/src/turbomind/layers/DynamicDecodeLayer.cc +++ b/src/turbomind/layers/DynamicDecodeLayer.cc @@ -17,6 +17,8 @@ #include "src/turbomind/layers/DynamicDecodeLayer.h" #include "src/turbomind/core/data_type.h" #include "src/turbomind/layers/BaseDynamicDecodeLayer.h" +#include "src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h" +#include "src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h" #include "src/turbomind/layers/sampling_layers/LogitsProcessorLayer.h" #include "src/turbomind/layers/sampling_layers/SamplingLayer.h" #include "src/turbomind/layers/sampling_layers/StopCriteriaLayer.h" @@ -35,7 +37,9 @@ DynamicDecodeLayer::DynamicDecodeLayer(DataType dtype, TM_CHECK(dtype == kFloat32); BaseDynamicDecodeLayer::BaseParam param{max_batch_size, vocab_size, vocab_size_padded, stream, device_prop}; layers_.emplace_back(new LogitsProcessorLayer{param}); + layers_.emplace_back(new GuidedDecodeMaskLayer{param}); layers_.emplace_back(new SamplingLayer{param}); + layers_.emplace_back(new GuidedDecodeUpdateLayer{param}); layers_.emplace_back(new StopCriteriaLayer{param}); } diff --git a/src/turbomind/layers/sampling_layers/CMakeLists.txt b/src/turbomind/layers/sampling_layers/CMakeLists.txt index c1dc86b8d1..d7ec104508 100644 --- a/src/turbomind/layers/sampling_layers/CMakeLists.txt +++ b/src/turbomind/layers/sampling_layers/CMakeLists.txt @@ -20,17 +20,22 @@ add_library(LogitsProcessorLayer STATIC LogitsProcessorLayer.cc) set_property(TARGET LogitsProcessorLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET LogitsProcessorLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(LogitsProcessorLayer PUBLIC CUDA::cudart ban_bad_words memory_utils - sampling_penalty_kernels + sampling_penalty_kernels xgrammar ) add_library(SamplingLayer STATIC SamplingLayer.cc) set_property(TARGET SamplingLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET SamplingLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(SamplingLayer PUBLIC CUDA::cudart memory_utils - sampling_topk_kernels sampling_topp_kernels sampling_kernels + sampling_topk_kernels sampling_topp_kernels sampling_kernels xgrammar ) add_library(StopCriteriaLayer STATIC StopCriteriaLayer.cc) set_property(TARGET StopCriteriaLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET StopCriteriaLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(StopCriteriaLayer PUBLIC CUDA::cudart stop_criteria memory_utils) +target_link_libraries(StopCriteriaLayer PUBLIC CUDA::cudart stop_criteria memory_utils xgrammar) + +add_library(GuidedDecodeLayer STATIC GuidedDecodeMaskLayer.cc GuidedDecodeUpdateLayer.cc) +set_property(TARGET GuidedDecodeLayer PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET GuidedDecodeLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(GuidedDecodeLayer PUBLIC CUDA::cudart apply_token_bitmask_inplace_cuda xgrammar) diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc new file mode 100644 index 0000000000..2262992902 --- /dev/null +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2025-2025, OpenMMLab. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h" +#include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h" + +namespace turbomind { + +template +GuidedDecodeMaskLayer::GuidedDecodeMaskLayer(const BaseParam& param): BaseDynamicDecodeLayer{param} +{ + const auto bitmask_size = xgrammar::GetBitmaskSize(vocab_size_padded_); + bitmask_buf_ = {{max_batch_size_, bitmask_size}, kCPU}; + bitmask_ = {{max_batch_size_, bitmask_size}, kDEVICE}; +} + +template +void GuidedDecodeMaskLayer::Setup(const std::vector& rs, const TensorMap& args) +{ + TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + matchers_.clear(); + for (const auto& r : rs) { + matchers_.push_back(r->matcher); + } +} + +template +void GuidedDecodeMaskLayer::Forward(TensorMap& args) +{ + TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + + Tensor_ logits = args.at("logits"); + const ssize_t bsz = logits.shape(0); + + TM_CHECK(bsz == matchers_.size()); + + const auto bitmask_size = bitmask_buf_.shape(1); + std::vector bitmask_shape = {bsz, bitmask_size}; + + DLTensor bitmask_dltensor{bitmask_buf_.data(), + DLDevice{kDLCPU, 0}, + bitmask_buf_.ndim(), + xgrammar::GetBitmaskDLType(), + bitmask_shape.data(), + nullptr, + 0}; + bool need_apply = false; + for (size_t i = 0; i < bsz; ++i) { + const auto& matcher = matchers_[i]; + if (matcher) { + matcher->FillNextTokenBitmask(&bitmask_dltensor, i); + need_apply = true; + } + } + + if (need_apply) { + Copy(bitmask_buf_, bitmask_); + ApplyTokenBitmaskInplace(logits, bitmask_.slice(0, bsz)); + } +} + +template class GuidedDecodeMaskLayer; + +} // namespace turbomind diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h new file mode 100644 index 0000000000..45cc917976 --- /dev/null +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2025-2025, OpenMMLab. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "src/turbomind/layers/BaseDynamicDecodeLayer.h" + +#include "src/turbomind/engine/request.h" + +namespace turbomind { + +template +class GuidedDecodeMaskLayer: public BaseDynamicDecodeLayer { +public: + explicit GuidedDecodeMaskLayer(const BaseParam& param); + + void Setup(const std::vector& rs, const TensorMap& args) override; + + void Forward(TensorMap& args) override; + +private: + std::vector> matchers_; + // host buffer + Tensor_ bitmask_buf_; + // device buffer + Tensor_ bitmask_; +}; + +} // namespace turbomind diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc new file mode 100644 index 0000000000..653a8874d8 --- /dev/null +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2025-2025, OpenMMLab. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h" + +namespace turbomind { + +template +GuidedDecodeUpdateLayer::GuidedDecodeUpdateLayer(const BaseParam& param): BaseDynamicDecodeLayer{param} +{ +} + +template +void GuidedDecodeUpdateLayer::Setup(const std::vector& rs, const TensorMap& args) +{ + TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + matchers_.clear(); + for (const auto& r : rs) { + matchers_.push_back(r->matcher); + } +} + +template +void GuidedDecodeUpdateLayer::Forward(TensorMap& args) +{ + TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); + Tensor_ logits = args.at("logits"); + Tensor_ output_ids = args.at("output_ids"); + const int step = *args.at("step").data(); + const ssize_t bsz = logits.shape(0); + Tensor_ output_ids_buf{{bsz}, kCPU}; + + FT_CHECK(bsz == matchers_.size()); + Copy(output_ids.slice(step * bsz, bsz), output_ids_buf); + + for (size_t i = 0; i < bsz; ++i) { + const auto& matcher = matchers_[i]; + if (matcher) { + matcher->AcceptToken(output_ids_buf.data()[i]); + } + } +} + +template class GuidedDecodeUpdateLayer; +} // namespace turbomind diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h new file mode 100644 index 0000000000..94cf1338f4 --- /dev/null +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025-2025, OpenMMLab. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include "src/turbomind/layers/BaseDynamicDecodeLayer.h" + +#include "src/turbomind/engine/request.h" + +namespace turbomind { + +template +class GuidedDecodeUpdateLayer: public BaseDynamicDecodeLayer { +public: + explicit GuidedDecodeUpdateLayer(const BaseParam&); + + void Setup(const std::vector&, const TensorMap&) override; + + void Forward(TensorMap&) override; + +private: + std::vector> matchers_; +}; + +} // namespace turbomind diff --git a/src/turbomind/python/CMakeLists.txt b/src/turbomind/python/CMakeLists.txt index e58eb15c5b..c4f5673c28 100644 --- a/src/turbomind/python/CMakeLists.txt +++ b/src/turbomind/python/CMakeLists.txt @@ -16,6 +16,10 @@ pybind11_add_module(${PROJECT_NAME} bind.cpp) target_link_libraries(${PROJECT_NAME} PRIVATE LlamaTritonBackend) target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_14) +pybind11_add_module(_xgrammar xgrammar_bind.cpp) +target_link_libraries(_xgrammar PRIVATE core xgrammar) +target_compile_features(_xgrammar PRIVATE cxx_std_14) + if (CALL_FROM_SETUP_PY) set(_INSTALL_CUDA_RPATH "\$ORIGIN" diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index e6f38a2d1b..f4d090fefd 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include "src/turbomind/core/data_type.h" #include "src/turbomind/core/tensor.h" #include "src/turbomind/engine/model_request.h" @@ -488,7 +490,15 @@ PYBIND11_MODULE(_turbomind, m) }, py::call_guard(), "cb"_a, - "session_id"_a); + "session_id"_a) + .def( + "set_grammar", + [](ModelRequest* model_request, const xgrammar::CompiledGrammar& grammar) { + TM_LOG_INFO("Set grammar for model_request"); + model_request->setGrammar(grammar); + }, + py::call_guard(), + "grammar"_a); // transformer model using ft::LlamaTritonModel; diff --git a/src/turbomind/python/xgrammar_bind.cpp b/src/turbomind/python/xgrammar_bind.cpp new file mode 100644 index 0000000000..bcd6945915 --- /dev/null +++ b/src/turbomind/python/xgrammar_bind.cpp @@ -0,0 +1,134 @@ +// Modified from xgrammar/nanobind/nanobind.cc from xgrammar project. +/*! + * Copyright (c) 2024 by Contributors + * \file xgrammar/nanobind/nanobind.cc + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include "src/turbomind/core/check.h" + +namespace py = pybind11; +using namespace xgrammar; +using namespace pybind11::literals; + +namespace { + +static const std::vector +CommonEncodedVocabType(const py::typing::List>& lst) +{ + std::vector out; + out.reserve(lst.size()); + for (const auto& h : lst) { + if (py::isinstance(h)) { + out.emplace_back(h.cast()); + } + else if (py::isinstance(h)) { + out.emplace_back(h.cast()); + } + else { + throw std::invalid_argument("encoded_vocab items must be str or bytes"); + } + } + return out; +} + +TokenizerInfo TokenizerInfo_Init(const std::vector& encoded_vocab, + int vocab_type, + std::optional vocab_size, + std::optional> stop_token_ids, + bool add_prefix_space) +{ + TM_CHECK(vocab_type == 0 || vocab_type == 1 || vocab_type == 2) << "Invalid vocab type: " << vocab_type; + return TokenizerInfo( + encoded_vocab, static_cast(vocab_type), vocab_size, stop_token_ids, add_prefix_space); +} + +int TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer) +{ + return static_cast(tokenizer.GetVocabType()); +} + +std::vector TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) +{ + const auto& decoded_vocab = tokenizer.GetDecodedVocab(); + std::vector py_result; + py_result.reserve(decoded_vocab.size()); + for (const auto& item : decoded_vocab) { + py_result.emplace_back(py::bytes(item.c_str())); + } + return py_result; +} + +} // namespace + +PYBIND11_MODULE(_xgrammar, m) +{ + py::class_>(m, "TokenizerInfo") + .def(py::init([](const py::typing::List>& encoded_vocab, + int vocab_type, + std::optional vocab_size, + std::optional> stop_token_ids, + bool add_prefix_space) { + return TokenizerInfo{TokenizerInfo_Init(CommonEncodedVocabType(encoded_vocab), + vocab_type, + vocab_size, + std::move(stop_token_ids), + add_prefix_space)}; + }), + py::arg("encoded_vocab"), + py::arg("vocab_type"), + py::arg("vocab_size") = py::none(), + py::arg("stop_token_ids") = py::none(), + py::arg("add_prefix_space")) + + .def_property_readonly("vocab_type", &TokenizerInfo_GetVocabType) + .def_property_readonly("vocab_size", &TokenizerInfo::GetVocabSize) + .def_property_readonly("add_prefix_space", &TokenizerInfo::GetAddPrefixSpace) + .def_property_readonly("decoded_vocab", &TokenizerInfo_GetDecodedVocab) + .def_property_readonly("stop_token_ids", &TokenizerInfo::GetStopTokenIds) + .def_property_readonly("special_token_ids", &TokenizerInfo::GetSpecialTokenIds) + + .def("dump_metadata", &TokenizerInfo::DumpMetadata) + + .def_static("from_vocab_and_metadata", + [](const py::typing::List>& encoded_vocab, + const std::string& metadata) { + return TokenizerInfo::FromVocabAndMetadata(CommonEncodedVocabType(encoded_vocab), metadata); + }) + + .def_static("_detect_metadata_from_hf", &TokenizerInfo::DetectMetadataFromHF); + + py::class_(m, "CompiledGrammar"); + + py::class_ pyGrammarCompiler(m, "GrammarCompiler"); + pyGrammarCompiler + .def(py::init(), + py::arg("tokenizer_info"), + py::arg("max_threads") = 8, + py::arg("cache_enabled") = true, + py::arg("max_memory_bytes") = -1) + .def("compile_json_schema", + &GrammarCompiler::CompileJSONSchema, + py::call_guard(), + py::arg("schema"), + py::arg("any_whitespace") = false, + py::arg("indent") = py::none(), + py::arg("separators") = py::none(), + py::arg("strict_mode") = true, + py::arg("max_whitespace_cnt") = py::none()) + .def("compile_regex", + &GrammarCompiler::CompileRegex, + py::call_guard(), + py::arg("schema")); +} diff --git a/src/turbomind/triton_backend/llama/CMakeLists.txt b/src/turbomind/triton_backend/llama/CMakeLists.txt index 756f5ac67d..d3b6020356 100644 --- a/src/turbomind/triton_backend/llama/CMakeLists.txt +++ b/src/turbomind/triton_backend/llama/CMakeLists.txt @@ -32,6 +32,7 @@ target_link_libraries(LlamaTritonBackend PUBLIC core memory_utils CUDA::cublasLt - yaml-cpp::yaml-cpp) + yaml-cpp::yaml-cpp + xgrammar) target_compile_features(LlamaTritonBackend PRIVATE cxx_std_14) diff --git a/tests/test_lmdeploy/test_grammar.py b/tests/test_lmdeploy/test_grammar.py new file mode 100644 index 0000000000..e45b4f1a42 --- /dev/null +++ b/tests/test_lmdeploy/test_grammar.py @@ -0,0 +1,79 @@ +import json + +import pytest +from jsonschema import validate + +from lmdeploy import pipeline +from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig + +MODEL_IDS = [ + 'Qwen/Qwen3-0.6B', + 'OpenGVLab/InternVL3_5-1B', +] + +BACKEND_FACTORIES = [ + ('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)), + ('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)), +] + +GUIDE_SCHEMA = { + 'type': 'object', + 'properties': { + 'name': { + 'type': 'string' + }, + 'skills': { + 'type': 'array', + 'items': { + 'type': 'string', + 'maxLength': 10 + }, + 'minItems': 3, + 'maxItems': 10, + }, + 'work history': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'company': { + 'type': 'string' + }, + 'duration': { + 'type': 'string' + }, + }, + 'required': ['company'], + }, + }, + }, + 'required': ['name', 'skills', 'work history'], +} + + +@pytest.mark.parametrize('model_id', MODEL_IDS) +@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES) +@pytest.mark.parametrize('enable_guide', [True, False]) +def test_guided_matrix(model_id, backend_name, backend_factory, enable_guide): + pipe = pipeline( + model_id, + backend_config=backend_factory(), + log_level='INFO', + ) + + try: + if enable_guide: + gen_config = GenerationConfig(response_format=dict( + type='json_schema', + json_schema=dict(name='test', schema=GUIDE_SCHEMA), + ), ) + else: + gen_config = GenerationConfig() + + response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config) + assert response and response[0].text + + if enable_guide: + validate(instance=json.loads(response[0].text), schema=GUIDE_SCHEMA) + finally: + pipe.close()