diff --git a/src/dstack/_internal/proxy/routers/model_proxy.py b/src/dstack/_internal/proxy/routers/model_proxy.py index 918b23bc6..419b70cb1 100644 --- a/src/dstack/_internal/proxy/routers/model_proxy.py +++ b/src/dstack/_internal/proxy/routers/model_proxy.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator +from typing import AsyncIterator, Optional from fastapi import APIRouter, Depends, status from fastapi.responses import StreamingResponse @@ -55,13 +55,44 @@ async def post_chat_completions( return await client.generate(body) else: return StreamingResponse( - stream_chunks(client.stream(body)), + await StreamingAdaptor(client.stream(body)).get_stream(), media_type="text/event-stream", headers={"X-Accel-Buffering": "no"}, ) -async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]: - async for chunk in chunks: - yield f"data:{chunk.json()}\n\n".encode() - yield "data: [DONE]\n\n".encode() +class StreamingAdaptor: + """ + Converts a stream of ChatCompletionsChunk to an SSE stream. + Also pre-fetches the first chunk **before** starting streaming to downstream, + so that upstream request errors can propagate to the downstream client. + """ + + def __init__(self, stream: AsyncIterator[ChatCompletionsChunk]) -> None: + self._stream = stream + + async def get_stream(self) -> AsyncIterator[bytes]: + try: + first_chunk = await self._stream.__anext__() + except StopAsyncIteration: + first_chunk = None + return self._adaptor(first_chunk) + + async def _adaptor(self, first_chunk: Optional[ChatCompletionsChunk]) -> AsyncIterator[bytes]: + if first_chunk is not None: + yield self._encode_chunk(first_chunk) + + try: + async for chunk in self._stream: + yield self._encode_chunk(chunk) + except ProxyError as e: + # No standard way to report errors while streaming, + # but we'll at least send them as comments + yield f": {e.detail!r}\n\n".encode() # !r to avoid line breaks + return + + yield "data: [DONE]\n\n".encode() + + @staticmethod + def _encode_chunk(chunk: ChatCompletionsChunk) -> bytes: + return f"data:{chunk.json()}\n\n".encode() diff --git a/src/dstack/_internal/proxy/services/model_proxy/clients/openai.py b/src/dstack/_internal/proxy/services/model_proxy/clients/openai.py index bff4eaadf..69b6909cc 100644 --- a/src/dstack/_internal/proxy/services/model_proxy/clients/openai.py +++ b/src/dstack/_internal/proxy/services/model_proxy/clients/openai.py @@ -1,6 +1,8 @@ from typing import AsyncIterator import httpx +from fastapi import status +from pydantic import ValidationError from dstack._internal.proxy.errors import ProxyError from dstack._internal.proxy.schemas.model_proxy import ( @@ -17,21 +19,49 @@ def __init__(self, http_client: httpx.AsyncClient, prefix: str): self._prefix = prefix async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse: - resp = await self._http.post( - f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True) - ) - if resp.status_code != 200: - raise ProxyError(resp.text) - return ChatCompletionsResponse.__response__.parse_raw(resp.content) + try: + resp = await self._http.post( + f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True) + ) + await self._propagate_error(resp) + except httpx.RequestError as e: + raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY) + + try: + return ChatCompletionsResponse.__response__.parse_raw(resp.content) + except ValidationError as e: + raise ProxyError(f"Invalid response from model: {e}", status.HTTP_502_BAD_GATEWAY) async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]: - async with self._http.stream( - "POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True) - ) as resp: - async for line in resp.aiter_lines(): - if not line.startswith("data:"): - continue - data = line[len("data:") :].strip() - if data == "[DONE]": - break - yield ChatCompletionsChunk.__response__.parse_raw(data) + try: + async with self._http.stream( + "POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True) + ) as resp: + await self._propagate_error(resp) + + async for line in resp.aiter_lines(): + if not line.startswith("data:"): + continue + data = line[len("data:") :].strip() + if data == "[DONE]": + break + yield self._parse_chunk_data(data) + except httpx.RequestError as e: + raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY) + + @staticmethod + def _parse_chunk_data(data: str) -> ChatCompletionsChunk: + try: + return ChatCompletionsChunk.__response__.parse_raw(data) + except ValidationError as e: + raise ProxyError(f"Invalid chunk in model stream: {e}", status.HTTP_502_BAD_GATEWAY) + + @staticmethod + async def _propagate_error(resp: httpx.Response) -> None: + """ + Propagates HTTP error by raising ProxyError if status is not 200. + May also raise httpx.RequestError if there are issues reading the response. + """ + if resp.status_code != 200: + resp_body = await resp.aread() + raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code) diff --git a/src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py b/src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py index d06ac0510..f71a3e66c 100644 --- a/src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py +++ b/src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py @@ -6,6 +6,7 @@ import httpx import jinja2 import jinja2.sandbox +from fastapi import status from dstack._internal.proxy.errors import ProxyError from dstack._internal.proxy.schemas.model_proxy import ( @@ -38,9 +39,12 @@ def __init__(self, http_client: httpx.AsyncClient, chat_template: str, eos_token async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse: payload = self.get_payload(request) - resp = await self.client.post("/generate", json=payload) - if resp.status_code != 200: - raise ProxyError(resp.text) # TODO(egor-s) + try: + resp = await self.client.post("/generate", json=payload) + await self.propagate_error(resp) + except httpx.RequestError as e: + raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY) + data = resp.json() choices = [ @@ -91,38 +95,51 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom created = int(datetime.datetime.utcnow().timestamp()) payload = self.get_payload(request) - async with self.client.stream("POST", "/generate_stream", json=payload) as resp: - async for line in resp.aiter_lines(): - if line.startswith("data:"): - data = json.loads(line[len("data:") :].strip("\n")) - if "error" in data: - raise ProxyError(data["error"]) - chunk = ChatCompletionsChunk( - id=completion_id, - choices=[], - created=created, - model=request.model, - system_fingerprint="", - ) - if data["details"] is not None: - chunk.choices = [ - ChatCompletionsChunkChoice( - delta={}, - logprobs=None, - finish_reason=self.finish_reason(data["details"]["finish_reason"]), - index=0, - ) - ] - else: - chunk.choices = [ - ChatCompletionsChunkChoice( - delta={"content": data["token"]["text"], "role": "assistant"}, - logprobs=None, - finish_reason=None, - index=0, - ) - ] - yield chunk + try: + async with self.client.stream("POST", "/generate_stream", json=payload) as resp: + await self.propagate_error(resp) + async for line in resp.aiter_lines(): + if line.startswith("data:"): + yield self.parse_chunk( + data=json.loads(line[len("data:") :].strip("\n")), + model=request.model, + completion_id=completion_id, + created=created, + ) + except httpx.RequestError as e: + raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY) + + def parse_chunk( + self, data: dict, model: str, completion_id: str, created: int + ) -> ChatCompletionsChunk: + if "error" in data: + raise ProxyError(data["error"]) + chunk = ChatCompletionsChunk( + id=completion_id, + choices=[], + created=created, + model=model, + system_fingerprint="", + ) + if data["details"] is not None: + chunk.choices = [ + ChatCompletionsChunkChoice( + delta={}, + logprobs=None, + finish_reason=self.finish_reason(data["details"]["finish_reason"]), + index=0, + ) + ] + else: + chunk.choices = [ + ChatCompletionsChunkChoice( + delta={"content": data["token"]["text"], "role": "assistant"}, + logprobs=None, + finish_reason=None, + index=0, + ) + ] + return chunk def get_payload(self, request: ChatCompletionsRequest) -> Dict: try: @@ -177,6 +194,16 @@ def trim_stop_tokens(text: str, stop_tokens: List[str]) -> str: return text[: -len(stop_token)] return text + @staticmethod + async def propagate_error(resp: httpx.Response) -> None: + """ + Propagates HTTP error by raising ProxyError if status is not 200. + May also raise httpx.RequestError if there are issues reading the response. + """ + if resp.status_code != 200: + resp_body = await resp.aread() + raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code) + def raise_exception(message: str): raise jinja2.TemplateError(message)