Skip to content

Commit

Permalink
add client support
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Apr 12, 2024
1 parent c3ddb44 commit e902c8a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 76 deletions.
57 changes: 56 additions & 1 deletion clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
Response,
Request,
Parameters,
MergedAdapters, ResponseFormat,
MergedAdapters,
ResponseFormat,
EmbedResponse
)
from lorax.errors import parse_error

Expand Down Expand Up @@ -55,6 +57,7 @@ def __init__(
Timeout in seconds
"""
self.base_url = base_url
self.embed_endpoint = f"{base_url}/embed"
self.headers = headers
self.cookies = cookies
self.timeout = timeout
Expand Down Expand Up @@ -334,6 +337,34 @@ def generate_stream(
raise parse_error(resp.status_code, json_payload)
yield response


def embed(self, inputs: str) -> EmbedResponse:
"""
Given inputs, embed the text using the model
Args:
inputs (`str`):
Input text
Returns:
Embeddings: computed embeddings
"""
request = Request(inputs=inputs)

resp = requests.post(
self.embed_endpoint,
json=request.dict(by_alias=True),
headers=self.headers,
cookies=self.cookies,
timeout=self.timeout,
)

payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())

return EmbedResponse(**payload)


class AsyncClient:
"""Asynchronous Client to make calls to a LoRAX instance
Expand Down Expand Up @@ -376,6 +407,7 @@ def __init__(
Timeout in seconds
"""
self.base_url = base_url
self.embed_endpoint = f"{base_url}/embed"
self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60)
Expand Down Expand Up @@ -650,3 +682,26 @@ async def generate_stream(
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
yield response


async def embed(self, inputs: str) -> EmbedResponse:
"""
Given inputs, embed the text using the model
Args:
inputs (`str`):
Input text
Returns:
Embeddings: computed embeddings
"""
request = Request(inputs=inputs)
async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.embed_endpoint, json=request.dict(by_alias=True)) as resp:
payload = await resp.json()

if resp.status != 200:
raise parse_error(resp.status, payload)
return EmbedResponse(**payload)
4 changes: 4 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,7 @@ class StreamResponse(BaseModel):
class DeployedModel(BaseModel):
model_id: str
sha: str

class EmbedResponse(BaseModel):
# Embeddings
embeddings: Optional[List[float]]
75 changes: 0 additions & 75 deletions server/lorax_server/test_embedding_pb.py

This file was deleted.

0 comments on commit e902c8a

Please sign in to comment.