Skip to content

FEAT: jina-reranker-m0 #3209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions xinference/model/rerank/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def _auto_detect_type(model_path):
"LlamaTokenizerFast": "LLM-based layerwise",
"GemmaTokenizerFast": "LLM-based",
"XLMRobertaTokenizerFast": "normal",
"CLIPTokenizerFast": "LLM-based multimodal",
}

tokenizer = RerankModel._get_tokenizer(model_path)
Expand Down Expand Up @@ -229,12 +230,40 @@ def load(self):
)
if self._use_fp16:
self._model.model.half()
elif self._model_spec.type == "LLM-based multimodal":
# 添加多模态模型加载逻辑
try:
from transformers import AutoModel

attn_implementation = (
"flash_attention_2" if flash_attn_installed else None
)
self._model = AutoModel.from_pretrained(
self._model_path,
torch_dtype="auto" if self._use_fp16 else None,
trust_remote_code=True,
attn_implementation=attn_implementation,
)

if self._device:
self._model.to(self._device)
self._model.eval()
return

except ImportError:
error_message = "Failed to import module 'transformers'"
installation_guide = [
"Please make sure 'transformers>=4.47.3' is installed. ",
"You can install it by `pip install 'transformers>=4.47.3'`\n",
]
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
else:
try:
if self._model_spec.type == "LLM-based":
from FlagEmbedding import FlagLLMReranker as FlagReranker
elif self._model_spec.type == "LLM-based layerwise":
from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker

else:
raise RuntimeError(
f"Unsupported Rank model type: {self._model_spec.type}"
Expand Down Expand Up @@ -265,6 +294,7 @@ def rerank(
if max_chunks_per_doc is not None:
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)

sentence_combinations = [[query, doc] for doc in documents]
# reset n tokens
self._model.model.n_tokens = 0
Expand All @@ -277,6 +307,37 @@ def rerank(
).cpu()
if similarity_scores.dtype == torch.bfloat16:
similarity_scores = similarity_scores.float()
elif self._model_spec.type == "LLM-based multimodal":
# 获取文档类型,默认为text
doc_type = kwargs.pop("doc_type", "text")

# 检查模型是否支持该文档类型
if (
hasattr(self._model_spec, "supported_doc_types")
and doc_type not in self._model_spec.supported_doc_types
):
raise ValueError(
f"Model {self._model_spec.model_name} does not support document type: {doc_type}"
)


# 多模态模型处理逻辑
max_length = kwargs.pop("max_length", 1024)
similarity_scores = self._model.compute_score(
sentence_combinations,
max_length=max_length,
doc_type=doc_type,
**kwargs,
)

if not isinstance(similarity_scores, Sequence):
similarity_scores = [similarity_scores]
elif (
isinstance(similarity_scores, list)
and len(similarity_scores) > 0
and isinstance(similarity_scores[0], Sequence)
):
similarity_scores = similarity_scores[0]
else:
# Related issue: https://github.com/xorbitsai/inference/issues/1775
similarity_scores = self._model.compute_score(
Expand Down Expand Up @@ -340,6 +401,16 @@ def rerank(
gc.collect()
empty_cache()

if self._counter > 10:
items = []
for i in range(10):
items.append(docs[i])
items = list(items)
else:
items.append(docs[0])

items.append(docs[0])

return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata)


Expand Down
8 changes: 8 additions & 0 deletions xinference/model/rerank/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,13 @@
"max_tokens": 1024,
"model_id": "openbmb/MiniCPM-Reranker",
"model_revision": "5d2fd7345b6444c89d4c0fa59c92272888f3f2d0"
},
{
"model_name": "jina-reranker-m0",
"type": "LLM-based multimodal",
"language": ["en", "zh"],
"max_tokens": 10240,
"model_id": "jinaai/jina-reranker-m0",
"model_revision": "main"
}
]
9 changes: 9 additions & 0 deletions xinference/model/rerank/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,14 @@
"max_tokens": 1024,
"model_id": "OpenBMB/MiniCPM-Reranker",
"model_hub": "modelscope"
},
{
"model_name": "jina-reranker-m0",
"type": "LLM-based multimodal",
"language": ["en", "zh"],
"max_tokens": 10240,
"model_id": "jinaai/jina-reranker-m0",
"model_revision": "master",
"model_hub": "modelscope"
}
]
Loading