Skip to content

Commit

Permalink
Improve error handling in model proxy (#1973)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvstme authored Nov 7, 2024
1 parent fc5286d commit 2912670
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 57 deletions.
43 changes: 37 additions & 6 deletions src/dstack/_internal/proxy/routers/model_proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator
from typing import AsyncIterator, Optional

from fastapi import APIRouter, Depends, status
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -55,13 +55,44 @@ async def post_chat_completions(
return await client.generate(body)
else:
return StreamingResponse(
stream_chunks(client.stream(body)),
await StreamingAdaptor(client.stream(body)).get_stream(),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)


async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
async for chunk in chunks:
yield f"data:{chunk.json()}\n\n".encode()
yield "data: [DONE]\n\n".encode()
class StreamingAdaptor:
"""
Converts a stream of ChatCompletionsChunk to an SSE stream.
Also pre-fetches the first chunk **before** starting streaming to downstream,
so that upstream request errors can propagate to the downstream client.
"""

def __init__(self, stream: AsyncIterator[ChatCompletionsChunk]) -> None:
self._stream = stream

async def get_stream(self) -> AsyncIterator[bytes]:
try:
first_chunk = await self._stream.__anext__()
except StopAsyncIteration:
first_chunk = None
return self._adaptor(first_chunk)

async def _adaptor(self, first_chunk: Optional[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
if first_chunk is not None:
yield self._encode_chunk(first_chunk)

try:
async for chunk in self._stream:
yield self._encode_chunk(chunk)
except ProxyError as e:
# No standard way to report errors while streaming,
# but we'll at least send them as comments
yield f": {e.detail!r}\n\n".encode() # !r to avoid line breaks
return

yield "data: [DONE]\n\n".encode()

@staticmethod
def _encode_chunk(chunk: ChatCompletionsChunk) -> bytes:
return f"data:{chunk.json()}\n\n".encode()
62 changes: 46 additions & 16 deletions src/dstack/_internal/proxy/services/model_proxy/clients/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import AsyncIterator

import httpx
from fastapi import status
from pydantic import ValidationError

from dstack._internal.proxy.errors import ProxyError
from dstack._internal.proxy.schemas.model_proxy import (
Expand All @@ -17,21 +19,49 @@ def __init__(self, http_client: httpx.AsyncClient, prefix: str):
self._prefix = prefix

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
resp = await self._http.post(
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
)
if resp.status_code != 200:
raise ProxyError(resp.text)
return ChatCompletionsResponse.__response__.parse_raw(resp.content)
try:
resp = await self._http.post(
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
)
await self._propagate_error(resp)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

try:
return ChatCompletionsResponse.__response__.parse_raw(resp.content)
except ValidationError as e:
raise ProxyError(f"Invalid response from model: {e}", status.HTTP_502_BAD_GATEWAY)

async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
async with self._http.stream(
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
) as resp:
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
data = line[len("data:") :].strip()
if data == "[DONE]":
break
yield ChatCompletionsChunk.__response__.parse_raw(data)
try:
async with self._http.stream(
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
) as resp:
await self._propagate_error(resp)

async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
data = line[len("data:") :].strip()
if data == "[DONE]":
break
yield self._parse_chunk_data(data)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

@staticmethod
def _parse_chunk_data(data: str) -> ChatCompletionsChunk:
try:
return ChatCompletionsChunk.__response__.parse_raw(data)
except ValidationError as e:
raise ProxyError(f"Invalid chunk in model stream: {e}", status.HTTP_502_BAD_GATEWAY)

@staticmethod
async def _propagate_error(resp: httpx.Response) -> None:
"""
Propagates HTTP error by raising ProxyError if status is not 200.
May also raise httpx.RequestError if there are issues reading the response.
"""
if resp.status_code != 200:
resp_body = await resp.aread()
raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code)
97 changes: 62 additions & 35 deletions src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import httpx
import jinja2
import jinja2.sandbox
from fastapi import status

from dstack._internal.proxy.errors import ProxyError
from dstack._internal.proxy.schemas.model_proxy import (
Expand Down Expand Up @@ -38,9 +39,12 @@ def __init__(self, http_client: httpx.AsyncClient, chat_template: str, eos_token

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
payload = self.get_payload(request)
resp = await self.client.post("/generate", json=payload)
if resp.status_code != 200:
raise ProxyError(resp.text) # TODO(egor-s)
try:
resp = await self.client.post("/generate", json=payload)
await self.propagate_error(resp)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

data = resp.json()

choices = [
Expand Down Expand Up @@ -91,38 +95,51 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
created = int(datetime.datetime.utcnow().timestamp())

payload = self.get_payload(request)
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
async for line in resp.aiter_lines():
if line.startswith("data:"):
data = json.loads(line[len("data:") :].strip("\n"))
if "error" in data:
raise ProxyError(data["error"])
chunk = ChatCompletionsChunk(
id=completion_id,
choices=[],
created=created,
model=request.model,
system_fingerprint="",
)
if data["details"] is not None:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={},
logprobs=None,
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
index=0,
)
]
else:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={"content": data["token"]["text"], "role": "assistant"},
logprobs=None,
finish_reason=None,
index=0,
)
]
yield chunk
try:
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
await self.propagate_error(resp)
async for line in resp.aiter_lines():
if line.startswith("data:"):
yield self.parse_chunk(
data=json.loads(line[len("data:") :].strip("\n")),
model=request.model,
completion_id=completion_id,
created=created,
)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

def parse_chunk(
self, data: dict, model: str, completion_id: str, created: int
) -> ChatCompletionsChunk:
if "error" in data:
raise ProxyError(data["error"])
chunk = ChatCompletionsChunk(
id=completion_id,
choices=[],
created=created,
model=model,
system_fingerprint="",
)
if data["details"] is not None:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={},
logprobs=None,
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
index=0,
)
]
else:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={"content": data["token"]["text"], "role": "assistant"},
logprobs=None,
finish_reason=None,
index=0,
)
]
return chunk

def get_payload(self, request: ChatCompletionsRequest) -> Dict:
try:
Expand Down Expand Up @@ -177,6 +194,16 @@ def trim_stop_tokens(text: str, stop_tokens: List[str]) -> str:
return text[: -len(stop_token)]
return text

@staticmethod
async def propagate_error(resp: httpx.Response) -> None:
"""
Propagates HTTP error by raising ProxyError if status is not 200.
May also raise httpx.RequestError if there are issues reading the response.
"""
if resp.status_code != 200:
resp_body = await resp.aread()
raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code)


def raise_exception(message: str):
raise jinja2.TemplateError(message)

0 comments on commit 2912670

Please sign in to comment.