Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
amumu96 committed Jan 6, 2025
1 parent 4ad3085 commit 1ea2c11
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 25 deletions.
2 changes: 1 addition & 1 deletion xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -9015,7 +9015,7 @@
"model_id": "THUDM/cogagent-9b-20241220"
}
],
"chat_template": "",
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"stop_token_ids": [
151329,
151336,
Expand Down
2 changes: 1 addition & 1 deletion xinference/model/llm/llm_family_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -6749,7 +6749,7 @@
"model_hub": "modelscope"
}
],
"chat_template": "",
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"stop_token_ids": [
151329,
151336,
Expand Down
54 changes: 34 additions & 20 deletions xinference/model/llm/transformers/cogagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,25 @@
import re
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Iterator, List, Literal, Optional, Union
from typing import Dict, Iterator, List, Literal, Optional, Union, cast

import torch

from ....model.utils import select_device
from ....types import ChatCompletion, ChatCompletionChunk, CompletionChunk
from ....types import (
ChatCompletion,
ChatCompletionChunk,
CogagentGenerateConfig,
CompletionChunk,
)
from ..llm_family import LLMFamilyV1, LLMSpecV1
from ..utils import (
_decode_image,
generate_chat_completion,
generate_completion_chunk,
parse_messages,
)
from .core import PytorchChatModel, PytorchGenerateConfig
from .core import PytorchChatModel
from .utils import cache_clean

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,7 +122,7 @@ def _history_content_to_cogagent(self, chat_history: List[Dict]):
action_pattern = r"Action:\s*(.*)"

def extract_operations(_content: str):
"""提取 grounded operation action operation"""
"""extract grounded operation and action operation"""
_history_step = []
_history_action = []

Expand Down Expand Up @@ -178,28 +183,32 @@ def get_query_and_history(

# Compose the query with task, platform, and selected format instructions
query = f"Task: {task}{history_str}\n{self._platform}{self._format}"

logger.info(f"query:{query}")
return query, image

def _sanitize_generate_config(
self,
generate_config: Optional[PytorchGenerateConfig],
) -> PytorchGenerateConfig:
generate_config: Optional[CogagentGenerateConfig],
) -> CogagentGenerateConfig:
logger.info(f"generate_config:{generate_config}")

generate_config = super()._sanitize_generate_config(generate_config)
return generate_config
return cast(CogagentGenerateConfig, generate_config)

@cache_clean
def chat(
self,
messages: List[Dict],
generate_config: Optional[PytorchGenerateConfig] = None,
generate_config: Optional[CogagentGenerateConfig] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
stream = generate_config.get("stream", False) if generate_config else False
generate_config = self._sanitize_generate_config(generate_config)

self._platform = generate_config.get("platform") or self._platform
self._format = generate_config.get("format") or self._format
if generate_config is not None:
self._platform = generate_config.pop("platform", self._platform)
self._format = generate_config.pop("format", self._format)
logger.info(f"_platform:{self._platform}")
logger.info(f"_format:{self._format}")

generate_config = self._sanitize_generate_config(generate_config)
stream = generate_config.get("stream")
sanitized_config = {
"max_length": generate_config.get("max_tokens", 512),
"top_k": generate_config.get("top_k", 1),
Expand All @@ -209,13 +218,19 @@ def chat(

query, image = self.get_query_and_history(prompt, chat_history)

inputs = self._tokenizer.apply_chat_template(
full_context_kwargs = {
"return_tensors": "pt",
"return_dict": True,
}
assert self.model_family.chat_template is not None
inputs = self.get_full_context(
[{"role": "user", "image": image, "content": query}],
add_generation_prompt=True,
self.model_family.chat_template,
self._tokenizer,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to(self._model.device)
**full_context_kwargs,
)
inputs.to(self._model.device)

if stream:
it = self._streaming_chat_response(inputs, sanitized_config)
Expand All @@ -226,7 +241,6 @@ def chat(
outputs = self._model.generate(**inputs, **sanitized_config)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
response = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.info(f"Model response:\n{response}")

return generate_chat_completion(self.model_uid, response)

Expand Down
11 changes: 8 additions & 3 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,18 @@ def _build_from_raw_template(
return rendered

def get_full_context(
self, messages: List, chat_template: str, tokenizer=None, **kwargs
) -> str:
self,
messages: List,
chat_template: str,
tokenizer=None,
tokenize=False,
**kwargs,
):
if tokenizer is not None:
try:
full_context = tokenizer.apply_chat_template(
messages,
tokenize=False,
tokenize=tokenize,
chat_template=chat_template,
add_generation_prompt=True,
**kwargs,
Expand Down
17 changes: 17 additions & 0 deletions xinference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ class PytorchGenerateConfig(TypedDict, total=False):
lora_name: Optional[str]
stream_options: Optional[Union[dict, None]]
request_id: Optional[str]


class CogagentGenerateConfig(PytorchGenerateConfig, total=False):
platform: Optional[Literal["Mac", "WIN", "Mobile"]]
format: Optional[
Literal[
Expand Down Expand Up @@ -440,6 +443,19 @@ class CreateChatModel(BaseModel):
CreateChatCompletionLlamaCpp: BaseModel = CreateCompletionLlamaCpp


class CreateExtraChatCompletion(BaseModel):
platform: Optional[Literal["Mac", "WIN", "Mobile"]]
format: Optional[
Literal[
"(Answer in Action-Operation-Sensitive format.)",
"(Answer in Status-Plan-Action-Operation format.)",
"(Answer in Status-Action-Operation-Sensitive format.)",
"(Answer in Status-Action-Operation format.)",
"(Answer in Action-Operation format.)",
]
]


from ._compat import CreateChatCompletionOpenAI


Expand All @@ -448,6 +464,7 @@ class CreateChatCompletion( # type: ignore
CreateChatCompletionTorch,
CreateChatCompletionLlamaCpp,
CreateChatCompletionOpenAI,
CreateExtraChatCompletion,
):
pass

Expand Down

0 comments on commit 1ea2c11

Please sign in to comment.