From e902c8ace45f84b51a11575cb5d9e8500d0bb851 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Fri, 12 Apr 2024 21:21:58 +0000 Subject: [PATCH] add client support --- clients/python/lorax/client.py | 57 +++++++++++++++++- clients/python/lorax/types.py | 4 ++ server/lorax_server/test_embedding_pb.py | 75 ------------------------ 3 files changed, 60 insertions(+), 76 deletions(-) delete mode 100644 server/lorax_server/test_embedding_pb.py diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 46627c3d4..b2a98cbed 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -10,7 +10,9 @@ Response, Request, Parameters, - MergedAdapters, ResponseFormat, + MergedAdapters, + ResponseFormat, + EmbedResponse ) from lorax.errors import parse_error @@ -55,6 +57,7 @@ def __init__( Timeout in seconds """ self.base_url = base_url + self.embed_endpoint = f"{base_url}/embed" self.headers = headers self.cookies = cookies self.timeout = timeout @@ -334,6 +337,34 @@ def generate_stream( raise parse_error(resp.status_code, json_payload) yield response + + def embed(self, inputs: str) -> EmbedResponse: + """ + Given inputs, embed the text using the model + + Args: + inputs (`str`): + Input text + + Returns: + Embeddings: computed embeddings + """ + request = Request(inputs=inputs) + + resp = requests.post( + self.embed_endpoint, + json=request.dict(by_alias=True), + headers=self.headers, + cookies=self.cookies, + timeout=self.timeout, + ) + + payload = resp.json() + if resp.status_code != 200: + raise parse_error(resp.status_code, resp.json()) + + return EmbedResponse(**payload) + class AsyncClient: """Asynchronous Client to make calls to a LoRAX instance @@ -376,6 +407,7 @@ def __init__( Timeout in seconds """ self.base_url = base_url + self.embed_endpoint = f"{base_url}/embed" self.headers = headers self.cookies = cookies self.timeout = ClientTimeout(timeout * 60) @@ -650,3 +682,26 @@ async def generate_stream( # If we failed to parse the payload, then it is an error payload raise parse_error(resp.status, json_payload) yield response + + + async def embed(self, inputs: str) -> EmbedResponse: + """ + Given inputs, embed the text using the model + + Args: + inputs (`str`): + Input text + + Returns: + Embeddings: computed embeddings + """ + request = Request(inputs=inputs) + async with ClientSession( + headers=self.headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.post(self.embed_endpoint, json=request.dict(by_alias=True)) as resp: + payload = await resp.json() + + if resp.status != 200: + raise parse_error(resp.status, payload) + return EmbedResponse(**payload) \ No newline at end of file diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 894a518e0..7b954a0c4 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -335,3 +335,7 @@ class StreamResponse(BaseModel): class DeployedModel(BaseModel): model_id: str sha: str + +class EmbedResponse(BaseModel): + # Embeddings + embeddings: Optional[List[float]] \ No newline at end of file diff --git a/server/lorax_server/test_embedding_pb.py b/server/lorax_server/test_embedding_pb.py deleted file mode 100644 index 1e06ce321..000000000 --- a/server/lorax_server/test_embedding_pb.py +++ /dev/null @@ -1,75 +0,0 @@ -import grpc -from lorax_server.pb import generate_pb2_grpc, generate_pb2 -from google.protobuf import json_format - -def run_prefil(stub): - json_string = ''' - { - "batch": { - "requests": [ - { - "inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]", - "truncate": 1792, - "parameters": { - "temperature": 1, - "top_p": 1, - "typical_p": 1, - "seed": 11242005690274133440, - "repetition_penalty": 1 - }, - "stopping_parameters": { - "max_new_tokens": 64 - }, - "adapter_index": 1 - } - ], - "size": 1, - "max_tokens": 112 - } - } - ''' - - request = generate_pb2.PrefillRequest() - json_format.Parse(json_string, request) - response = stub.Prefill(request) - return response - -def run_embed(stub): - json_string = ''' - { - "inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]" - } - ''' - request = generate_pb2.EmbedRequest() - json_format.Parse(json_string, request) - response = stub.Embed(request) - return response - - -def run(): - # Connect to the server using a Unix domain socket - channel = grpc.insecure_channel('unix:///tmp/lorax-server-0') - - # Create a stub (client) - stub = generate_pb2_grpc.LoraxServiceStub(channel) - embed_resp = run_embed(stub) - breakpoint() - prefil_resp = run_prefil(stub) - batch = prefil_resp.batch - - # # Create a request object - request = generate_pb2.DecodeRequest( - batches=[batch] - ) - - # Call the Decode method - for _ in range(100): - try: - response = stub.Decode(request) - print("Client received: ", response) - except grpc.RpcError as e: - print(f"RPC failed: {e.code()} {e.details()}") - -if __name__ == '__main__': - run() -