From c295f192c8f40e94a33ba98cf54260cd8fd83618 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 9 Apr 2025 10:56:09 +0800 Subject: [PATCH 1/3] update --- .../llm/infer/infer_engine/lmdeploy_engine.py | 3 - swift/llm/infer/infer_engine/sglang_engine.py | 106 ++++++++++++++++++ 2 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 swift/llm/infer/infer_engine/sglang_engine.py diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index 4c791b476c..35df646ae6 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -150,9 +150,6 @@ def _load_generation_config(self): if os.path.isfile(generation_config_path): generation_config = GenerationConfig.from_pretrained(self.model_dir) kwargs = generation_config.to_dict() - max_new_tokens = kwargs.get('max_new_tokens') - if max_new_tokens is None: - kwargs.pop('max_new_tokens', None) parameters = inspect.signature(LmdeployGenerationConfig).parameters for k, v in kwargs.copy().items(): if k not in parameters or v is None: diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py new file mode 100644 index 0000000000..048f8cb785 --- /dev/null +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -0,0 +1,106 @@ +import os +from copy import deepcopy +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import sglang as sgl +import torch +from transformers import GenerationConfig + +from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer +from swift.plugin import Metric +from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid) +from .infer_engine import InferEngine + + +class SglangEngine(InferEngine): + sampling_parameters = { + 'max_new_tokens', 'stop', 'temperature', 'top_p', 'top_k', 'min_p', 'frequency_penalty', 'presence_penalty', + 'repetition_penalty', 'min_new_tokens', 'n' + } + + def __init__( + self, + model_id_or_path: str, + torch_dtype: Optional[torch.dtype] = None, + *, + model_type: Optional[str] = None, + use_hf: Optional[bool] = None, + hub_token: Optional[str] = None, + revision: Optional[str] = None, + ): + self.processor = get_model_tokenizer( + model_id_or_path, + torch_dtype, + load_model=False, + download_model=True, + model_type=model_type, + use_hf=use_hf, + hub_token=hub_token, + revision=revision)[1] + self._post_init() + self.engine = sgl.Engine(model_path=model_id_or_path, dtype=torch_dtype) + self._load_generation_config() + + def _load_generation_config(self) -> None: + generation_config_path = os.path.join(self.model_dir, 'generation_config.json') + if os.path.isfile(generation_config_path): + generation_config = GenerationConfig.from_pretrained(self.model_dir) + kwargs = generation_config.to_dict() + top_k = kwargs.get('top_k') + if top_k == 0: + kwargs['top_k'] = -1 + + for k, v in kwargs.copy().items(): + if k not in self.sampling_parameters or v is None: + kwargs.pop(k) + self.generation_config = kwargs + else: + self.generation_config = {} + + def _prepare_generation_config(self, request_config: RequestConfig) -> Dict[str, Any]: + kwargs = {'max_new_tokens': request_config.max_tokens} + for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: + new_value = getattr(request_config, key) + if new_value is None: + kwargs[key] = getattr(self.generation_config, key) + else: + kwargs[key] = new_value + for key in ['n', 'frequency_penalty', 'presence_penalty']: + kwargs[key] = getattr(request_config, key) + + return kwargs + + def _add_stop_words(self, generation_config: Dict[str, Any], request_config: RequestConfig, + template_meta: TemplateMeta) -> None: + stop_words = (request_config.stop or []) + (self.generation_config.get('stop') or []) + template_meta.stop_words + generation_config['stop'] = self._get_stop_words(stop_words) + + def infer(self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + template: Optional[Template] = None, + **kwargs) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: + request_config = deepcopy(request_config or RequestConfig()) + if template is None: + template = self.default_template + batched_inputs, error_list = self._batch_encode( + infer_requests, template=template, strict=getattr(self, 'strict', True)) + self.set_default_max_tokens(request_config, batched_inputs) + generation_config = self._prepare_generation_config(request_config) + self._add_stop_words(generation_config, request_config, template.template_meta) + input_ids = [inputs['input_ids'] for inputs in batched_inputs] + outputs = self.engine.generate(input_ids=input_ids, sampling_params=generation_config) + print() + + async def infer_async( + self, + infer_request: InferRequest, + request_config: Optional[RequestConfig] = None, + *, + template: Optional[Template] = None, + pre_infer_hook=None, + ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: + pass From 3db83402e9340b9137f3d9d85c9b1ef7bf82f1de Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 9 Apr 2025 13:28:15 +0800 Subject: [PATCH 2/3] update --- ...44\350\241\214\345\217\202\346\225\260.md" | 2 +- ...06\345\222\214\351\203\250\347\275\262.md" | 1 + .../Instruction/Command-line-parameters.md | 2 +- .../Instruction/Inference-and-deployment.md | 1 + swift/llm/__init__.py | 4 +-- swift/llm/argument/infer_args.py | 2 +- swift/llm/infer/__init__.py | 6 ++-- swift/llm/infer/infer_engine/__init__.py | 2 ++ swift/llm/infer/infer_engine/sglang_engine.py | 35 +++++++++++++------ 9 files changed, 36 insertions(+), 19 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 1837a54341..7fc496a125 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -437,7 +437,7 @@ repetition 奖励参数 推理参数除包含[基本参数](#基本参数)、[合并参数](#合并参数)、[vLLM参数](#vllm参数)、[LMDeploy参数](#LMDeploy参数)外,还包含下面的部分: -- 🔥infer_backend: 推理加速后端,支持'pt'、'vllm'、'lmdeploy'三种推理引擎。默认为'pt' +- 🔥infer_backend: 推理加速后端,支持'pt'、'vllm'、'sglang'、'lmdeploy'四种推理引擎。默认为'pt' - 🔥max_batch_size: 指定infer_backend为pt时生效,用于批量推理,默认为1 - ddp_backend: 指定infer_backend为pt时生效,用于指定多卡推理时的分布式后端,默认为None,进行自动选择。多卡推理例子可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/infer/pt) - 🔥result_path: 推理结果存储路径(jsonl),默认为None,保存在checkpoint目录(含args.json文件)或者'./result'目录,最终存储路径会在命令行中打印 diff --git "a/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" "b/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" index 1f036fe4d2..409781fa64 100644 --- "a/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" +++ "b/docs/source/Instruction/\346\216\250\347\220\206\345\222\214\351\203\250\347\275\262.md" @@ -6,6 +6,7 @@ | ------------ | -------------- | ---------- | ------ | -------- | ------ | ----- | ----- | | pytorch | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/client/llm/chat/openai_client.py) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/app/mllm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_lora.py) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/pt/batch_ddp.sh) |DDP/device_map | | [vllm](https://github.com/vllm-project/vllm) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/vllm/mllm_tp.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/lora/server.sh) | ❌ | ✅ | TP/PP/DP | +| [sglang](https://github.com/sgl-project/sglang) | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | TP/DP/EP | | [lmdeploy](https://github.com/InternLM/lmdeploy) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/lmdeploy/mllm_tp.sh) | ✅ | ❌ | ❌ | ✅ | TP/DP | diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 69705cc4d4..a1297379ec 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -450,7 +450,7 @@ repetition penalty function arguments Inference arguments include the [base arguments](#base-arguments), [merge arguments](#merge-arguments), [vLLM arguments](#vllm-arguments), [LMDeploy arguments](#LMDeploy-arguments), and also contain the following: -- 🔥infer_backend: Inference acceleration backend, supporting three inference engines: 'pt', 'vllm', and 'lmdeploy'. The default is 'pt'. +- 🔥infer_backend: Inference acceleration backend, supporting four inference engines: 'pt', 'vllm', 'sglang', and 'lmdeploy'. The default is 'pt'. - 🔥max_batch_size: Effective when infer_backend is set to 'pt'; used for batch inference, with a default value of 1. - ddp_backend: Effective when infer_backend is set to 'pt'; used to specify the distributed backend for multi-GPU inference. The default is None, which means automatic selection. For an example of multi-GPU inference, you can refer [here](https://github.com/modelscope/ms-swift/tree/main/examples/infer/pt). - 🔥result_path: Path to store inference results (jsonl). The default is None, meaning results are saved in the checkpoint directory (with args.json file) or './result' directory. The final storage path will be printed in the command line. diff --git a/docs/source_en/Instruction/Inference-and-deployment.md b/docs/source_en/Instruction/Inference-and-deployment.md index d19719b496..0ff7c2146a 100644 --- a/docs/source_en/Instruction/Inference-and-deployment.md +++ b/docs/source_en/Instruction/Inference-and-deployment.md @@ -6,6 +6,7 @@ Below are the inference engines supported by Swift along with their correspondin | ------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | --------------- | ------------------------------------------------------------ | ----- | ------------------------------------------------------------ | ------------------- | | pytorch | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/client/llm/chat/openai_client.py) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/app/mllm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_lora.py) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/pt/batch_ddp.sh) | DDP/device_map | | [vllm](https://github.com/vllm-project/vllm) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/vllm/mllm_tp.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/lora/server.sh) | ❌ | ✅ | TP/PP/DP | +| [sglang](https://github.com/sgl-project/sglang) | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | TP/DP/EP | | [lmdeploy](https://github.com/InternLM/lmdeploy) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/lmdeploy/mllm_tp.sh) | ✅ | ❌ | ❌ | ✅ | TP/DP | ## Inference diff --git a/swift/llm/__init__.py b/swift/llm/__init__.py index 1569a5303c..1dc5f77437 100644 --- a/swift/llm/__init__.py +++ b/swift/llm/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: # Recommend using `xxx_main` from .infer import (VllmEngine, RequestConfig, LmdeployEngine, PtEngine, InferEngine, infer_main, deploy_main, - InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine) + InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine, SglangEngine) from .export import (export_main, merge_lora, quantize_model, export_to_ollama) from .eval import eval_main from .app import app_main @@ -36,7 +36,7 @@ 'rlhf': ['rlhf_main'], 'infer': [ 'deploy_main', 'VllmEngine', 'RequestConfig', 'LmdeployEngine', 'PtEngine', 'infer_main', 'InferClient', - 'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine' + 'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine', 'SglangEngine' ], 'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'], 'app': ['app_main'], diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index 1c944a1278..e1b1eeeb20 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -116,7 +116,7 @@ class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArgum max_batch_size (int): Maximum batch size for the pt engine. Default is 1. val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None. """ - infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt' + infer_backend: Literal['vllm', 'pt', 'sglang', 'lmdeploy'] = 'pt' result_path: Optional[str] = None metric: Literal['acc', 'rouge'] = None diff --git a/swift/llm/infer/__init__.py b/swift/llm/infer/__init__.py index 984420201d..a4ac07601f 100644 --- a/swift/llm/infer/__init__.py +++ b/swift/llm/infer/__init__.py @@ -8,7 +8,7 @@ from .deploy import deploy_main, SwiftDeploy, run_deploy from .protocol import RequestConfig from .utils import prepare_model_template - from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferClient, + from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, SglangEngine, PtEngine, InferClient, prepare_generation_config, AdapterRequest, BaseInferEngine) else: _import_structure = { @@ -17,8 +17,8 @@ 'protocol': ['RequestConfig'], 'utils': ['prepare_model_template'], 'infer_engine': [ - 'InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferClient', 'prepare_generation_config', - 'AdapterRequest', 'BaseInferEngine' + 'InferEngine', 'VllmEngine', 'LmdeployEngine', 'SglangEngine', 'PtEngine', 'InferClient', + 'prepare_generation_config', 'AdapterRequest', 'BaseInferEngine' ], } diff --git a/swift/llm/infer/infer_engine/__init__.py b/swift/llm/infer/infer_engine/__init__.py index 668e0a726c..01ca779694 100644 --- a/swift/llm/infer/infer_engine/__init__.py +++ b/swift/llm/infer/infer_engine/__init__.py @@ -7,6 +7,7 @@ from .vllm_engine import VllmEngine from .grpo_vllm_engine import GRPOVllmEngine from .lmdeploy_engine import LmdeployEngine + from .sglang_engine import SglangEngine from .pt_engine import PtEngine from .infer_client import InferClient from .infer_engine import InferEngine @@ -17,6 +18,7 @@ 'vllm_engine': ['VllmEngine'], 'grpo_vllm_engine': ['GRPOVllmEngine'], 'lmdeploy_engine': ['LmdeployEngine'], + 'sglang_engine': ['SglangEngine'], 'pt_engine': ['PtEngine'], 'infer_client': ['InferClient'], 'infer_engine': ['InferEngine'], diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 048f8cb785..b9cd47d431 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -1,9 +1,11 @@ +import inspect import os from copy import deepcopy from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union import sglang as sgl import torch +from sglang.srt.sampling.sampling_params import SamplingParams from transformers import GenerationConfig from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer @@ -14,10 +16,6 @@ class SglangEngine(InferEngine): - sampling_parameters = { - 'max_new_tokens', 'stop', 'temperature', 'top_p', 'top_k', 'min_p', 'frequency_penalty', 'presence_penalty', - 'repetition_penalty', 'min_new_tokens', 'n' - } def __init__( self, @@ -39,7 +37,7 @@ def __init__( hub_token=hub_token, revision=revision)[1] self._post_init() - self.engine = sgl.Engine(model_path=model_id_or_path, dtype=torch_dtype) + self.engine = sgl.Engine(model_path=self.model_dir, dtype=self.model_info.torch_dtype) self._load_generation_config() def _load_generation_config(self) -> None: @@ -51,12 +49,13 @@ def _load_generation_config(self) -> None: if top_k == 0: kwargs['top_k'] = -1 + parameters = inspect.signature(SamplingParams).parameters for k, v in kwargs.copy().items(): - if k not in self.sampling_parameters or v is None: + if k not in parameters or v is None: kwargs.pop(k) - self.generation_config = kwargs + self.generation_config = SamplingParams(**kwargs) else: - self.generation_config = {} + self.generation_config = SamplingParams() def _prepare_generation_config(self, request_config: RequestConfig) -> Dict[str, Any]: kwargs = {'max_new_tokens': request_config.max_tokens} @@ -69,13 +68,26 @@ def _prepare_generation_config(self, request_config: RequestConfig) -> Dict[str, for key in ['n', 'frequency_penalty', 'presence_penalty']: kwargs[key] = getattr(request_config, key) - return kwargs + return SamplingParams(**kwargs) def _add_stop_words(self, generation_config: Dict[str, Any], request_config: RequestConfig, template_meta: TemplateMeta) -> None: - stop_words = (request_config.stop or []) + (self.generation_config.get('stop') or []) + template_meta.stop_words + stop_words = (request_config.stop or []) + (self.generation_config.stop or []) + template_meta.stop_words generation_config['stop'] = self._get_stop_words(stop_words) + def _create_chat_completion_response(self, output, template): + assert result is not None + meta_info = output['meta_info'] + usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens']) + response = output['text'] + toolcall = self._get_toolcall(response, template.tools_prompt) + choice = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + finish_reason=meta_info['finish_reason']['type'], + logprobs=None) + return ChatCompletionResponse(model=self.model_name, choices=[choice], usage=usage_info, id=request_id) + def infer(self, infer_requests: List[InferRequest], request_config: Optional[RequestConfig] = None, @@ -93,7 +105,8 @@ def infer(self, self._add_stop_words(generation_config, request_config, template.template_meta) input_ids = [inputs['input_ids'] for inputs in batched_inputs] outputs = self.engine.generate(input_ids=input_ids, sampling_params=generation_config) - print() + + return [self._create_chat_completion_response(output, template) for output in outputs] async def infer_async( self, From e6e3ee863feaeecce9b35b0f0d84418d92174e1d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 15 Apr 2025 01:44:28 +0800 Subject: [PATCH 3/3] update --- swift/llm/infer/infer_engine/sglang_engine.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index b9cd47d431..67f14ef44b 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -53,30 +53,30 @@ def _load_generation_config(self) -> None: for k, v in kwargs.copy().items(): if k not in parameters or v is None: kwargs.pop(k) - self.generation_config = SamplingParams(**kwargs) + self.generation_config = kwargs else: - self.generation_config = SamplingParams() + self.generation_config = {} def _prepare_generation_config(self, request_config: RequestConfig) -> Dict[str, Any]: kwargs = {'max_new_tokens': request_config.max_tokens} for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: new_value = getattr(request_config, key) if new_value is None: - kwargs[key] = getattr(self.generation_config, key) + kwargs[key] = self.generation_config.get(key) else: kwargs[key] = new_value for key in ['n', 'frequency_penalty', 'presence_penalty']: kwargs[key] = getattr(request_config, key) - return SamplingParams(**kwargs) + return kwargs def _add_stop_words(self, generation_config: Dict[str, Any], request_config: RequestConfig, template_meta: TemplateMeta) -> None: - stop_words = (request_config.stop or []) + (self.generation_config.stop or []) + template_meta.stop_words + stop_words = (request_config.stop or []) + (self.generation_config.get('stop') or []) + template_meta.stop_words generation_config['stop'] = self._get_stop_words(stop_words) def _create_chat_completion_response(self, output, template): - assert result is not None + assert output is not None meta_info = output['meta_info'] usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens']) response = output['text'] @@ -86,7 +86,7 @@ def _create_chat_completion_response(self, output, template): message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), finish_reason=meta_info['finish_reason']['type'], logprobs=None) - return ChatCompletionResponse(model=self.model_name, choices=[choice], usage=usage_info, id=request_id) + return ChatCompletionResponse(model=self.model_name, choices=[choice], usage=usage_info, id=random_uuid()) def infer(self, infer_requests: List[InferRequest], @@ -104,9 +104,12 @@ def infer(self, generation_config = self._prepare_generation_config(request_config) self._add_stop_words(generation_config, request_config, template.template_meta) input_ids = [inputs['input_ids'] for inputs in batched_inputs] - outputs = self.engine.generate(input_ids=input_ids, sampling_params=generation_config) + if request_config.stream: + pass + else: + outputs = self.engine.generate(input_ids=input_ids, sampling_params=generation_config) - return [self._create_chat_completion_response(output, template) for output in outputs] + return [self._create_chat_completion_response(output, template) for output in outputs] async def infer_async( self,