diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index d8b4564106..8e8c4138f3 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -489,6 +489,16 @@ async def internal_exception_handler(request: Request, exc: Exception): else None ), ) + self._router.add_api_route( + "/v1/convert_ids_to_tokens", + self.convert_ids_to_tokens, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:read"])] + if self.is_authenticated() + else None + ), + ) self._router.add_api_route( "/v1/rerank", self.rerank, @@ -1312,6 +1322,41 @@ async def create_embedding(self, request: Request) -> Response: await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) + async def convert_ids_to_tokens(self, request: Request) -> Response: + payload = await request.json() + body = CreateEmbeddingRequest.parse_obj(payload) + model_uid = body.model + exclude = { + "model", + "input", + "user", + } + kwargs = {key: value for key, value in payload.items() if key not in exclude} + + try: + model = await (await self._get_supervisor_ref()).get_model(model_uid) + except ValueError as ve: + logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + try: + decoded_texts = await model.convert_ids_to_tokens(body.input, **kwargs) + return Response(decoded_texts, media_type="application/json") + except RuntimeError as re: + logger.error(re, exc_info=True) + await self._report_error_event(model_uid, str(re)) + self.handle_request_limit_error(re) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + async def rerank(self, request: Request) -> Response: payload = await request.json() body = RerankRequest.parse_obj(payload) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index e145d963de..468f3fdbe3 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -126,6 +126,43 @@ def create_embedding(self, input: Union[str, List[str]], **kwargs) -> "Embedding response_data = response.json() return response_data + def convert_ids_to_tokens( + self, input: Union[List, List[List]], **kwargs + ) -> List[str]: + """ + Convert token IDs to human readable tokens via RESTful APIs. + + Parameters + ---------- + input: Union[List, List[List]] + Input token IDs to convert, can be a single list of token IDs or a list of token ID lists. + To convert multiple sequences in a single request, pass a list of token ID lists. + + Returns + ------- + list + A list of decoded tokens in human readable format. + + Raises + ------ + RuntimeError + Report the failure of token conversion and provide the error message. + + """ + url = f"{self._base_url}/v1/convert_ids_to_tokens" + request_body = { + "model": self._model_uid, + "input": input, + } + request_body.update(kwargs) + response = requests.post(url, json=request_body, headers=self.auth_headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to decode token ids, detail: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data + class RESTfulRerankModelHandle(RESTfulModelHandle): def rerank( diff --git a/xinference/core/model.py b/xinference/core/model.py index 42453ddc69..caf6a675b9 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -794,6 +794,19 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): f"Model {self._model.model_spec} is not for creating embedding." ) + @request_limit + @log_async(logger=logger) + async def convert_ids_to_tokens( + self, input: Union[List, List[List]], *args, **kwargs + ): + kwargs.pop("request_id", None) + if hasattr(self._model, "convert_ids_to_tokens"): + return await self._call_wrapper_json( + self._model.convert_ids_to_tokens, input, *args, **kwargs + ) + + raise AttributeError(f"Model {self._model.model_spec} can convert token id.") + @request_limit @log_async(logger=logger) async def rerank( diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index ae66b945b2..5848aa9289 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -193,6 +193,27 @@ def to(self, *args, **kwargs): device=self._device, model_kwargs=model_kwargs, ) + elif ( + self._kwargs.get("hybrid_mode") + and "m3" in self._model_spec.model_name.lower() + ): + try: + from FlagEmbedding import BGEM3FlagModel + except ImportError: + error_message = "Failed to import module 'BGEM3FlagModel'" + installation_guide = [ + "Please make sure 'FlagEmbedding' is installed. ", + "You can install it by `pip install FlagEmbedding`\n", + ] + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + + model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None + self._model = BGEM3FlagModel( + self._model_path, + device=self._device, + model_kwargs=model_kwargs, + trust_remote_code=True, + ) else: model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None self._model = SentenceTransformer( @@ -203,10 +224,155 @@ def to(self, *args, **kwargs): ) def create_embedding(self, sentences: Union[str, List[str]], **kwargs): + from FlagEmbedding import BGEM3FlagModel from sentence_transformers import SentenceTransformer kwargs.setdefault("normalize_embeddings", True) + @no_type_check + def _encode_bgem3( + model: Union[SentenceTransformer, BGEM3FlagModel], + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = None, + output_value: str = "sparse_embedding", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + **kwargs, + ): + """ + Computes sentence embeddings with bge-m3 model + Nothing special here, just replace sentence-transformer with FlagEmbedding + TODO: think about how to solve the redundant code of encode method in the future + + :param sentences: the sentences to embed + :param batch_size: the batch size used for the computation + :param show_progress_bar: Output a progress bar when encode sentences + :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values + :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. + :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy + :param device: Which torch.device to use for the computation + :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. + + :return: + By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. + """ + import torch + from tqdm.autonotebook import trange + + if show_progress_bar is None: + show_progress_bar = ( + logger.getEffectiveLevel() == logging.INFO + or logger.getEffectiveLevel() == logging.DEBUG + ) + + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sparse_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + if device is None: + # Same as SentenceTransformer.py + from sentence_transformers.util import get_device_name + + device = get_device_name() + logger.info(f"Use pytorch device_name: {device}") + + all_embeddings = [] + all_token_nums = 0 + + # The original code does not support other inference engines + def _text_length(text): + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + elif len(text) == 0 or isinstance( + text[0], int + ): # Empty string or list of ints + return len(text) + else: + return sum( + [len(t) for t in text] + ) # Sum of length of individual strings + + length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences]) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=not show_progress_bar, + ): + sentences_batch = sentences_sorted[ + start_index : start_index + batch_size + ] + + with torch.no_grad(): + out_features = model.encode(sentences_batch, **kwargs) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip( + out_features[output_value], out_features["attention_mask"] + ): + last_mask_id = len(attention) - 1 + while ( + last_mask_id > 0 and attention[last_mask_id].item() == 0 + ): + last_mask_id -= 1 + + embeddings.append(token_emb[0 : last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = { + name: out_features[name][sent_idx] + for name in out_features + } + embeddings.append(row) + # for sparse embedding + else: + if kwargs.get("return_sparse"): + embeddings = out_features["lexical_weights"] + else: + embeddings = out_features["dense_vecs"] + + if convert_to_numpy: + embeddings = embeddings.cpu() + + all_embeddings.extend(embeddings) + + all_embeddings = [ + all_embeddings[idx] for idx in np.argsort(length_sorted_idx) + ] + + if convert_to_tensor: + if len(all_embeddings): + all_embeddings = torch.stack(all_embeddings) + else: + all_embeddings = torch.Tensor() + elif convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings, all_token_nums + # copied from sentence-transformers, and modify it to return tokens num @no_type_check def encode( @@ -390,6 +556,10 @@ def encode( convert_to_numpy=False, **kwargs, ) + elif isinstance(self._model, BGEM3FlagModel): + all_embeddings, all_token_nums = _encode_bgem3( + self._model, sentences, convert_to_numpy=False, **kwargs + ) else: all_embeddings, all_token_nums = encode( self._model, @@ -401,14 +571,30 @@ def encode( all_embeddings = [all_embeddings] embedding_list = [] for index, data in enumerate(all_embeddings): - embedding_list.append( - EmbeddingData(index=index, object="embedding", embedding=data.tolist()) - ) + if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel): + embedding_list.append( + EmbeddingData( + index=index, + object="embedding", + embedding={k: float(v) for k, v in data.items()}, + ) + ) + else: + embedding_list.append( + EmbeddingData( + index=index, object="embedding", embedding=data.tolist() + ) + ) usage = EmbeddingUsage( prompt_tokens=all_token_nums, total_tokens=all_token_nums ) result = Embedding( - object="list", + object=( + "list" # type: ignore + if not isinstance(self._model, BGEM3FlagModel) + and not kwargs.get("return_sparse") + else "dict" + ), model=self._model_uid, data=embedding_list, usage=usage, @@ -430,6 +616,38 @@ def encode( return result + def convert_ids_to_tokens( + self, + batch_token_ids: Union[List[Union[int, str]], List[List[Union[int, str]]]], + **kwargs, + ) -> Union[List[str]]: + batch_decoded_texts: List[str] = [] + + assert self._model is not None + + if isinstance(batch_token_ids, (int, str)): + return self._model.tokenizer.convert_ids_to_tokens( + [int(str(batch_token_ids))] + )[0] + + # check if it's a nested list + if ( + isinstance(batch_token_ids, list) + and batch_token_ids + and isinstance(batch_token_ids[0], list) + ): + for token_ids in batch_token_ids: + token_ids = [int(token_id) for token_id in token_ids] + batch_decoded_texts.append( + self._model.tokenizer.convert_ids_to_tokens(token_ids) + ) + else: + batch_token_ids = [int(token_id) for token_id in batch_token_ids] + batch_decoded_texts = self._model.tokenizer.convert_ids_to_tokens( + batch_token_ids + ) + return batch_decoded_texts + def match_embedding( model_name: str, diff --git a/xinference/model/embedding/tests/test_embedding_models.py b/xinference/model/embedding/tests/test_embedding_models.py index 7ff47f7c83..a81fb8b07a 100644 --- a/xinference/model/embedding/tests/test_embedding_models.py +++ b/xinference/model/embedding/tests/test_embedding_models.py @@ -247,3 +247,19 @@ def test_register_fault_embedding(): assert any( "Invalid model URI /new_data/cache/gte-Qwen2" in str(r.message) for r in record ) + + +def test_convert_ids_to_tokens(): + from ..core import EmbeddingModel + + model_path = cache(TEST_MODEL_SPEC_FROM_MODELSCOPE) + model = EmbeddingModel("mock", model_path, TEST_MODEL_SPEC_FROM_MODELSCOPE) + model.load() + + ids = [[8074, 8059, 8064, 8056], [144, 147, 160, 160, 158]] + tokens = model.convert_ids_to_tokens(ids) + + assert isinstance(tokens, list) + assert tokens == [["x", "i", "n", "f"], ["b", "e", "r", "r", "p"]] + + shutil.rmtree(model_path, ignore_errors=True) diff --git a/xinference/types.py b/xinference/types.py index 613d8709bb..759cf0b7c4 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -71,7 +71,8 @@ class EmbeddingUsage(TypedDict): class EmbeddingData(TypedDict): index: int object: str - embedding: List[float] + # support sparse embedding + embedding: Union[List[float], Dict[str, float]] class Embedding(TypedDict):