Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support chat models in dstack-proxy #1953

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ httpx>=0.23
pytest~=7.2
pytest-asyncio>=0.21
pytest-httpbin==2.1.0
openai>=1.53.0,<2.0.0
freezegun>=1.2.0
ruff==0.5.3 # Should match .pre-commit-config.yaml
testcontainers # testcontainers<4 may not work with asyncpg
6 changes: 5 additions & 1 deletion src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel

# TODO(#1595): refactor into different modules: gateway-specific and proxy-specific


class GatewayStatus(str, Enum):
SUBMITTED = "submitted"
Expand Down Expand Up @@ -111,6 +113,8 @@ class GatewayProvisioningData(CoreModel):


class BaseChatModel(CoreModel):
# Adding more model types might require rethinking this class,
# since pydantic doesn't work with two discriminators (type and format) at once
type: Annotated[Literal["chat"], Field(description="The type of the model")]
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[
Expand Down Expand Up @@ -151,4 +155,4 @@ class OpenAIChatModel(BaseChatModel):


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Annotated[Union[ChatModel], Field(discriminator="type")] # embeddings and etc.
AnyModel = Union[ChatModel] # embeddings and etc.
9 changes: 5 additions & 4 deletions src/dstack/_internal/proxy/deps.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional

from fastapi import Depends, HTTPException, Request, Security, status
from fastapi import Depends, Request, Security, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing_extensions import Annotated

from dstack._internal.proxy.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.repos.base import BaseProxyRepo


Expand All @@ -26,7 +27,7 @@ async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
async def get_injector(request: Request) -> BaseProxyDependencyInjector:
injector = request.app.state.proxy_dependency_injector
if not isinstance(injector, BaseProxyDependencyInjector):
raise RuntimeError(f"Wrong BaseProxyDependencyInjector type {type(injector)}")
raise UnexpectedProxyError(f"Unexpected proxy_dependency_injector type {type(injector)}")
return injector


Expand All @@ -47,9 +48,9 @@ async def enforce(self) -> None:
if self._token is None or not await self._repo.is_project_member(
self._project_name, self._token
):
raise HTTPException(
status.HTTP_403_FORBIDDEN,
raise ProxyError(
f"Unauthenticated or unauthorized to access project {self._project_name}",
status.HTTP_403_FORBIDDEN,
)


Expand Down
14 changes: 14 additions & 0 deletions src/dstack/_internal/proxy/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastapi import HTTPException, status


class ProxyError(HTTPException):
"""Errors in dstack-proxy that are caused by and should be reported to the user"""

def __init__(self, detail: str, code: int = status.HTTP_400_BAD_REQUEST) -> None:
super().__init__(detail=detail, status_code=code)


class UnexpectedProxyError(RuntimeError):
"""Internal errors in dstack-proxy that should have never happened"""

pass
39 changes: 37 additions & 2 deletions src/dstack/_internal/proxy/repos/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from datetime import datetime
from typing import List, Literal, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from dstack._internal.core.models.instances import SSHConnectionParams

Expand All @@ -26,6 +28,27 @@ class Project(BaseModel):
ssh_private_key: str


class TGIChatModelFormat(BaseModel):
format: Literal["tgi"]
chat_template: str
eos_token: str


class OpenAIChatModelFormat(BaseModel):
format: Literal["openai"]
prefix: str


AnyModelFormat = Union[TGIChatModelFormat, OpenAIChatModelFormat]


class ChatModel(BaseModel):
name: str
created_at: datetime
run_name: str
format_spec: Annotated[AnyModelFormat, Field(discriminator="format")]


class BaseProxyRepo(ABC):
@abstractmethod
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
Expand All @@ -35,6 +58,18 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
async def add_service(self, project_name: str, service: Service) -> None:
pass

@abstractmethod
async def list_models(self, project_name: str) -> List[ChatModel]:
pass

@abstractmethod
async def get_model(self, project_name: str, name: str) -> Optional[ChatModel]:
pass

@abstractmethod
async def add_model(self, project_name: str, model: ChatModel) -> None:
pass

@abstractmethod
async def get_project(self, name: str) -> Optional[Project]:
pass
Expand Down
14 changes: 12 additions & 2 deletions src/dstack/_internal/proxy/repos/memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, Optional
from typing import Dict, List, Optional

from dstack._internal.proxy.repos.base import BaseProxyRepo, Project, Service
from dstack._internal.proxy.repos.base import BaseProxyRepo, ChatModel, Project, Service


class InMemoryProxyRepo(BaseProxyRepo):
def __init__(self) -> None:
self.services: Dict[str, Dict[str, Service]] = {}
self.models: Dict[str, Dict[str, ChatModel]] = {}
self.projects: Dict[str, Project] = {}

async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
Expand All @@ -14,6 +15,15 @@ async def get_service(self, project_name: str, run_name: str) -> Optional[Servic
async def add_service(self, project_name: str, service: Service) -> None:
self.services.setdefault(project_name, {})[service.run_name] = service

async def list_models(self, project_name: str) -> List[ChatModel]:
return list(self.models.get(project_name, {}).values())

async def get_model(self, project_name: str, name: str) -> Optional[ChatModel]:
return self.models.get(project_name, {}).get(name)

async def add_model(self, project_name: str, model: ChatModel) -> None:
self.models.setdefault(project_name, {})[model.name] = model

async def get_project(self, name: str) -> Optional[Project]:
return self.projects.get(name)

Expand Down
67 changes: 67 additions & 0 deletions src/dstack/_internal/proxy/routers/model_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import AsyncIterator

from fastapi import APIRouter, Depends, status
from fastapi.responses import StreamingResponse
from typing_extensions import Annotated

from dstack._internal.proxy.deps import ProxyAuth, get_proxy_repo
from dstack._internal.proxy.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.repos.base import BaseProxyRepo
from dstack._internal.proxy.schemas.model_proxy import (
ChatCompletionsChunk,
ChatCompletionsRequest,
ChatCompletionsResponse,
Model,
ModelsResponse,
)
from dstack._internal.proxy.services.model_proxy import get_chat_client
from dstack._internal.proxy.services.service_connection import get_service_replica_client

router = APIRouter(dependencies=[Depends(ProxyAuth(auto_enforce=True))])


@router.get("/{project_name}/models")
async def get_models(
project_name: str, repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)]
) -> ModelsResponse:
models = await repo.list_models(project_name)
data = [
Model(id=m.name, created=int(m.created_at.timestamp()), owned_by=project_name)
for m in models
]
return ModelsResponse(data=data)


@router.post("/{project_name}/chat/completions", response_model=ChatCompletionsResponse)
async def post_chat_completions(
project_name: str,
body: ChatCompletionsRequest,
repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)],
):
model = await repo.get_model(project_name, body.model)
if model is None:
raise ProxyError(
f"Model {body.model} not found in project {project_name}", status.HTTP_404_NOT_FOUND
)
service = await repo.get_service(project_name, model.run_name)
if service is None or not service.replicas:
raise UnexpectedProxyError(
f"Model {model.name} in project {project_name} references run {model.run_name}"
" that does not exist or has no replicas"
)
http_client = await get_service_replica_client(project_name, service, repo)
client = get_chat_client(model, http_client)
if not body.stream:
return await client.generate(body)
else:
return StreamingResponse(
stream_chunks(client.stream(body)),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)


async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
async for chunk in chunks:
yield f"data:{chunk.json()}\n\n".encode()
yield "data: [DONE]\n\n".encode()
Empty file.
79 changes: 79 additions & 0 deletions src/dstack/_internal/proxy/schemas/model_proxy.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from dstack-gateway with minor adjustments

Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Dict, List, Literal, Optional, Union

from dstack._internal.core.models.common import CoreModel

FinishReason = Literal["stop", "length", "tool_calls", "eos_token"]


class ChatMessage(CoreModel):
role: str # TODO(egor-s) types
content: str


class ChatCompletionsRequest(CoreModel):
messages: List[ChatMessage]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Dict[str, float] = {}
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: float = 0.0
response_format: Optional[Dict] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: bool = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
tools: List[Any] = []
tool_choice: Union[Literal["none", "auto"], Dict] = {}
user: Optional[str] = None


class ChatCompletionsChoice(CoreModel):
finish_reason: FinishReason
index: int
message: ChatMessage


class ChatCompletionsChunkChoice(CoreModel):
delta: object
logprobs: object = {}
finish_reason: Optional[FinishReason]
index: int


class ChatCompletionsUsage(CoreModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int


class ChatCompletionsResponse(CoreModel):
id: str
choices: List[ChatCompletionsChoice]
created: int
model: str
system_fingerprint: str = ""
object: Literal["chat.completion"] = "chat.completion"
usage: ChatCompletionsUsage


class ChatCompletionsChunk(CoreModel):
id: str
choices: List[ChatCompletionsChunkChoice]
created: int
model: str
system_fingerprint: str = ""
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"


class Model(CoreModel):
object: Literal["model"] = "model"
id: str
created: int
owned_by: str


class ModelsResponse(CoreModel):
object: Literal["list"] = "list"
data: List[Model]
23 changes: 23 additions & 0 deletions src/dstack/_internal/proxy/services/model_proxy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import httpx

from dstack._internal.proxy.errors import UnexpectedProxyError
from dstack._internal.proxy.repos.base import ChatModel
from dstack._internal.proxy.services.model_proxy.clients import ChatCompletionsClient
from dstack._internal.proxy.services.model_proxy.clients.openai import OpenAIChatCompletions
from dstack._internal.proxy.services.model_proxy.clients.tgi import TGIChatCompletions


def get_chat_client(model: ChatModel, http_client: httpx.AsyncClient) -> ChatCompletionsClient:
if model.format_spec.format == "tgi":
return TGIChatCompletions(
http_client=http_client,
chat_template=model.format_spec.chat_template,
eos_token=model.format_spec.eos_token,
)
elif model.format_spec.format == "openai":
return OpenAIChatCompletions(
http_client=http_client,
prefix=model.format_spec.prefix,
)
else:
raise UnexpectedProxyError(f"Unsupported model format {model.format_spec.format}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next 3 files were copied from dstack-gateway with minor adjustments

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator

from dstack._internal.proxy.schemas.model_proxy import (
ChatCompletionsChunk,
ChatCompletionsRequest,
ChatCompletionsResponse,
)


class ChatCompletionsClient(ABC):
@abstractmethod
async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
pass

@abstractmethod
async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
yield
37 changes: 37 additions & 0 deletions src/dstack/_internal/proxy/services/model_proxy/clients/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import AsyncIterator

import httpx

from dstack._internal.proxy.errors import ProxyError
from dstack._internal.proxy.schemas.model_proxy import (
ChatCompletionsChunk,
ChatCompletionsRequest,
ChatCompletionsResponse,
)
from dstack._internal.proxy.services.model_proxy.clients import ChatCompletionsClient


class OpenAIChatCompletions(ChatCompletionsClient):
def __init__(self, http_client: httpx.AsyncClient, prefix: str):
self._http = http_client
self._prefix = prefix

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
resp = await self._http.post(
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
)
if resp.status_code != 200:
raise ProxyError(resp.text)
return ChatCompletionsResponse.__response__.parse_raw(resp.content)

async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
async with self._http.stream(
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
) as resp:
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
data = line[len("data:") :].strip()
if data == "[DONE]":
break
yield ChatCompletionsChunk.__response__.parse_raw(data)
Loading