Skip to content

Commit

Permalink
feat: APIServer supports embeddings (#1256)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Mar 5, 2024
1 parent 5f3ee35 commit 74ec8e5
Show file tree
Hide file tree
Showing 9 changed files with 412 additions and 38 deletions.
88 changes: 76 additions & 12 deletions dbgpt/model/cluster/apiserver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
EmbeddingsRequest,
EmbeddingsResponse,
ModelCard,
ModelList,
ModelPermission,
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(self, code: int, message: str):

class APISettings(BaseModel):
api_keys: Optional[List[str]] = None
embedding_bach_size: int = 4


api_settings = APISettings()
Expand Down Expand Up @@ -181,27 +184,29 @@ def get_model_registry(self) -> ModelRegistry:
return controller

async def get_model_instances_or_raise(
self, model_name: str
self, model_name: str, worker_type: str = "llm"
) -> List[ModelInstance]:
"""Get healthy model instances with request model name
Args:
model_name (str): Model name
worker_type (str, optional): Worker type. Defaults to "llm".
Raises:
APIServerException: If can't get healthy model instances with request model name
"""
registry = self.get_model_registry()
registry_model_name = f"{model_name}@llm"
suffix = f"@{worker_type}"
registry_model_name = f"{model_name}{suffix}"
model_instances = await registry.get_all_instances(
registry_model_name, healthy_only=True
)
if not model_instances:
all_instances = await registry.get_all_model_instances(healthy_only=True)
models = [
ins.model_name.split("@llm")[0]
ins.model_name.split(suffix)[0]
for ins in all_instances
if ins.model_name.endswith("@llm")
if ins.model_name.endswith(suffix)
]
if models:
models = "&&".join(models)
Expand Down Expand Up @@ -336,6 +341,25 @@ async def chat_completion_generate(

return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)

async def embeddings_generate(
self, model: str, texts: List[str]
) -> List[List[float]]:
"""Generate embeddings
Args:
model (str): Model name
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
}
return await worker_manager.embeddings(params)


def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
Expand Down Expand Up @@ -389,6 +413,40 @@ async def create_chat_completion(
return await api_server.chat_completion_generate(request.model, params, request.n)


@router.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(
request: EmbeddingsRequest, api_server: APIServer = Depends(get_api_server)
):
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
texts = request.input
if isinstance(texts, str):
texts = [texts]
batch_size = api_settings.embedding_bach_size
batches = [
texts[i : min(i + batch_size, len(texts))]
for i in range(0, len(texts), batch_size)
]
data = []
async_tasks = []
for num_batch, batch in enumerate(batches):
async_tasks.append(api_server.embeddings_generate(request.model, batch))

# Request all embeddings in parallel
batch_embeddings: List[List[List[float]]] = await asyncio.gather(*async_tasks)
for num_batch, embeddings in enumerate(batch_embeddings):
data += [
{
"object": "embedding",
"embedding": emb,
"index": num_batch * batch_size + i,
}
for i, emb in enumerate(embeddings)
]
return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict(
exclude_none=True
)


def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
Expand Down Expand Up @@ -427,20 +485,14 @@ def initialize_apiserver(
host: str = None,
port: int = None,
api_keys: List[str] = None,
embedding_batch_size: Optional[int] = None,
):
global global_system_app
global api_settings
embedded_mod = True
if not app:
embedded_mod = False
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)

if not system_app:
system_app = SystemApp(app)
Expand All @@ -449,6 +501,9 @@ def initialize_apiserver(
if api_keys:
api_settings.api_keys = api_keys

if embedding_batch_size:
api_settings.embedding_bach_size = embedding_batch_size

app.include_router(router, prefix="/api", tags=["APIServer"])

@app.exception_handler(APIServerException)
Expand All @@ -464,7 +519,15 @@ async def validation_exception_handler(request, exc):
if not embedded_mod:
import uvicorn

uvicorn.run(app, host=host, port=port, log_level="info")
# https://github.com/encode/starlette/issues/617
cors_app = CORSMiddleware(
app=app,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
uvicorn.run(cors_app, host=host, port=port, log_level="info")


def run_apiserver():
Expand All @@ -488,6 +551,7 @@ def run_apiserver():
host=apiserver_params.host,
port=apiserver_params.port,
api_keys=api_keys,
embedding_batch_size=apiserver_params.embedding_batch_size,
)


Expand Down
3 changes: 3 additions & 0 deletions dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class ModelAPIServerParameters(BaseParameters):
default=None,
metadata={"help": "Optional list of comma separated API keys"},
)
embedding_batch_size: Optional[int] = field(
default=None, metadata={"help": "Embedding batch size"}
)

log_level: Optional[str] = field(
default=None,
Expand Down
16 changes: 16 additions & 0 deletions dbgpt/rag/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .embedding_factory import DefaultEmbeddingFactory, EmbeddingFactory
from .embeddings import (
Embeddings,
HuggingFaceEmbeddings,
JinaEmbeddings,
OpenAPIEmbeddings,
)

__ALL__ = [
"OpenAPIEmbeddings",
"Embeddings",
"HuggingFaceEmbeddings",
"JinaEmbeddings",
"EmbeddingFactory",
"DefaultEmbeddingFactory",
]
Loading

0 comments on commit 74ec8e5

Please sign in to comment.