diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index fb88e0566..130b6c44e 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -3,7 +3,7 @@ from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Any, Dict, Optional, List, AsyncIterator, Iterator +from typing import Any, Dict, Optional, List, AsyncIterator, Iterator, Union from lorax.types import ( StreamResponse, @@ -79,7 +79,7 @@ def generate( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, - response_format: Optional[ResponseFormat] = None, + response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, decoder_input_details: bool = False, ) -> Response: """ @@ -125,7 +125,7 @@ def generate( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - response_format (`Optional[Dict[str, Any]]`): + response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`): Optional specification of a format to impose upon the generated text, e.g.,: ``` { @@ -162,7 +162,7 @@ def generate( truncate=truncate, typical_p=typical_p, watermark=watermark, - response_format=response_format, # TODO: make object, ensure all variants of generate are migrated + response_format=response_format, decoder_input_details=decoder_input_details, ) request = Request(inputs=prompt, stream=False, parameters=parameters) @@ -175,9 +175,7 @@ def generate( timeout=self.timeout, ) - # TODO: handle 422 errors - print(resp.text) # Failed to deserialize the JSON body into the target type: parameters.response_format: missing field `type` at line 1 column 854 - + # TODO: expose better error messages for 422 and similar errors payload = resp.json() if resp.status_code != 200: raise parse_error(resp.status_code, payload) @@ -203,7 +201,7 @@ def generate_stream( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, - response_format: Optional[Dict[str, Any]] = None, + response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, ) -> Iterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens @@ -246,7 +244,7 @@ def generate_stream( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - response_format (`Optional[Dict[str, Any]]`): + response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`): Optional specification of a format to impose upon the generated text, e.g.,: ``` { @@ -282,13 +280,13 @@ def generate_stream( truncate=truncate, typical_p=typical_p, watermark=watermark, - response_format=json.dumps(response_format) if response_format is not None else None, + response_format=response_format, ) request = Request(inputs=prompt, stream=True, parameters=parameters) resp = requests.post( self.base_url, - json=request.dict(), + json=request.dict(by_alias=True), headers=self.headers, cookies=self.cookies, timeout=self.timeout, @@ -384,7 +382,7 @@ async def generate( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, - response_format: Optional[Dict[str, Any]] = None, + response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, decoder_input_details: bool = False, ) -> Response: """ @@ -430,7 +428,7 @@ async def generate( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - response_format (`Optional[Dict[str, Any]]`): + response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`): Optional specification of a format to impose upon the generated text, e.g.,: ``` { @@ -468,14 +466,14 @@ async def generate( truncate=truncate, typical_p=typical_p, watermark=watermark, - response_format=json.dumps(response_format) if response_format is not None else None, + response_format=response_format, ) request = Request(inputs=prompt, stream=False, parameters=parameters) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.base_url, json=request.dict()) as resp: + async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp: payload = await resp.json() if resp.status != 200: @@ -501,7 +499,7 @@ async def generate_stream( truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: bool = False, - response_format: Optional[Dict[str, Any]] = None, + response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -544,7 +542,7 @@ async def generate_stream( See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) - response_format (`Optional[Dict[str, Any]]`): + response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`): Optional specification of a format to impose upon the generated text, e.g.,: ``` { @@ -580,14 +578,14 @@ async def generate_stream( truncate=truncate, typical_p=typical_p, watermark=watermark, - response_format=json.dumps(response_format) if response_format is not None else None, + response_format=response_format, ) request = Request(inputs=prompt, stream=True, parameters=parameters) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.base_url, json=request.dict()) as resp: + async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.json())