From fda1a56ef48b331a980b221deae89e402265bddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=97=A0=E5=89=91?= <1025360135@qq.com> Date: Tue, 25 Jun 2024 19:50:38 +0800 Subject: [PATCH] feat(model): Support tongyi embedding (#1552) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 无剑 Co-authored-by: csunny Co-authored-by: aries_ckt <916701291@qq.com> --- .env.template | 5 ++ dbgpt/configs/model_config.py | 1 + dbgpt/model/adapter/embeddings_loader.py | 8 +++ dbgpt/model/parameter.py | 3 +- dbgpt/rag/embedding/__init__.py | 2 + dbgpt/rag/embedding/embeddings.py | 80 ++++++++++++++++++++++++ 6 files changed, 97 insertions(+), 2 deletions(-) diff --git a/.env.template b/.env.template index 313e027e1..da6fb3bca 100644 --- a/.env.template +++ b/.env.template @@ -92,6 +92,11 @@ KNOWLEDGE_SEARCH_REWRITE=False # proxy_openai_proxy_api_key={your-openai-sk} # proxy_openai_proxy_backend=text-embedding-ada-002 + +## qwen embedding model, See dbgpt/model/parameter.py +# EMBEDDING_MODEL=proxy_tongyi +# proxy_tongyi_proxy_backend=text-embedding-v1 + ## Common HTTP embedding model # EMBEDDING_MODEL=proxy_http_openapi # proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index fefd4c953..0876c29dc 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -263,6 +263,7 @@ def get_device() -> str: # Common HTTP embedding model "proxy_http_openapi": "proxy_http_openapi", "proxy_ollama": "proxy_ollama", + "proxy_tongyi": "proxy_tongyi", # Rerank model, rerank mode is a special embedding model "bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"), "bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"), diff --git a/dbgpt/model/adapter/embeddings_loader.py b/dbgpt/model/adapter/embeddings_loader.py index 1731d21e5..b10cf34af 100644 --- a/dbgpt/model/adapter/embeddings_loader.py +++ b/dbgpt/model/adapter/embeddings_loader.py @@ -50,6 +50,14 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin if proxy_param.proxy_backend: openapi_param["model_name"] = proxy_param.proxy_backend return OpenAPIEmbeddings(**openapi_param) + elif model_name in ["proxy_tongyi"]: + from dbgpt.rag.embedding import TongYiEmbeddings + + proxy_param = cast(ProxyEmbeddingParameters, param) + tongyi_param = {"api_key": proxy_param.proxy_api_key} + if proxy_param.proxy_backend: + tongyi_param["model_name"] = proxy_param.proxy_backend + return TongYiEmbeddings(**tongyi_param) elif model_name in ["proxy_ollama"]: from dbgpt.rag.embedding import OllamaEmbeddings diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 5aad75a50..d549d8db7 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -665,8 +665,7 @@ def is_rerank_model(self) -> bool: _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { - ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi," - "proxy_ollama,rerank_proxy_http_openapi", + ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi", } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index ce38d77e1..fcd4590f9 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -14,6 +14,7 @@ JinaEmbeddings, OllamaEmbeddings, OpenAPIEmbeddings, + TongYiEmbeddings, ) from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401 @@ -29,6 +30,7 @@ "DefaultEmbeddingFactory", "EmbeddingFactory", "WrappedEmbeddingFactory", + "TongYiEmbeddings", "CrossEncoderRerankEmbeddings", "OpenAPIRerankEmbeddings", ] diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 401030e75..30e7ba949 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -826,3 +826,83 @@ async def aembed_query(self, text: str) -> List[float]: return embedding["embedding"] except ollama.ResponseError as e: raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}") + + +class TongYiEmbeddings(BaseModel, Embeddings): + """The tongyi embeddings. + + import dashscope + from http import HTTPStatus + from dashscope import TextEmbedding + + dashscope.api_key = '' + def embed_with_list_of_str(): + resp = TextEmbedding.call( + model=TextEmbedding.Models.text_embedding_v1, + # 最多支持10条,每条最长支持2048tokens + input=['风急天高猿啸哀', '渚清沙白鸟飞回', '无边落木萧萧下', '不尽长江滚滚来'] + ) + if resp.status_code == HTTPStatus.OK: + print(resp) + else: + print(resp) + + if __name__ == '__main__': + embed_with_list_of_str() + """ + + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + api_key: Optional[str] = Field( + default=None, description="The API key for the embeddings API." + ) + model_name: str = Field( + default="text-embedding-v1", description="The name of the model to use." + ) + + def __init__(self, **kwargs): + """Initialize the OpenAPIEmbeddings.""" + try: + import dashscope # type: ignore + except ImportError as exc: + raise ValueError( + "Could not import python package: dashscope " + "Please install dashscope by command `pip install dashscope" + ) from exc + dashscope.TextEmbedding.api_key = kwargs.get("api_key") + super().__init__(**kwargs) + self._api_key = kwargs.get("api_key") + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embedded texts as List[List[float]], where each inner List[float] + corresponds to a single input text. + """ + from dashscope import TextEmbedding + + # 最多支持10条,每条最长支持2048tokens + resp = TextEmbedding.call( + model=self.model_name, input=texts, api_key=self._api_key + ) + if "output" not in resp: + raise RuntimeError(resp["message"]) + + embeddings = resp["output"]["embeddings"] + sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"]) + + return [result["embedding"] for result in sorted_embeddings] + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a OpenAPI embedding model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0]