From 74ec8e52cd8a2a17db02bd031b0a89596a38ebe0 Mon Sep 17 00:00:00 2001
From: Fangyin Cheng <staneyffer@gmail.com>
Date: Tue, 5 Mar 2024 20:21:37 +0800
Subject: [PATCH] feat: APIServer supports embeddings (#1256)

---
 dbgpt/model/cluster/apiserver/api.py          |  88 ++++++--
 dbgpt/model/parameter.py                      |   3 +
 dbgpt/rag/embedding/__init__.py               |  16 ++
 dbgpt/rag/embedding/embeddings.py             | 191 +++++++++++++++++-
 dbgpt/storage/vector_store/base.py            |  33 ++-
 dbgpt/storage/vector_store/connector.py       |   6 +-
 .../advanced_usage/OpenAI_SDK_call.md         |  22 +-
 examples/rag/embedding_rag_example.py         |   4 +-
 examples/rag/rag_embedding_api_example.py     |  87 ++++++++
 9 files changed, 412 insertions(+), 38 deletions(-)
 create mode 100644 examples/rag/rag_embedding_api_example.py

diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py
index 1a508fddb..c3cdbc5b2 100644
--- a/dbgpt/model/cluster/apiserver/api.py
+++ b/dbgpt/model/cluster/apiserver/api.py
@@ -23,6 +23,8 @@
     ChatCompletionStreamResponse,
     ChatMessage,
     DeltaMessage,
+    EmbeddingsRequest,
+    EmbeddingsResponse,
     ModelCard,
     ModelList,
     ModelPermission,
@@ -51,6 +53,7 @@ def __init__(self, code: int, message: str):
 
 class APISettings(BaseModel):
     api_keys: Optional[List[str]] = None
+    embedding_bach_size: int = 4
 
 
 api_settings = APISettings()
@@ -181,27 +184,29 @@ def get_model_registry(self) -> ModelRegistry:
         return controller
 
     async def get_model_instances_or_raise(
-        self, model_name: str
+        self, model_name: str, worker_type: str = "llm"
     ) -> List[ModelInstance]:
         """Get healthy model instances with request model name
 
         Args:
             model_name (str): Model name
+            worker_type (str, optional): Worker type. Defaults to "llm".
 
         Raises:
             APIServerException: If can't get healthy model instances with request model name
         """
         registry = self.get_model_registry()
-        registry_model_name = f"{model_name}@llm"
+        suffix = f"@{worker_type}"
+        registry_model_name = f"{model_name}{suffix}"
         model_instances = await registry.get_all_instances(
             registry_model_name, healthy_only=True
         )
         if not model_instances:
             all_instances = await registry.get_all_model_instances(healthy_only=True)
             models = [
-                ins.model_name.split("@llm")[0]
+                ins.model_name.split(suffix)[0]
                 for ins in all_instances
-                if ins.model_name.endswith("@llm")
+                if ins.model_name.endswith(suffix)
             ]
             if models:
                 models = "&&".join(models)
@@ -336,6 +341,25 @@ async def chat_completion_generate(
 
         return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)
 
+    async def embeddings_generate(
+        self, model: str, texts: List[str]
+    ) -> List[List[float]]:
+        """Generate embeddings
+
+        Args:
+            model (str): Model name
+            texts (List[str]): Texts to embed
+
+        Returns:
+            List[List[float]]: The embeddings of texts
+        """
+        worker_manager: WorkerManager = self.get_worker_manager()
+        params = {
+            "input": texts,
+            "model": model,
+        }
+        return await worker_manager.embeddings(params)
+
 
 def get_api_server() -> APIServer:
     api_server = global_system_app.get_component(
@@ -389,6 +413,40 @@ async def create_chat_completion(
     return await api_server.chat_completion_generate(request.model, params, request.n)
 
 
+@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
+async def create_embeddings(
+    request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
+):
+    await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
+    texts = request.input
+    if isinstance(texts, str):
+        texts = [texts]
+    batch_size = api_settings.embedding_bach_size
+    batches = [
+        texts[i : min(i + batch_size, len(texts))]
+        for i in range(0, len(texts), batch_size)
+    ]
+    data = []
+    async_tasks = []
+    for num_batch, batch in enumerate(batches):
+        async_tasks.append(api_server.embeddings_generate(request.model, batch))
+
+    # Request all embeddings in parallel
+    batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
+    for num_batch, embeddings in enumerate(batch_embeddings):
+        data += [
+            {
+                "object": "embedding",
+                "embedding": emb,
+                "index": num_batch * batch_size + i,
+            }
+            for i, emb in enumerate(embeddings)
+        ]
+    return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
+        exclude_none=True
+    )
+
+
 def _initialize_all(controller_addr: str, system_app: SystemApp):
     from dbgpt.model.cluster.controller.controller import ModelRegistryClient
     from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
@@ -427,6 +485,7 @@ def initialize_apiserver(
     host: str = None,
     port: int = None,
     api_keys: List[str] = None,
+    embedding_batch_size: Optional[int] = None,
 ):
     global global_system_app
     global api_settings
@@ -434,13 +493,6 @@ def initialize_apiserver(
     if not app:
         embedded_mod = False
         app = FastAPI()
-        app.add_middleware(
-            CORSMiddleware,
-            allow_origins=["*"],
-            allow_credentials=True,
-            allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
-            allow_headers=["*"],
-        )
 
     if not system_app:
         system_app = SystemApp(app)
@@ -449,6 +501,9 @@ def initialize_apiserver(
     if api_keys:
         api_settings.api_keys = api_keys
 
+    if embedding_batch_size:
+        api_settings.embedding_bach_size = embedding_batch_size
+
     app.include_router(router, prefix="/api", tags=["APIServer"])
 
     @app.exception_handler(APIServerException)
@@ -464,7 +519,15 @@ async def validation_exception_handler(request, exc):
     if not embedded_mod:
         import uvicorn
 
-        uvicorn.run(app, host=host, port=port, log_level="info")
+        # https://github.com/encode/starlette/issues/617
+        cors_app = CORSMiddleware(
+            app=app,
+            allow_origins=["*"],
+            allow_credentials=True,
+            allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
+            allow_headers=["*"],
+        )
+        uvicorn.run(cors_app, host=host, port=port, log_level="info")
 
 
 def run_apiserver():
@@ -488,6 +551,7 @@ def run_apiserver():
         host=apiserver_params.host,
         port=apiserver_params.port,
         api_keys=api_keys,
+        embedding_batch_size=apiserver_params.embedding_batch_size,
     )
 
 
diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py
index 810355f0d..946d65ecf 100644
--- a/dbgpt/model/parameter.py
+++ b/dbgpt/model/parameter.py
@@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters):
         default=None,
         metadata={"help": "Optional list of comma separated API keys"},
     )
+    embedding_batch_size: Optional[int] = field(
+        default=None, metadata={"help": "Embedding batch size"}
+    )
 
     log_level: Optional[str] = field(
         default=None,
diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py
index e69de29bb..a4d6d2914 100644
--- a/dbgpt/rag/embedding/__init__.py
+++ b/dbgpt/rag/embedding/__init__.py
@@ -0,0 +1,16 @@
+from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
+from .embeddings import (
+    Embeddings,
+    HuggingFaceEmbeddings,
+    JinaEmbeddings,
+    OpenAPIEmbeddings,
+)
+
+__ALL__ = [
+    "OpenAPIEmbeddings",
+    "Embeddings",
+    "HuggingFaceEmbeddings",
+    "JinaEmbeddings",
+    "EmbeddingFactory",
+    "DefaultEmbeddingFactory",
+]
diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py
index e9fb38784..7ccc78ee5 100644
--- a/dbgpt/rag/embedding/embeddings.py
+++ b/dbgpt/rag/embedding/embeddings.py
@@ -2,8 +2,10 @@
 from abc import ABC, abstractmethod
 from typing import Any, Dict, List, Optional
 
+import aiohttp
 import requests
-from pydantic import BaseModel, Extra, Field
+
+from dbgpt._private.pydantic import BaseModel, Extra, Field
 
 DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
 DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
@@ -363,6 +365,29 @@ def embed_query(self, text: str) -> List[float]:
         return self.embed_documents([text])[0]
 
 
+def _handle_request_result(res: requests.Response) -> List[List[float]]:
+    """Parse the result from a request.
+
+    Args:
+        res: The response from the request.
+
+    Returns:
+        List[List[float]]: The embeddings.
+
+    Raises:
+        RuntimeError: If the response is not successful.
+    """
+    res.raise_for_status()
+    resp = res.json()
+    if "data" not in resp:
+        raise RuntimeError(resp["detail"])
+    embeddings = resp["data"]
+    # Sort resulting embeddings by index
+    sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])  # type: ignore
+    # Return just the embeddings
+    return [result["embedding"] for result in sorted_embeddings]
+
+
 class JinaEmbeddings(BaseModel, Embeddings):
     """
     This class is used to get embeddings for a list of texts using the Jina AI API.
@@ -406,20 +431,136 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
         # Call Jina AI Embedding API
         resp = self.session.post(  # type: ignore
             self.api_url, json={"input": texts, "model": self.model_name}
-        ).json()
-        if "data" not in resp:
-            raise RuntimeError(resp["detail"])
+        )
+        return _handle_request_result(res)
+
+    def embed_query(self, text: str) -> List[float]:
+        """Compute query embeddings using a HuggingFace transformer model.
+
+        Args:
+            text: The text to embed.
+
+        Returns:
+            Embeddings for the text.
+        """
+        return self.embed_documents([text])[0]
+
+
+class OpenAPIEmbeddings(BaseModel, Embeddings):
+    """This class is used to get embeddings for a list of texts using the API.
+
+    This API is compatible with the OpenAI Embedding API.
+
+    Examples:
+
+        Using OpenAI's API:
+        .. code-block:: python
+
+            from dbgpt.rag.embedding import OpenAPIEmbeddings
+
+            openai_embeddings = OpenAPIEmbeddings(
+                api_url="https://api.openai.com/v1/embeddings",
+                api_key="your_api_key",
+                model_name="text-embedding-3-small",
+            )
+            texts = ["Hello, world!", "How are you?"]
+            openai_embeddings.embed_documents(texts)
+
+        Using DB-GPT APIServer's embedding API:
+        To use the DB-GPT APIServer's embedding API, you should deploy DB-GPT according
+        to the `Cluster Deploy
+        <https://docs.dbgpt.site/docs/installation/model_service/cluster>`_.
+
+        A simple example:
+        1. Deploy Model Cluster with following command:
+        .. code-block:: bash
+
+            dbgpt start controller --port 8000
+
+        2. Deploy Embedding Model Worker with following command:
+        .. code-block:: bash
+
+            dbgpt start worker --model_name text2vec \
+            --model_path /app/models/text2vec-large-chinese \
+            --worker_type text2vec \
+            --port 8003 \
+            --controller_addr http://127.0.0.1:8000
+
+        3. Deploy API Server with following command:
+        .. code-block:: bash
+
+            dbgpt start apiserver --controller_addr http://127.0.0.1:8000 \
+            --api_keys my_api_token --port 8100
+
+        Now, you can use the API server to get embeddings:
+        .. code-block:: python
+
+            from dbgpt.rag.embedding import OpenAPIEmbeddings
 
-        embeddings = resp["data"]
+            openai_embeddings = OpenAPIEmbeddings(
+                api_url="http://localhost:8100/api/v1/embeddings",
+                api_key="my_api_token",
+                model_name="text2vec",
+            )
+            texts = ["Hello, world!", "How are you?"]
+            openai_embeddings.embed_documents(texts)
+
+    """
+
+    api_url: str = Field(
+        default="http://localhost:8100/api/v1/embeddings",
+        description="The URL of the embeddings API.",
+    )
+    api_key: Optional[str] = Field(
+        default=None, description="The API key for the embeddings API."
+    )
+    model_name: str = Field(
+        default="text2vec", description="The name of the model to use."
+    )
+    timeout: int = Field(
+        default=60, description="The timeout for the request in seconds."
+    )
+
+    session: requests.Session = None
+
+    class Config:
+        """Configuration for this pydantic object."""
+
+        arbitrary_types_allowed = True
+
+    def __init__(self, **kwargs):
+        """Initialize the OpenAPIEmbeddings."""
+        super().__init__(**kwargs)
+        try:
+            import requests
+        except ImportError:
+            raise ValueError(
+                "The requests python package is not installed. "
+                "Please install it with `pip install requests`"
+            )
+        self.session = requests.Session()
+        self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
+
+    def embed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Get the embeddings for a list of texts.
 
-        # Sort resulting embeddings by index
-        sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])  # type: ignore
+        Args:
+            texts (Documents): A list of texts to get embeddings for.
 
-        # Return just the embeddings
-        return [result["embedding"] for result in sorted_embeddings]
+        Returns:
+            Embedded texts as List[List[float]], where each inner List[float]
+                corresponds to a single input text.
+        """
+        # Call OpenAI Embedding API
+        res = self.session.post(  # type: ignore
+            self.api_url,
+            json={"input": texts, "model": self.model_name},
+            timeout=self.timeout,
+        )
+        return _handle_request_result(res)
 
     def embed_query(self, text: str) -> List[float]:
-        """Compute query embeddings using a HuggingFace transformer model.
+        """Compute query embeddings using a OpenAPI embedding model.
 
         Args:
             text: The text to embed.
@@ -428,3 +569,33 @@ def embed_query(self, text: str) -> List[float]:
             Embeddings for the text.
         """
         return self.embed_documents([text])[0]
+
+    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
+        """Asynchronous Embed search docs.
+
+        Args:
+            texts: A list of texts to get embeddings for.
+
+        Returns:
+            List[List[float]]: Embedded texts as List[List[float]], where each inner
+                List[float] corresponds to a single input text.
+        """
+        headers = {"Authorization": f"Bearer {self.api_key}"}
+        async with aiohttp.ClientSession(
+            headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
+        ) as session:
+            async with session.post(
+                self.api_url, json={"input": texts, "model": self.model_name}
+            ) as resp:
+                resp.raise_for_status()
+                data = await resp.json()
+                if "data" not in data:
+                    raise RuntimeError(data["detail"])
+                embeddings = data["data"]
+                sorted_embeddings = sorted(embeddings, key=lambda e: e["index"])
+                return [result["embedding"] for result in sorted_embeddings]
+
+    async def aembed_query(self, text: str) -> List[float]:
+        """Asynchronous Embed query text."""
+        embeddings = await self.aembed_documents([text])
+        return embeddings[0]
diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py
index a97fdd601..769a48b08 100644
--- a/dbgpt/storage/vector_store/base.py
+++ b/dbgpt/storage/vector_store/base.py
@@ -2,6 +2,7 @@
 import math
 import time
 from abc import ABC, abstractmethod
+from concurrent.futures import ThreadPoolExecutor
 from typing import Any, Callable, List, Optional
 
 from pydantic import BaseModel, Field
@@ -24,11 +25,13 @@ class VectorStoreConfig(BaseModel):
     )
     password: Optional[str] = Field(
         default=None,
-        description="The password of vector store, if not set, will use the default password.",
+        description="The password of vector store, if not set, will use the default "
+        "password.",
     )
     embedding_fn: Optional[Any] = Field(
         default=None,
-        description="The embedding function of vector store, if not set, will use the default embedding function.",
+        description="The embedding function of vector store, if not set, will use the "
+        "default embedding function.",
     )
     max_chunks_once_load: int = Field(
         default=10,
@@ -36,6 +39,11 @@ class VectorStoreConfig(BaseModel):
         "large, you can set this value to a larger number to speed up the loading "
         "process. Default is 10.",
     )
+    max_threads: int = Field(
+        default=1,
+        description="The max number of threads to use. Default is 1. If you set this "
+        "bigger than 1, please make sure your vector store is thread-safe.",
+    )
 
 
 class VectorStoreBase(ABC):
@@ -52,12 +60,13 @@ def load_document(self, chunks: List[Chunk]) -> List[str]:
         pass
 
     def load_document_with_limit(
-        self, chunks: List[Chunk], max_chunks_once_load: int = 10
+        self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
     ) -> List[str]:
         """load document in vector database with limit.
         Args:
             chunks: document chunks.
             max_chunks_once_load: Max number of chunks to load at once.
+            max_threads: Max number of threads to use.
         Return:
         """
         # Group the chunks into chunks of size max_chunks
@@ -65,14 +74,22 @@ def load_document_with_limit(
             chunks[i : i + max_chunks_once_load]
             for i in range(0, len(chunks), max_chunks_once_load)
         ]
-        logger.info(f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups")
+        logger.info(
+            f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
+            f"{max_threads} threads."
+        )
         ids = []
         loaded_cnt = 0
         start_time = time.time()
-        for chunk_group in chunk_groups:
-            ids.extend(self.load_document(chunk_group))
-            loaded_cnt += len(chunk_group)
-            logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
+        with ThreadPoolExecutor(max_workers=max_threads) as executor:
+            tasks = []
+            for chunk_group in chunk_groups:
+                tasks.append(executor.submit(self.load_document, chunk_group))
+            for future in tasks:
+                success_ids = future.result()
+                ids.extend(success_ids)
+                loaded_cnt += len(success_ids)
+                logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
         logger.info(
             f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
         )
diff --git a/dbgpt/storage/vector_store/connector.py b/dbgpt/storage/vector_store/connector.py
index ddcfc2f06..0c2ffa064 100644
--- a/dbgpt/storage/vector_store/connector.py
+++ b/dbgpt/storage/vector_store/connector.py
@@ -1,5 +1,5 @@
 import os
-from typing import Any, Callable, List, Optional
+from typing import Any, List, Optional
 
 from dbgpt.rag.chunk import Chunk
 from dbgpt.storage import vector_store
@@ -65,7 +65,9 @@ def load_document(self, chunks: List[Chunk]) -> List[str]:
         Return chunk ids.
         """
         return self.client.load_document_with_limit(
-            chunks, self._vector_store_config.max_chunks_once_load
+            chunks,
+            self._vector_store_config.max_chunks_once_load,
+            self._vector_store_config.max_threads,
         )
 
     def similar_search(self, doc: str, topk: int) -> List[Chunk]:
diff --git a/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md b/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md
index 25dad39b2..beaaa27bb 100644
--- a/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md
+++ b/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md
@@ -10,7 +10,7 @@ The call of multi-model services is compatible with the OpenAI interface, and th
 ## Start apiserver
 
 After deploying the model service, you need to start the API Server. By default, the model API Server uses port `8100` to start.
-```python
+```bash
 dbgpt start apiserver --controller_addr http://127.0.0.1:8000 --api_keys EMPTY
 
 ```
@@ -25,7 +25,7 @@ After the apiserver is started, the service call can be verified. First, let's l
 :::tip
 List models
 :::
-```python
+```bash
 curl http://127.0.0.1:8100/api/v1/models \
 -H "Authorization: Bearer EMPTY" \
 -H "Content-Type: application/json"
@@ -34,17 +34,31 @@ curl http://127.0.0.1:8100/api/v1/models \
 :::tip
 Chat
 :::
-```python
+```bash
 curl http://127.0.0.1:8100/api/v1/chat/completions \
 -H "Authorization: Bearer EMPTY" \
 -H "Content-Type: application/json" \
 -d '{"model": "vicuna-13b-v1.5", "messages": [{"role": "user", "content": "hello"}]}'
 ```
 
+:::tip
+Embedding 
+:::
+```bash
+curl http://127.0.0.1:8100/api/v1/embeddings \
+-H "Authorization: Bearer EMPTY" \
+-H "Content-Type: application/json" \
+-d '{
+    "model": "text2vec",
+    "input": "Hello world!"
+}'
+```
+
+
 
 ## Verify via OpenAI SDK
 
-```python
+```bash
 import openai
 openai.api_key = "EMPTY"
 openai.api_base = "http://127.0.0.1:8100/api/v1"
diff --git a/examples/rag/embedding_rag_example.py b/examples/rag/embedding_rag_example.py
index ef2b5c591..de493b59a 100644
--- a/examples/rag/embedding_rag_example.py
+++ b/examples/rag/embedding_rag_example.py
@@ -1,7 +1,7 @@
 import asyncio
 import os
 
-from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
+from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH, ROOT_PATH
 from dbgpt.rag.chunk_manager import ChunkParameters
 from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddingFactory
 from dbgpt.rag.knowledge.factory import KnowledgeFactory
@@ -37,7 +37,7 @@ def _create_vector_connector():
 
 
 async def main():
-    file_path = "docs/docs/awel.md"
+    file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
     knowledge = KnowledgeFactory.from_file_path(file_path)
     vector_connector = _create_vector_connector()
     chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
diff --git a/examples/rag/rag_embedding_api_example.py b/examples/rag/rag_embedding_api_example.py
new file mode 100644
index 000000000..b5f669ed8
--- /dev/null
+++ b/examples/rag/rag_embedding_api_example.py
@@ -0,0 +1,87 @@
+"""A RAG example using the OpenAPIEmbeddings.
+
+Example:
+
+    Test with `OpenAI embeddings
+    <https://platform.openai.com/docs/api-reference/embeddings/create>`_.
+
+    .. code-block:: shell
+
+        export API_SERVER_BASE_URL=${OPENAI_API_BASE:-"https://api.openai.com/v1"}
+        export API_SERVER_API_KEY="${OPENAI_API_KEY}"
+        export API_SERVER_EMBEDDINGS_MODEL="text-embedding-ada-002"
+        python examples/rag/rag_embedding_api_example.py
+
+    Test with DB-GPT `API Server
+    <https://docs.dbgpt.site/docs/installation/advanced_usage/OpenAI_SDK_call#start-apiserver>`_.
+
+    .. code-block:: shell
+        export API_SERVER_BASE_URL="http://localhost:8100/api/v1"
+        export API_SERVER_API_KEY="your_api_key"
+        export API_SERVER_EMBEDDINGS_MODEL="text2vec"
+        python examples/rag/rag_embedding_api_example.py
+
+"""
+import asyncio
+import os
+from typing import Optional
+
+from dbgpt.configs.model_config import PILOT_PATH, ROOT_PATH
+from dbgpt.rag.chunk_manager import ChunkParameters
+from dbgpt.rag.embedding import OpenAPIEmbeddings
+from dbgpt.rag.knowledge.factory import KnowledgeFactory
+from dbgpt.serve.rag.assembler.embedding import EmbeddingAssembler
+from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
+from dbgpt.storage.vector_store.connector import VectorStoreConnector
+
+
+def _create_embeddings(
+    api_url: str = None, api_key: Optional[str] = None, model_name: Optional[str] = None
+) -> OpenAPIEmbeddings:
+    if not api_url:
+        api_server_base_url = os.getenv(
+            "API_SERVER_BASE_URL", "http://localhost:8100/api/v1/"
+        )
+        api_url = f"{api_server_base_url}/embeddings"
+    if not api_key:
+        api_key = os.getenv("API_SERVER_API_KEY")
+
+    if not model_name:
+        model_name = os.getenv("API_SERVER_EMBEDDINGS_MODEL", "text2vec")
+
+    return OpenAPIEmbeddings(api_url=api_url, api_key=api_key, model_name=model_name)
+
+
+def _create_vector_connector():
+    """Create vector connector."""
+
+    return VectorStoreConnector.from_default(
+        "Chroma",
+        vector_store_config=ChromaVectorConfig(
+            name="example_embedding_api_vector_store_name",
+            persist_path=os.path.join(PILOT_PATH, "data"),
+        ),
+        embedding_fn=_create_embeddings(),
+    )
+
+
+async def main():
+    file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md")
+    knowledge = KnowledgeFactory.from_file_path(file_path)
+    vector_connector = _create_vector_connector()
+    chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
+    # get embedding assembler
+    assembler = EmbeddingAssembler.load_from_knowledge(
+        knowledge=knowledge,
+        chunk_parameters=chunk_parameters,
+        vector_store_connector=vector_connector,
+    )
+    assembler.persist()
+    # get embeddings retriever
+    retriever = assembler.as_retriever(3)
+    chunks = await retriever.aretrieve_with_scores("what is awel talk about", 0.3)
+    print(f"embedding rag example results:{chunks}")
+
+
+if __name__ == "__main__":
+    asyncio.run(main())