Skip to content

Commit 7ff5402

Browse files
committed
修复 glm 4.1v
1 parent 288c67e commit 7ff5402

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

gpt_server/model_backend/vllm_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,13 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
6969
request_id = params.get("request_id", "0")
7070
temperature = float(params.get("temperature", 0.8))
7171
top_p = float(params.get("top_p", 0.8))
72-
top_k = params.get("top_k", -1.0)
72+
top_k = int(params.get("top_k", 0))
7373
max_new_tokens = int(params.get("max_new_tokens", 1024 * 8))
7474
stop_str = params.get("stop", None)
7575
stop_token_ids = params.get("stop_words_ids", None) or []
7676
presence_penalty = float(params.get("presence_penalty", 0.0))
7777
frequency_penalty = float(params.get("frequency_penalty", 0.0))
78+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
7879
request = params.get("request", None)
7980
# Handle stop_str
8081
stop = set()
@@ -145,6 +146,7 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
145146
stop_token_ids=stop_token_ids,
146147
presence_penalty=presence_penalty,
147148
frequency_penalty=frequency_penalty,
149+
repetition_penalty=repetition_penalty,
148150
guided_decoding=guided_decoding,
149151
)
150152
lora_request = None

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
# logger.info(f"模型配置:{self.model_config}")
9191
self.vision_config = getattr(self.model_config, "vision_config", None)
9292
is_vision = self.vision_config is not None
93+
if is_vision:
94+
multimodal = True
95+
logger.warning(f"{model_names[0]} 是多模态模型")
9396
super().__init__(
9497
controller_addr,
9598
worker_addr,
@@ -98,7 +101,7 @@ def __init__(
98101
model_names,
99102
limit_worker_concurrency,
100103
conv_template,
101-
multimodal=multimodal or is_vision,
104+
multimodal=multimodal,
102105
)
103106
os.environ["WORKER_NAME"] = self.__class__.__name__
104107
self.worker_name = self.__class__.__name__

gpt_server/model_worker/chatglm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ async def generate_stream_gate(self, params):
7777
# text = self.tokenizer.decode(input_ids.tolist()[0])
7878
params["prompt"] = text
7979
# params["input_ids"] = input_ids
80-
80+
else: # 多模态模型
81+
params["multimodal"] = True
8182
# ---------------添加额外的参数------------------------
8283
params["messages"] = messages
8384
params["stop"].extend(self.stop)

tests/test_openai_vl_chat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
from openai import OpenAI
3+
from pathlib import Path
34

45

56
def image_to_base64(image_path):
@@ -11,7 +12,7 @@ def image_to_base64(image_path):
1112
return base64_prefix + base64_string
1213

1314

14-
image_path = "../assets/logo.png"
15+
image_path = Path(__file__).parent.parent / "assets/logo.png"
1516
# 使用本地的图片
1617
url = image_to_base64(image_path)
1718
# 使用网络图片
@@ -22,7 +23,7 @@ def image_to_base64(image_path):
2223

2324
stream = True
2425
output = client.chat.completions.create(
25-
model="minicpmv", # internlm chatglm3 qwen llama3 chatglm4
26+
model="glm4.1v", # internlm chatglm3 qwen llama3 chatglm4
2627
messages=[
2728
{
2829
"role": "user",

0 commit comments

Comments
 (0)