Skip to content

[WIP] support sglang engine #3810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ soft overlong 奖励参数

推理参数除包含[基本参数](#基本参数)、[合并参数](#合并参数)、[vLLM参数](#vllm参数)、[LMDeploy参数](#LMDeploy参数)外,还包含下面的部分:

- 🔥infer_backend: 推理加速后端,支持'pt'、'vllm'、'lmdeploy'三种推理引擎。默认为'pt'
- 🔥infer_backend: 推理加速后端,支持'pt'、'vllm'、'sglang'、'lmdeploy'四种推理引擎。默认为'pt'
- 🔥max_batch_size: 指定infer_backend为pt时生效,用于批量推理,默认为1
- ddp_backend: 指定infer_backend为pt时生效,用于指定多卡推理时的分布式后端,默认为None,进行自动选择。多卡推理例子可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/infer/pt)
- 🔥result_path: 推理结果存储路径(jsonl),默认为None,保存在checkpoint目录(含args.json文件)或者'./result'目录,最终存储路径会在命令行中打印
Expand Down
1 change: 1 addition & 0 deletions docs/source/Instruction/推理和部署.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
| ------------ | -------------- | ---------- | ------ | -------- | ------ | ----- | ----- |
| pytorch | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/client/llm/chat/openai_client.py) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/app/mllm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_lora.py) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/pt/batch_ddp.sh) |DDP/device_map |
| [vllm](https://github.com/vllm-project/vllm) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/vllm/mllm_tp.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/lora/server.sh) | ❌ | ✅ | TP/PP/DP |
| [sglang](https://github.com/sgl-project/sglang) | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | TP/DP/EP |
| [lmdeploy](https://github.com/InternLM/lmdeploy) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/lmdeploy/mllm_tp.sh) | ✅ | ❌ | ❌ | ✅ | TP/DP |


Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ Soft overlong reward parameters:

Inference arguments include the [base arguments](#base-arguments), [merge arguments](#merge-arguments), [vLLM arguments](#vllm-arguments), [LMDeploy arguments](#LMDeploy-arguments), and also contain the following:

- 🔥infer_backend: Inference acceleration backend, supporting three inference engines: 'pt', 'vllm', and 'lmdeploy'. The default is 'pt'.
- 🔥infer_backend: Inference acceleration backend, supporting four inference engines: 'pt', 'vllm', 'sglang', and 'lmdeploy'. The default is 'pt'.
- 🔥max_batch_size: Effective when infer_backend is set to 'pt'; used for batch inference, with a default value of 1.
- ddp_backend: Effective when infer_backend is set to 'pt'; used to specify the distributed backend for multi-GPU inference. The default is None, which means automatic selection. For an example of multi-GPU inference, you can refer [here](https://github.com/modelscope/ms-swift/tree/main/examples/infer/pt).
- 🔥result_path: Path to store inference results (jsonl). The default is None, meaning results are saved in the checkpoint directory (with args.json file) or './result' directory. The final storage path will be printed in the command line.
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Inference-and-deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Below are the inference engines supported by Swift along with their correspondin
| ------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | --------------- | ------------------------------------------------------------ | ----- | ------------------------------------------------------------ | ------------------- |
| pytorch | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/client/llm/chat/openai_client.py) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/app/mllm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_lora.py) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/pt/batch_ddp.sh) | DDP/device_map |
| [vllm](https://github.com/vllm-project/vllm) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/vllm/mllm_tp.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/lora/server.sh) | ❌ | ✅ | TP/PP/DP |
| [sglang](https://github.com/sgl-project/sglang) | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | TP/DP/EP |
| [lmdeploy](https://github.com/InternLM/lmdeploy) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/infer/lmdeploy/mllm_tp.sh) | ✅ | ❌ | ❌ | ✅ | TP/DP |

## Inference
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
# Recommend using `xxx_main`
from .infer import (VllmEngine, RequestConfig, LmdeployEngine, PtEngine, InferEngine, infer_main, deploy_main,
InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine)
InferClient, run_deploy, AdapterRequest, prepare_model_template, BaseInferEngine, SglangEngine)
from .export import (export_main, merge_lora, quantize_model, export_to_ollama)
from .eval import eval_main
from .app import app_main
Expand Down Expand Up @@ -36,7 +36,7 @@
'rlhf': ['rlhf_main'],
'infer': [
'deploy_main', 'VllmEngine', 'RequestConfig', 'LmdeployEngine', 'PtEngine', 'infer_main', 'InferClient',
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine'
'run_deploy', 'InferEngine', 'AdapterRequest', 'prepare_model_template', 'BaseInferEngine', 'SglangEngine'
],
'export': ['export_main', 'merge_lora', 'quantize_model', 'export_to_ollama'],
'app': ['app_main'],
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class InferArguments(MergeArguments, VllmArguments, LmdeployArguments, BaseArgum
max_batch_size (int): Maximum batch size for the pt engine. Default is 1.
val_dataset_sample (Optional[int]): Sample size for validation dataset. Default is None.
"""
infer_backend: Literal['vllm', 'pt', 'lmdeploy'] = 'pt'
infer_backend: Literal['vllm', 'pt', 'sglang', 'lmdeploy'] = 'pt'

result_path: Optional[str] = None
metric: Literal['acc', 'rouge'] = None
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .deploy import deploy_main, SwiftDeploy, run_deploy
from .protocol import RequestConfig, Function
from .utils import prepare_model_template
from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, PtEngine, InferClient,
from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, SglangEngine, PtEngine, InferClient,
prepare_generation_config, AdapterRequest, BaseInferEngine)
else:
_import_structure = {
Expand All @@ -17,8 +17,8 @@
'protocol': ['RequestConfig', 'Function'],
'utils': ['prepare_model_template'],
'infer_engine': [
'InferEngine', 'VllmEngine', 'LmdeployEngine', 'PtEngine', 'InferClient', 'prepare_generation_config',
'AdapterRequest', 'BaseInferEngine'
'InferEngine', 'VllmEngine', 'LmdeployEngine', 'SglangEngine', 'PtEngine', 'InferClient',
'prepare_generation_config', 'AdapterRequest', 'BaseInferEngine'
],
}

Expand Down
2 changes: 2 additions & 0 deletions swift/llm/infer/infer_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .vllm_engine import VllmEngine
from .grpo_vllm_engine import GRPOVllmEngine
from .lmdeploy_engine import LmdeployEngine
from .sglang_engine import SglangEngine
from .pt_engine import PtEngine
from .infer_client import InferClient
from .infer_engine import InferEngine
Expand All @@ -17,6 +18,7 @@
'vllm_engine': ['VllmEngine'],
'grpo_vllm_engine': ['GRPOVllmEngine'],
'lmdeploy_engine': ['LmdeployEngine'],
'sglang_engine': ['SglangEngine'],
'pt_engine': ['PtEngine'],
'infer_client': ['InferClient'],
'infer_engine': ['InferEngine'],
Expand Down
3 changes: 0 additions & 3 deletions swift/llm/infer/infer_engine/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ def _load_generation_config(self):
if os.path.isfile(generation_config_path):
generation_config = GenerationConfig.from_pretrained(self.model_dir)
kwargs = generation_config.to_dict()
max_new_tokens = kwargs.get('max_new_tokens')
if max_new_tokens is None:
kwargs.pop('max_new_tokens', None)
parameters = inspect.signature(LmdeployGenerationConfig).parameters
for k, v in kwargs.copy().items():
if k not in parameters or v is None:
Expand Down
122 changes: 122 additions & 0 deletions swift/llm/infer/infer_engine/sglang_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import inspect
import os
from copy import deepcopy
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

import sglang as sgl
import torch
from sglang.srt.sampling.sampling_params import SamplingParams
from transformers import GenerationConfig

from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer
from swift.plugin import Metric
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid)
from .infer_engine import InferEngine


class SglangEngine(InferEngine):

def __init__(
self,
model_id_or_path: str,
torch_dtype: Optional[torch.dtype] = None,
*,
model_type: Optional[str] = None,
use_hf: Optional[bool] = None,
hub_token: Optional[str] = None,
revision: Optional[str] = None,
):
self.processor = get_model_tokenizer(
model_id_or_path,
torch_dtype,
load_model=False,
download_model=True,
model_type=model_type,
use_hf=use_hf,
hub_token=hub_token,
revision=revision)[1]
self._post_init()
self.engine = sgl.Engine(model_path=self.model_dir, dtype=self.model_info.torch_dtype)
self._load_generation_config()

def _load_generation_config(self) -> None:
generation_config_path = os.path.join(self.model_dir, 'generation_config.json')
if os.path.isfile(generation_config_path):
generation_config = GenerationConfig.from_pretrained(self.model_dir)
kwargs = generation_config.to_dict()
top_k = kwargs.get('top_k')
if top_k == 0:
kwargs['top_k'] = -1

parameters = inspect.signature(SamplingParams).parameters
for k, v in kwargs.copy().items():
if k not in parameters or v is None:
kwargs.pop(k)
self.generation_config = kwargs
else:
self.generation_config = {}

def _prepare_generation_config(self, request_config: RequestConfig) -> Dict[str, Any]:
kwargs = {'max_new_tokens': request_config.max_tokens}
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
new_value = getattr(request_config, key)
if new_value is None:
kwargs[key] = self.generation_config.get(key)
else:
kwargs[key] = new_value
for key in ['n', 'frequency_penalty', 'presence_penalty']:
kwargs[key] = getattr(request_config, key)

return kwargs

def _add_stop_words(self, generation_config: Dict[str, Any], request_config: RequestConfig,
template_meta: TemplateMeta) -> None:
stop_words = (request_config.stop or []) + (self.generation_config.get('stop') or []) + template_meta.stop_words
generation_config['stop'] = self._get_stop_words(stop_words)

def _create_chat_completion_response(self, output, template):
assert output is not None
meta_info = output['meta_info']
usage_info = self._get_usage_info(meta_info['prompt_tokens'], meta_info['completion_tokens'])
response = output['text']
toolcall = self._get_toolcall(response, template.tools_prompt)
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=meta_info['finish_reason']['type'],
logprobs=None)
return ChatCompletionResponse(model=self.model_name, choices=[choice], usage=usage_info, id=random_uuid())

def infer(self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
template: Optional[Template] = None,
**kwargs) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]:
request_config = deepcopy(request_config or RequestConfig())
if template is None:
template = self.default_template
batched_inputs, error_list = self._batch_encode(
infer_requests, template=template, strict=getattr(self, 'strict', True))
self.set_default_max_tokens(request_config, batched_inputs)
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config, template.template_meta)
input_ids = [inputs['input_ids'] for inputs in batched_inputs]
if request_config.stream:
pass
else:
outputs = self.engine.generate(input_ids=input_ids, sampling_params=generation_config)

return [self._create_chat_completion_response(output, template) for output in outputs]

async def infer_async(
self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
template: Optional[Template] = None,
pre_infer_hook=None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
pass