From 1ea2c11cf59181db208198cbdafdbf7064e23a86 Mon Sep 17 00:00:00 2001 From: amumu96 <15667065080@163.com> Date: Fri, 3 Jan 2025 17:14:22 +0800 Subject: [PATCH] fix --- xinference/model/llm/llm_family.json | 2 +- .../model/llm/llm_family_modelscope.json | 2 +- xinference/model/llm/transformers/cogagent.py | 54 ++++++++++++------- xinference/model/llm/utils.py | 11 ++-- xinference/types.py | 17 ++++++ 5 files changed, 61 insertions(+), 25 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 5a1a46e0aa..030c911dc3 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -9015,7 +9015,7 @@ "model_id": "THUDM/cogagent-9b-20241220" } ], - "chat_template": "", + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "stop_token_ids": [ 151329, 151336, diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 0175c7ba76..0908d314a2 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -6749,7 +6749,7 @@ "model_hub": "modelscope" } ], - "chat_template": "", + "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "stop_token_ids": [ 151329, 151336, diff --git a/xinference/model/llm/transformers/cogagent.py b/xinference/model/llm/transformers/cogagent.py index e828c5e2d3..cc2d76795e 100644 --- a/xinference/model/llm/transformers/cogagent.py +++ b/xinference/model/llm/transformers/cogagent.py @@ -15,12 +15,17 @@ import re import uuid from concurrent.futures import ThreadPoolExecutor -from typing import Dict, Iterator, List, Literal, Optional, Union +from typing import Dict, Iterator, List, Literal, Optional, Union, cast import torch from ....model.utils import select_device -from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk +from ....types import ( + ChatCompletion, + ChatCompletionChunk, + CogagentGenerateConfig, + CompletionChunk, +) from ..llm_family import LLMFamilyV1, LLMSpecV1 from ..utils import ( _decode_image, @@ -28,7 +33,7 @@ generate_completion_chunk, parse_messages, ) -from .core import PytorchChatModel, PytorchGenerateConfig +from .core import PytorchChatModel from .utils import cache_clean logger = logging.getLogger(__name__) @@ -117,7 +122,7 @@ def _history_content_to_cogagent(self, chat_history: List[Dict]): action_pattern = r"Action:\s*(.*)" def extract_operations(_content: str): - """提取 grounded operation 和 action operation""" + """extract grounded operation and action operation""" _history_step = [] _history_action = [] @@ -178,28 +183,32 @@ def get_query_and_history( # Compose the query with task, platform, and selected format instructions query = f"Task: {task}{history_str}\n{self._platform}{self._format}" - + logger.info(f"query:{query}") return query, image def _sanitize_generate_config( self, - generate_config: Optional[PytorchGenerateConfig], - ) -> PytorchGenerateConfig: + generate_config: Optional[CogagentGenerateConfig], + ) -> CogagentGenerateConfig: + logger.info(f"generate_config:{generate_config}") + generate_config = super()._sanitize_generate_config(generate_config) - return generate_config + return cast(CogagentGenerateConfig, generate_config) @cache_clean def chat( self, messages: List[Dict], - generate_config: Optional[PytorchGenerateConfig] = None, + generate_config: Optional[CogagentGenerateConfig] = None, ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: - stream = generate_config.get("stream", False) if generate_config else False - generate_config = self._sanitize_generate_config(generate_config) - - self._platform = generate_config.get("platform") or self._platform - self._format = generate_config.get("format") or self._format + if generate_config is not None: + self._platform = generate_config.pop("platform", self._platform) + self._format = generate_config.pop("format", self._format) + logger.info(f"_platform:{self._platform}") + logger.info(f"_format:{self._format}") + generate_config = self._sanitize_generate_config(generate_config) + stream = generate_config.get("stream") sanitized_config = { "max_length": generate_config.get("max_tokens", 512), "top_k": generate_config.get("top_k", 1), @@ -209,13 +218,19 @@ def chat( query, image = self.get_query_and_history(prompt, chat_history) - inputs = self._tokenizer.apply_chat_template( + full_context_kwargs = { + "return_tensors": "pt", + "return_dict": True, + } + assert self.model_family.chat_template is not None + inputs = self.get_full_context( [{"role": "user", "image": image, "content": query}], - add_generation_prompt=True, + self.model_family.chat_template, + self._tokenizer, tokenize=True, - return_tensors="pt", - return_dict=True, - ).to(self._model.device) + **full_context_kwargs, + ) + inputs.to(self._model.device) if stream: it = self._streaming_chat_response(inputs, sanitized_config) @@ -226,7 +241,6 @@ def chat( outputs = self._model.generate(**inputs, **sanitized_config) outputs = outputs[:, inputs["input_ids"].shape[1] :] response = self._tokenizer.decode(outputs[0], skip_special_tokens=True) - logger.info(f"Model response:\n{response}") return generate_chat_completion(self.model_uid, response) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index b7d07e5224..0bedbc4155 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -97,13 +97,18 @@ def _build_from_raw_template( return rendered def get_full_context( - self, messages: List, chat_template: str, tokenizer=None, **kwargs - ) -> str: + self, + messages: List, + chat_template: str, + tokenizer=None, + tokenize=False, + **kwargs, + ): if tokenizer is not None: try: full_context = tokenizer.apply_chat_template( messages, - tokenize=False, + tokenize=tokenize, chat_template=chat_template, add_generation_prompt=True, **kwargs, diff --git a/xinference/types.py b/xinference/types.py index 4f1b286648..e725239986 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -300,6 +300,9 @@ class PytorchGenerateConfig(TypedDict, total=False): lora_name: Optional[str] stream_options: Optional[Union[dict, None]] request_id: Optional[str] + + +class CogagentGenerateConfig(PytorchGenerateConfig, total=False): platform: Optional[Literal["Mac", "WIN", "Mobile"]] format: Optional[ Literal[ @@ -440,6 +443,19 @@ class CreateChatModel(BaseModel): CreateChatCompletionLlamaCpp: BaseModel = CreateCompletionLlamaCpp +class CreateExtraChatCompletion(BaseModel): + platform: Optional[Literal["Mac", "WIN", "Mobile"]] + format: Optional[ + Literal[ + "(Answer in Action-Operation-Sensitive format.)", + "(Answer in Status-Plan-Action-Operation format.)", + "(Answer in Status-Action-Operation-Sensitive format.)", + "(Answer in Status-Action-Operation format.)", + "(Answer in Action-Operation format.)", + ] + ] + + from ._compat import CreateChatCompletionOpenAI @@ -448,6 +464,7 @@ class CreateChatCompletion( # type: ignore CreateChatCompletionTorch, CreateChatCompletionLlamaCpp, CreateChatCompletionOpenAI, + CreateExtraChatCompletion, ): pass