Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 29, 2024
1 parent 3ce178b commit 2df4cce
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ RLHF参数继承于[训练参数](#训练参数)
App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数)
- base_url: 模型部署的base_url,例如`http://localhost:8000/v1`。默认为`None`
- studio_title: studio的标题。默认为None,设置为模型名
- is_multimodal: 是否启动多模态版本的app。默认为None,自动根据model判断,若无法判断,设置为False
- lang: 覆盖Web-UI参数,默认为'en'

### 评测参数
Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ App parameters inherit from [deployment arguments](#deployment-arguments) and [W

- base_url: Base URL for the model deployment, for example, `http://localhost:8000/v1`. Default is None.
- studio_title: Title of the studio. Default is None, set to the model name.
- is_multimodal: Whether to launch the multimodal version of the app. Defaults to None, automatically determined based on the model; if it cannot be determined, set to False.
- lang: Overrides the Web-UI Arguments, default is 'en'.

### Evaluation Arguments
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def run(self):
base_url = base_url or args.base_url
demo = build_ui(
base_url,
args.model_suffix,
stream=args.stream,
is_multimodal=args.model_meta.is_multimodal,
is_multimodal=args.is_multimodal,
studio_title=args.studio_title,
lang=args.lang,
default_system=args.system)
Expand Down
15 changes: 9 additions & 6 deletions swift/llm/app/build_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def modify_system_session(system: str):
return system, '', []


def _history_to_messages(history: History, system: str):
def _history_to_messages(history: History, system: Optional[str]):
messages = []
if system is not None:
messages.append({'role': 'system', 'content': system})
Expand All @@ -43,12 +43,14 @@ def _history_to_messages(history: History, system: str):
return messages


def model_chat(history: History, system: str, *, client, stream: bool):
def model_chat(history: History, system: Optional[str], *, client, model: str, stream: bool):
if history:
from swift.llm import InferRequest, RequestConfig

messages = _history_to_messages(history, system)
gen_or_res = client.infer([InferRequest(messages=messages)], request_config=RequestConfig(stream=stream))
gen_or_res = client.infer([InferRequest(messages=messages)],
request_config=RequestConfig(stream=stream),
model=model)
if stream:
response = ''
for resp_list in gen_or_res:
Expand Down Expand Up @@ -81,6 +83,7 @@ def add_file(history: History, file):


def build_ui(base_url: str,
model: Optional[str] = None,
*,
stream: bool = True,
is_multimodal: bool = True,
Expand All @@ -89,8 +92,8 @@ def build_ui(base_url: str,
default_system: Optional[str] = None):
from swift.llm import InferClient
client = InferClient(base_url=base_url)
if studio_title is None:
studio_title = client.models[0]
model = model or client.models[0]
studio_title = studio_title or model
with gr.Blocks() as demo:
gr.Markdown(f'<center><font size=8>{studio_title}</center>')
with gr.Row():
Expand All @@ -108,7 +111,7 @@ def build_ui(base_url: str,
clear_history = gr.Button(locale_mapping['clear_history'][lang])

system_state = gr.State(value=default_system)
model_chat_ = partial(model_chat, client=client, stream=stream)
model_chat_ = partial(model_chat, client=client, model=model, stream=stream)

upload.upload(add_file, [chatbot, upload], [chatbot])
textbox.submit(add_text, [chatbot, textbox], [chatbot, textbox]).then(model_chat_, [chatbot, system_state],
Expand Down
14 changes: 10 additions & 4 deletions swift/llm/argument/app_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Literal, Optional

from swift.utils import find_free_port, get_logger
from ..model import get_matched_model_meta
from ..template import get_template_meta
from .deploy_args import DeployArguments
from .webui_args import WebUIArguments
Expand All @@ -14,18 +15,23 @@
class AppArguments(WebUIArguments, DeployArguments):
base_url: Optional[str] = None
studio_title: Optional[str] = None
is_multimodal: Optional[str] = None

lang: Literal['en', 'zh'] = 'en'

def _init_torch_dtype(self) -> None:
if self.base_url:
self.model_meta = get_matched_model_meta(self.model)
return
super()._init_torch_dtype()

def __post_init__(self):
super().__post_init__()
self.server_port = find_free_port(self.server_port)
if self.studio_title is None:
self.studio_title = self.model_suffix
if self.system is None:
self.system = get_template_meta(self.model_meta.template).default_system
if self.model_meta:
if self.system is None:
self.system = get_template_meta(self.model_meta.template).default_system
if self.is_multimodal is None:
self.is_multimodal = self.model_meta.is_multimodal
if self.is_multimodal is None:
self.is_multimodal = False
6 changes: 3 additions & 3 deletions swift/llm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .constant import LLMModelType, MLLMModelType, ModelType
from .model_arch import MODEL_ARCH_MAPPING, ModelArch, ModelKeys, MultiModelKeys, get_model_arch, register_model_arch
from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map,
get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer,
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head,
load_by_unsloth, register_model)
get_default_torch_dtype, get_matched_model_meta, get_model_info_meta, get_model_name,
get_model_tokenizer, get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn,
get_model_with_value_head, load_by_unsloth, register_model)
from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download
7 changes: 5 additions & 2 deletions tests/app/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ def test_mllm():


def test_audio():
from swift.llm import app_main, AppArguments
app_main(AppArguments(model='Qwen/Qwen2-Audio-7B-Instruct', stream=True))
from swift.llm import AppArguments, app_main, DeployArguments, run_deploy
deploy_args = DeployArguments(model='Qwen/Qwen2-Audio-7B-Instruct', infer_backend='pt', verbose=False)

with run_deploy(deploy_args, return_url=True) as url:
app_main(AppArguments(model='Qwen2-Audio-7B-Instruct', base_url=url, stream=True))


if __name__ == '__main__':
Expand Down

0 comments on commit 2df4cce

Please sign in to comment.