Skip to content

Commit

Permalink
client cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang committed Feb 15, 2024
1 parent 6021813 commit f4060be
Showing 1 changed file with 17 additions and 19 deletions.
36 changes: 17 additions & 19 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Any, Dict, Optional, List, AsyncIterator, Iterator
from typing import Any, Dict, Optional, List, AsyncIterator, Iterator, Union

from lorax.types import (
StreamResponse,
Expand Down Expand Up @@ -79,7 +79,7 @@ def generate(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
response_format: Optional[ResponseFormat] = None,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
decoder_input_details: bool = False,
) -> Response:
"""
Expand Down Expand Up @@ -125,7 +125,7 @@ def generate(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
response_format (`Optional[Dict[str, Any]]`):
response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`):
Optional specification of a format to impose upon the generated text, e.g.,:
```
{
Expand Down Expand Up @@ -162,7 +162,7 @@ def generate(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
response_format=response_format, # TODO: make object, ensure all variants of generate are migrated
response_format=response_format,
decoder_input_details=decoder_input_details,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)
Expand All @@ -175,9 +175,7 @@ def generate(
timeout=self.timeout,
)

# TODO: handle 422 errors
print(resp.text) # Failed to deserialize the JSON body into the target type: parameters.response_format: missing field `type` at line 1 column 854

# TODO: expose better error messages for 422 and similar errors
payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
Expand All @@ -203,7 +201,7 @@ def generate_stream(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
response_format: Optional[Dict[str, Any]] = None,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
) -> Iterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens
Expand Down Expand Up @@ -246,7 +244,7 @@ def generate_stream(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
response_format (`Optional[Dict[str, Any]]`):
response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`):
Optional specification of a format to impose upon the generated text, e.g.,:
```
{
Expand Down Expand Up @@ -282,13 +280,13 @@ def generate_stream(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
response_format=json.dumps(response_format) if response_format is not None else None,
response_format=response_format,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

resp = requests.post(
self.base_url,
json=request.dict(),
json=request.dict(by_alias=True),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
Expand Down Expand Up @@ -384,7 +382,7 @@ async def generate(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
response_format: Optional[Dict[str, Any]] = None,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
decoder_input_details: bool = False,
) -> Response:
"""
Expand Down Expand Up @@ -430,7 +428,7 @@ async def generate(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
response_format (`Optional[Dict[str, Any]]`):
response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`):
Optional specification of a format to impose upon the generated text, e.g.,:
```
{
Expand Down Expand Up @@ -468,14 +466,14 @@ async def generate(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
response_format=json.dumps(response_format) if response_format is not None else None,
response_format=response_format,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)

async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp:
payload = await resp.json()

if resp.status != 200:
Expand All @@ -501,7 +499,7 @@ async def generate_stream(
truncate: Optional[int] = None,
typical_p: Optional[float] = None,
watermark: bool = False,
response_format: Optional[Dict[str, Any]] = None,
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
) -> AsyncIterator[StreamResponse]:
"""
Given a prompt, generate the following stream of tokens asynchronously
Expand Down Expand Up @@ -544,7 +542,7 @@ async def generate_stream(
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
response_format (`Optional[Dict[str, Any]]`):
response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`):
Optional specification of a format to impose upon the generated text, e.g.,:
```
{
Expand Down Expand Up @@ -580,14 +578,14 @@ async def generate_stream(
truncate=truncate,
typical_p=typical_p,
watermark=watermark,
response_format=json.dumps(response_format) if response_format is not None else None,
response_format=response_format,
)
request = Request(inputs=prompt, stream=True, parameters=parameters)

async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp:
async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp:

if resp.status != 200:
raise parse_error(resp.status, await resp.json())
Expand Down

0 comments on commit f4060be

Please sign in to comment.