From 74ec8e52cd8a2a17db02bd031b0a89596a38ebe0 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng <staneyffer@gmail.com> Date: Tue, 5 Mar 2024 20:21:37 +0800 Subject: [PATCH] feat: APIServer supports embeddings (#1256) --- dbgpt/model/cluster/apiserver/api.py | 88 ++++++-- dbgpt/model/parameter.py | 3 + dbgpt/rag/embedding/__init__.py | 16 ++ dbgpt/rag/embedding/embeddings.py | 191 +++++++++++++++++- dbgpt/storage/vector_store/base.py | 33 ++- dbgpt/storage/vector_store/connector.py | 6 +- .../advanced_usage/OpenAI_SDK_call.md | 22 +- examples/rag/embedding_rag_example.py | 4 +- examples/rag/rag_embedding_api_example.py | 87 ++++++++ 9 files changed, 412 insertions(+), 38 deletions(-) create mode 100644 examples/rag/rag_embedding_api_example.py diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index 1a508fddb..c3cdbc5b2 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -23,6 +23,8 @@ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, ModelCard, ModelList, ModelPermission, @@ -51,6 +53,7 @@ def __init__(self, code: int, message: str): class APISettings(BaseModel): api_keys: Optional[List[str]] = None + embedding_bach_size: int = 4 api_settings = APISettings() @@ -181,27 +184,29 @@ def get_model_registry(self) -> ModelRegistry: return controller async def get_model_instances_or_raise( - self, model_name: str + self, model_name: str, worker_type: str = "llm" ) -> List[ModelInstance]: """Get healthy model instances with request model name Args: model_name (str): Model name + worker_type (str, optional): Worker type. Defaults to "llm". Raises: APIServerException: If can't get healthy model instances with request model name """ registry = self.get_model_registry() - registry_model_name = f"{model_name}@llm" + suffix = f"@{worker_type}" + registry_model_name = f"{model_name}{suffix}" model_instances = await registry.get_all_instances( registry_model_name, healthy_only=True ) if not model_instances: all_instances = await registry.get_all_model_instances(healthy_only=True) models = [ - ins.model_name.split("@llm")[0] + ins.model_name.split(suffix)[0] for ins in all_instances - if ins.model_name.endswith("@llm") + if ins.model_name.endswith(suffix) ] if models: models = "&&".join(models) @@ -336,6 +341,25 @@ async def chat_completion_generate( return ChatCompletionResponse(model=model_name, choices=choices, usage=usage) + async def embeddings_generate( + self, model: str, texts: List[str] + ) -> List[List[float]]: + """Generate embeddings + + Args: + model (str): Model name + texts (List[str]): Texts to embed + + Returns: + List[List[float]]: The embeddings of texts + """ + worker_manager: WorkerManager = self.get_worker_manager() + params = { + "input": texts, + "model": model, + } + return await worker_manager.embeddings(params) + def get_api_server() -> APIServer: api_server = global_system_app.get_component( @@ -389,6 +413,40 @@ async def create_chat_completion( return await api_server.chat_completion_generate(request.model, params, request.n) +@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) +async def create_embeddings( + request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server) +): + await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec") + texts = request.input + if isinstance(texts, str): + texts = [texts] + batch_size = api_settings.embedding_bach_size + batches = [ + texts[i : min(i + batch_size, len(texts))] + for i in range(0, len(texts), batch_size) + ] + data = [] + async_tasks = [] + for num_batch, batch in enumerate(batches): + async_tasks.append(api_server.embeddings_generate(request.model, batch)) + + # Request all embeddings in parallel + batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks) + for num_batch, embeddings in enumerate(batch_embeddings): + data += [ + { + "object": "embedding", + "embedding": emb, + "index": num_batch * batch_size + i, + } + for i, emb in enumerate(embeddings) + ] + return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict( + exclude_none=True + ) + + def _initialize_all(controller_addr: str, system_app: SystemApp): from dbgpt.model.cluster.controller.controller import ModelRegistryClient from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory @@ -427,6 +485,7 @@ def initialize_apiserver( host: str = None, port: int = None, api_keys: List[str] = None, + embedding_batch_size: Optional[int] = None, ): global global_system_app global api_settings @@ -434,13 +493,6 @@ def initialize_apiserver( if not app: embedded_mod = False app = FastAPI() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], - allow_headers=["*"], - ) if not system_app: system_app = SystemApp(app) @@ -449,6 +501,9 @@ def initialize_apiserver( if api_keys: api_settings.api_keys = api_keys + if embedding_batch_size: + api_settings.embedding_bach_size = embedding_batch_size + app.include_router(router, prefix="/api", tags=["APIServer"]) @app.exception_handler(APIServerException) @@ -464,7 +519,15 @@ async def validation_exception_handler(request, exc): if not embedded_mod: import uvicorn - uvicorn.run(app, host=host, port=port, log_level="info") + # https://github.com/encode/starlette/issues/617 + cors_app = CORSMiddleware( + app=app, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], + ) + uvicorn.run(cors_app, host=host, port=port, log_level="info") def run_apiserver(): @@ -488,6 +551,7 @@ def run_apiserver(): host=apiserver_params.host, port=apiserver_params.port, api_keys=api_keys, + embedding_batch_size=apiserver_params.embedding_batch_size, ) diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 810355f0d..946d65ecf 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters): default=None, metadata={"help": "Optional list of comma separated API keys"}, ) + embedding_batch_size: Optional[int] = field( + default=None, metadata={"help": "Embedding batch size"} + ) log_level: Optional[str] = field( default=None, diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index e69de29bb..a4d6d2914 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -0,0 +1,16 @@ +from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory +from .embeddings import ( + Embeddings, + HuggingFaceEmbeddings, + JinaEmbeddings, + OpenAPIEmbeddings, +) + +__ALL__ = [ + "OpenAPIEmbeddings", + "Embeddings", + "HuggingFaceEmbeddings", + "JinaEmbeddings", + "EmbeddingFactory", + "DefaultEmbeddingFactory", +] diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index e9fb38784..7ccc78ee5 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -2,8 +2,10 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional +import aiohttp import requests -from pydantic import BaseModel, Extra, Field + +from dbgpt._private.pydantic import BaseModel, Extra, Field DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" @@ -363,6 +365,29 @@ def embed_query(self, text: str) -> List[float]: return self.embed_documents([text])[0] +def _handle_request_result(res: requests.Response) -> List[List[float]]: + """Parse the result from a request. + + Args: + res: The response from the request. + + Returns: + List[List[float]]: The embeddings. + + Raises: + RuntimeError: If the response is not successful. + """ + res.raise_for_status() + resp = res.json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + embeddings = resp["data"] + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] + + class JinaEmbeddings(BaseModel, Embeddings): """ This class is used to get embeddings for a list of texts using the Jina AI API. @@ -406,20 +431,136 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: # Call Jina AI Embedding API resp = self.session.post( # type: ignore self.api_url, json={"input": texts, "model": self.model_name} - ).json() - if "data" not in resp: - raise RuntimeError(resp["detail"]) + ) + return _handle_request_result(res) + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] + + +class OpenAPIEmbeddings(BaseModel, Embeddings): + """This class is used to get embeddings for a list of texts using the API. + + This API is compatible with the OpenAI Embedding API. + + Examples: + + Using OpenAI's API: + .. code-block:: python + + from dbgpt.rag.embedding import OpenAPIEmbeddings + + openai_embeddings = OpenAPIEmbeddings( + api_url="https://api.openai.com/v1/embeddings", + api_key="your_api_key", + model_name="text-embedding-3-small", + ) + texts = ["Hello, world!", "How are you?"] + openai_embeddings.embed_documents(texts) + + Using DB-GPT APIServer's embedding API: + To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according + to the `Cluster Deploy + <https://docs.dbgpt.site/docs/installation/model_service/cluster>`_. + + A simple example: + 1. Deploy Model Cluster with following command: + .. code-block:: bash + + dbgpt start controller --port 8000 + + 2. Deploy Embedding Model Worker with following command: + .. code-block:: bash + + dbgpt start worker --model_name text2vec \ + --model_path /app/models/text2vec-large-chinese \ + --worker_type text2vec \ + --port 8003 \ + --controller_addr http://127.0.0.1:8000 + + 3. Deploy API Server with following command: + .. code-block:: bash + + dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \ + --api_keys my_api_token --port 8100 + + Now, you can use the API server to get embeddings: + .. code-block:: python + + from dbgpt.rag.embedding import OpenAPIEmbeddings - embeddings = resp["data"] + openai_embeddings = OpenAPIEmbeddings( + api_url="http://localhost:8100/api/v1/embeddings", + api_key="my_api_token", + model_name="text2vec", + ) + texts = ["Hello, world!", "How are you?"] + openai_embeddings.embed_documents(texts) + + """ + + api_url: str = Field( + default="http://localhost:8100/api/v1/embeddings", + description="The URL of the embeddings API.", + ) + api_key: Optional[str] = Field( + default=None, description="The API key for the embeddings API." + ) + model_name: str = Field( + default="text2vec", description="The name of the model to use." + ) + timeout: int = Field( + default=60, description="The timeout for the request in seconds." + ) + + session: requests.Session = None + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def __init__(self, **kwargs): + """Initialize the OpenAPIEmbeddings.""" + super().__init__(**kwargs) + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. " + "Please install it with `pip install requests`" + ) + self.session = requests.Session() + self.session.headers.update({"Authorization": f"Bearer {self.api_key}"}) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Get the embeddings for a list of texts. - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore + Args: + texts (Documents): A list of texts to get embeddings for. - # Return just the embeddings - return [result["embedding"] for result in sorted_embeddings] + Returns: + Embedded texts as List[List[float]], where each inner List[float] + corresponds to a single input text. + """ + # Call OpenAI Embedding API + res = self.session.post( # type: ignore + self.api_url, + json={"input": texts, "model": self.model_name}, + timeout=self.timeout, + ) + return _handle_request_result(res) def embed_query(self, text: str) -> List[float]: - """Compute query embeddings using a HuggingFace transformer model. + """Compute query embeddings using a OpenAPI embedding model. Args: text: The text to embed. @@ -428,3 +569,33 @@ def embed_query(self, text: str) -> List[float]: Embeddings for the text. """ return self.embed_documents([text])[0] + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronous Embed search docs. + + Args: + texts: A list of texts to get embeddings for. + + Returns: + List[List[float]]: Embedded texts as List[List[float]], where each inner + List[float] corresponds to a single input text. + """ + headers = {"Authorization": f"Bearer {self.api_key}"} + async with aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as session: + async with session.post( + self.api_url, json={"input": texts, "model": self.model_name} + ) as resp: + resp.raise_for_status() + data = await resp.json() + if "data" not in data: + raise RuntimeError(data["detail"]) + embeddings = data["data"] + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) + return [result["embedding"] for result in sorted_embeddings] + + async def aembed_query(self, text: str) -> List[float]: + """Asynchronous Embed query text.""" + embeddings = await self.aembed_documents([text]) + return embeddings[0] diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index a97fdd601..769a48b08 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -2,6 +2,7 @@ import math import time from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, List, Optional from pydantic import BaseModel, Field @@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel): ) password: Optional[str] = Field( default=None, - description="The password of vector store, if not set, will use the default password.", + description="The password of vector store, if not set, will use the default " + "password.", ) embedding_fn: Optional[Any] = Field( default=None, - description="The embedding function of vector store, if not set, will use the default embedding function.", + description="The embedding function of vector store, if not set, will use the " + "default embedding function.", ) max_chunks_once_load: int = Field( default=10, @@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel): "large, you can set this value to a larger number to speed up the loading " "process. Default is 10.", ) + max_threads: int = Field( + default=1, + description="The max number of threads to use. Default is 1. If you set this " + "bigger than 1, please make sure your vector store is thread-safe.", + ) class VectorStoreBase(ABC): @@ -52,12 +60,13 @@ def load_document(self, chunks: List[Chunk]) -> List[str]: pass def load_document_with_limit( - self, chunks: List[Chunk], max_chunks_once_load: int = 10 + self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1 ) -> List[str]: """load document in vector database with limit. Args: chunks: document chunks. max_chunks_once_load: Max number of chunks to load at once. + max_threads: Max number of threads to use. Return: """ # Group the chunks into chunks of size max_chunks @@ -65,14 +74,22 @@ def load_document_with_limit( chunks[i : i + max_chunks_once_load] for i in range(0, len(chunks), max_chunks_once_load) ] - logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups") + logger.info( + f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with " + f"{max_threads} threads." + ) ids = [] loaded_cnt = 0 start_time = time.time() - for chunk_group in chunk_groups: - ids.extend(self.load_document(chunk_group)) - loaded_cnt += len(chunk_group) - logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.") + with ThreadPoolExecutor(max_workers=max_threads) as executor: + tasks = [] + for chunk_group in chunk_groups: + tasks.append(executor.submit(self.load_document, chunk_group)) + for future in tasks: + success_ids = future.result() + ids.extend(success_ids) + loaded_cnt += len(success_ids) + logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.") logger.info( f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds" ) diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py index ddcfc2f06..0c2ffa064 100644 --- a/dbgpt/storage/vector_store/connector.py +++ b/dbgpt/storage/vector_store/connector.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, List, Optional +from typing import Any, List, Optional from dbgpt.rag.chunk import Chunk from dbgpt.storage import vector_store @@ -65,7 +65,9 @@ def load_document(self, chunks: List[Chunk]) -> List[str]: Return chunk ids. """ return self.client.load_document_with_limit( - chunks, self._vector_store_config.max_chunks_once_load + chunks, + self._vector_store_config.max_chunks_once_load, + self._vector_store_config.max_threads, ) def similar_search(self, doc: str, topk: int) -> List[Chunk]: diff --git a/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md b/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md index 25dad39b2..beaaa27bb 100644 --- a/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md +++ b/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md @@ -10,7 +10,7 @@ The call of multi-model services is compatible with the OpenAI interface, and th ## Start apiserver After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start. -```python +```bash dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY ``` @@ -25,7 +25,7 @@ After the apiserver is started, the service call can be verified. First, let's l :::tip List models ::: -```python +```bash curl http://127.0.0.1:8100/api/v1/models \ -H "Authorization: Bearer EMPTY" \ -H "Content-Type: application/json" @@ -34,17 +34,31 @@ curl http://127.0.0.1:8100/api/v1/models \ :::tip Chat ::: -```python +```bash curl http://127.0.0.1:8100/api/v1/chat/completions \ -H "Authorization: Bearer EMPTY" \ -H "Content-Type: application/json" \ -d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}' ``` +:::tip +Embedding +::: +```bash +curl http://127.0.0.1:8100/api/v1/embeddings \ +-H "Authorization: Bearer EMPTY" \ +-H "Content-Type: application/json" \ +-d '{ + "model": "text2vec", + "input": "Hello world!" +}' +``` + + ## Verify via OpenAI SDK -```python +```bash import openai openai.api_key = "EMPTY" openai.api_base = "http://127.0.0.1:8100/api/v1" diff --git a/examples/rag/embedding_rag_example.py b/examples/rag/embedding_rag_example.py index ef2b5c591..de493b59a 100644 --- a/examples/rag/embedding_rag_example.py +++ b/examples/rag/embedding_rag_example.py @@ -1,7 +1,7 @@ import asyncio import os -from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH +from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory @@ -37,7 +37,7 @@ def _create_vector_connector(): async def main(): - file_path = "docs/docs/awel.md" + file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") knowledge = KnowledgeFactory.from_file_path(file_path) vector_connector = _create_vector_connector() chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") diff --git a/examples/rag/rag_embedding_api_example.py b/examples/rag/rag_embedding_api_example.py new file mode 100644 index 000000000..b5f669ed8 --- /dev/null +++ b/examples/rag/rag_embedding_api_example.py @@ -0,0 +1,87 @@ +"""A RAG example using the OpenAPIEmbeddings. + +Example: + + Test with `OpenAI embeddings + <https://platform.openai.com/docs/api-reference/embeddings/create>`_. + + .. code-block:: shell + + export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"} + export API_SERVER_API_KEY="${OPENAI_API_KEY}" + export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002" + python examples/rag/rag_embedding_api_example.py + + Test with DB-GPT `API Server + <https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_. + + .. code-block:: shell + export API_SERVER_BASE_URL="http://localhost:8100/api/v1" + export API_SERVER_API_KEY="your_api_key" + export API_SERVER_EMBEDDINGS_MODEL="text2vec" + python examples/rag/rag_embedding_api_example.py + +""" +import asyncio +import os +from typing import Optional + +from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH +from dbgpt.rag.chunk_manager import ChunkParameters +from dbgpt.rag.embedding import OpenAPIEmbeddings +from dbgpt.rag.knowledge.factory import KnowledgeFactory +from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler +from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt.storage.vector_store.connector import VectorStoreConnector + + +def _create_embeddings( + api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None +) -> OpenAPIEmbeddings: + if not api_url: + api_server_base_url = os.getenv( + "API_SERVER_BASE_URL", "http://localhost:8100/api/v1/" + ) + api_url = f"{api_server_base_url}/embeddings" + if not api_key: + api_key = os.getenv("API_SERVER_API_KEY") + + if not model_name: + model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec") + + return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name) + + +def _create_vector_connector(): + """Create vector connector.""" + + return VectorStoreConnector.from_default( + "Chroma", + vector_store_config=ChromaVectorConfig( + name="example_embedding_api_vector_store_name", + persist_path=os.path.join(PILOT_PATH, "data"), + ), + embedding_fn=_create_embeddings(), + ) + + +async def main(): + file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") + knowledge = KnowledgeFactory.from_file_path(file_path) + vector_connector = _create_vector_connector() + chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") + # get embedding assembler + assembler = EmbeddingAssembler.load_from_knowledge( + knowledge=knowledge, + chunk_parameters=chunk_parameters, + vector_store_connector=vector_connector, + ) + assembler.persist() + # get embeddings retriever + retriever = assembler.as_retriever(3) + chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3) + print(f"embedding rag example results:{chunks}") + + +if __name__ == "__main__": + asyncio.run(main())