From 0cd89948e1b8198abe2e597edf3e42e432b2d9a2 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 15 Nov 2023 21:38:41 -0700 Subject: [PATCH] Updated python client --- clients/python/README.md | 77 +++--------- clients/python/lorax/__init__.py | 3 +- clients/python/lorax/client.py | 42 ++++++- clients/python/lorax/errors.py | 5 +- clients/python/lorax/inference_api.py | 168 -------------------------- clients/python/lorax/types.py | 13 ++ clients/python/pyproject.toml | 8 +- 7 files changed, 72 insertions(+), 244 deletions(-) delete mode 100644 clients/python/lorax/inference_api.py diff --git a/clients/python/README.md b/clients/python/README.md index f4eead097..d6724331b 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -1,80 +1,31 @@ -# Text Generation +# LoRAX Python Client -The Hugging Face Text Generation Python library provides a convenient way of interfacing with a -`lorax-inference` instance running on -[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub. +LoRAX Python client provides a convenient way of interfacing with a +`lorax` instance running in your environment. -## Get Started +## Getting Started ### Install ```shell -pip install lorax +pip install lorax-client ``` -### Inference API Usage - -```python -from lorax import InferenceAPIClient - -client = InferenceAPIClient("bigscience/bloomz") -text = client.generate("Why is the sky blue?").generated_text -print(text) -# ' Rayleigh scattering' - -# Token Streaming -text = "" -for response in client.generate_stream("Why is the sky blue?"): - if not response.token.special: - text += response.token.text - -print(text) -# ' Rayleigh scattering' -``` - -or with the asynchronous client: - -```python -from lorax import InferenceAPIAsyncClient - -client = InferenceAPIAsyncClient("bigscience/bloomz") -response = await client.generate("Why is the sky blue?") -print(response.generated_text) -# ' Rayleigh scattering' - -# Token Streaming -text = "" -async for response in client.generate_stream("Why is the sky blue?"): - if not response.token.special: - text += response.token.text - -print(text) -# ' Rayleigh scattering' -``` - -Check all currently deployed models on the Huggingface Inference API with `Text Generation` support: - -```python -from lorax.inference_api import deployed_models - -print(deployed_models()) -``` - -### Hugging Face Inference Endpoint usage +### Run ```python from lorax import Client -endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" +endpoint_url = "http://127.0.0.1:8080" client = Client(endpoint_url) -text = client.generate("Why is the sky blue?").generated_text +text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text print(text) # ' Rayleigh scattering' # Token Streaming text = "" -for response in client.generate_stream("Why is the sky blue?"): +for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"): if not response.token.special: text += response.token.text @@ -87,16 +38,16 @@ or with the asynchronous client: ```python from lorax import AsyncClient -endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud" +endpoint_url = "http://127.0.0.1:8080" client = AsyncClient(endpoint_url) -response = await client.generate("Why is the sky blue?") +response = await client.generate("Why is the sky blue?", adapter_id="some/adapter") print(response.generated_text) # ' Rayleigh scattering' # Token Streaming text = "" -async for response in client.generate_stream("Why is the sky blue?"): +async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"): if not response.token.special: text += response.token.text @@ -109,6 +60,10 @@ print(text) ```python # Request Parameters class Parameters: + # The ID of the adapter to use + adapter_id: Optional[str] + # The source of the adapter to use + adapter_source: Optional[str] # Activate logits sampling do_sample: bool # Maximum number of generated tokens diff --git a/clients/python/lorax/__init__.py b/clients/python/lorax/__init__.py index 71e6b381d..0f0f3f3dc 100644 --- a/clients/python/lorax/__init__.py +++ b/clients/python/lorax/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.0" +__version__ = "0.1.0" from lorax.client import Client, AsyncClient -from lorax.inference_api import InferenceAPIClient, InferenceAPIAsyncClient diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 4dd160397..cd352fd89 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -22,12 +22,12 @@ class Client: ```python >>> from lorax import Client - >>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz") - >>> client.generate("Why is the sky blue?").generated_text + >>> client = Client("http://127.0.0.1:8080") + >>> client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text ' Rayleigh scattering' >>> result = "" - >>> for response in client.generate_stream("Why is the sky blue?"): + >>> for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"): >>> if not response.token.special: >>> result += response.token.text >>> result @@ -61,6 +61,8 @@ def __init__( def generate( self, prompt: str, + adapter_id: Optional[str] = None, + adapter_source: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, best_of: Optional[int] = None, @@ -82,6 +84,10 @@ def generate( Args: prompt (`str`): Input text + adapter_id (`Optional[str]`): + Adapter ID to apply to the base model for the request + adapter_source (`Optional[str]`): + Source of the adapter (hub, local, s3) do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -119,6 +125,8 @@ def generate( """ # Validate parameters parameters = Parameters( + adapter_id=adapter_id, + adapter_source=adapter_source, best_of=best_of, details=True, do_sample=do_sample, @@ -152,6 +160,8 @@ def generate( def generate_stream( self, prompt: str, + adapter_id: Optional[str] = None, + adapter_source: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, @@ -171,6 +181,10 @@ def generate_stream( Args: prompt (`str`): Input text + adapter_id (`Optional[str]`): + Adapter ID to apply to the base model for the request + adapter_source (`Optional[str]`): + Source of the adapter (hub, local, s3) do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -204,6 +218,8 @@ def generate_stream( """ # Validate parameters parameters = Parameters( + adapter_id=adapter_id, + adapter_source=adapter_source, best_of=None, details=True, decoder_input_details=False, @@ -264,12 +280,12 @@ class AsyncClient: >>> from lorax import AsyncClient >>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz") - >>> response = await client.generate("Why is the sky blue?") + >>> response = await client.generate("Why is the sky blue?", adapter_id="some/adapter") >>> response.generated_text ' Rayleigh scattering' >>> result = "" - >>> async for response in client.generate_stream("Why is the sky blue?"): + >>> async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"): >>> if not response.token.special: >>> result += response.token.text >>> result @@ -303,6 +319,8 @@ def __init__( async def generate( self, prompt: str, + adapter_id: Optional[str] = None, + adapter_source: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, best_of: Optional[int] = None, @@ -324,6 +342,10 @@ async def generate( Args: prompt (`str`): Input text + adapter_id (`Optional[str]`): + Adapter ID to apply to the base model for the request + adapter_source (`Optional[str]`): + Source of the adapter (hub, local, s3) do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -361,6 +383,8 @@ async def generate( """ # Validate parameters parameters = Parameters( + adapter_id=adapter_id, + adapter_source=adapter_source, best_of=best_of, details=True, decoder_input_details=decoder_input_details, @@ -392,6 +416,8 @@ async def generate( async def generate_stream( self, prompt: str, + adapter_id: Optional[str] = None, + adapter_source: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, @@ -411,6 +437,10 @@ async def generate_stream( Args: prompt (`str`): Input text + adapter_id (`Optional[str]`): + Adapter ID to apply to the base model for the request + adapter_source (`Optional[str]`): + Source of the adapter (hub, local, s3) do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -444,6 +474,8 @@ async def generate_stream( """ # Validate parameters parameters = Parameters( + adapter_id=adapter_id, + adapter_source=adapter_source, best_of=None, details=True, decoder_input_details=False, diff --git a/clients/python/lorax/errors.py b/clients/python/lorax/errors.py index dbf0b761a..e83ed5b75 100644 --- a/clients/python/lorax/errors.py +++ b/clients/python/lorax/errors.py @@ -50,10 +50,7 @@ def __init__(self, message: str): class NotSupportedError(Exception): def __init__(self, model_id: str): - message = ( - f"Model `{model_id}` is not available for inference with this client. \n" - "Use `huggingface_hub.inference_api.InferenceApi` instead." - ) + message = f"Model `{model_id}` is not available for inference with this client." super(NotSupportedError, self).__init__(message) diff --git a/clients/python/lorax/inference_api.py b/clients/python/lorax/inference_api.py deleted file mode 100644 index 51439e765..000000000 --- a/clients/python/lorax/inference_api.py +++ /dev/null @@ -1,168 +0,0 @@ -import os -import requests - -from typing import Dict, Optional, List -from huggingface_hub.utils import build_hf_headers - -from lorax import Client, AsyncClient, __version__ -from lorax.types import DeployedModel -from lorax.errors import NotSupportedError, parse_error - -INFERENCE_ENDPOINT = os.environ.get( - "HF_INFERENCE_ENDPOINT", "https://api-inference.huggingface.co" -) - - -def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: - """ - Get all currently deployed models with lorax-inference-support - - Returns: - List[DeployedModel]: list of all currently deployed models - """ - resp = requests.get( - f"https://api-inference.huggingface.co/framework/lorax-inference", - headers=headers, - timeout=5, - ) - - payload = resp.json() - if resp.status_code != 200: - raise parse_error(resp.status_code, payload) - - models = [DeployedModel(**raw_deployed_model) for raw_deployed_model in payload] - return models - - -def check_model_support(repo_id: str, headers: Optional[Dict] = None) -> bool: - """ - Check if a given model is supported by lorax-inference - - Returns: - bool: whether the model is supported by this client - """ - resp = requests.get( - f"https://api-inference.huggingface.co/status/{repo_id}", - headers=headers, - timeout=5, - ) - - payload = resp.json() - if resp.status_code != 200: - raise parse_error(resp.status_code, payload) - - framework = payload["framework"] - supported = framework == "lorax-inference" - return supported - - -class InferenceAPIClient(Client): - """Client to make calls to the HuggingFace Inference API. - - Only supports a subset of the available lorax or text2lorax models that are served using - lorax-inference - - Example: - - ```python - >>> from lorax import InferenceAPIClient - - >>> client = InferenceAPIClient("bigscience/bloomz") - >>> client.generate("Why is the sky blue?").generated_text - ' Rayleigh scattering' - - >>> result = "" - >>> for response in client.generate_stream("Why is the sky blue?"): - >>> if not response.token.special: - >>> result += response.token.text - >>> result - ' Rayleigh scattering' - ``` - """ - - def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10): - """ - Init headers and API information - - Args: - repo_id (`str`): - Id of repository (e.g. `bigscience/bloom`). - token (`str`, `optional`): - The API token to use as HTTP bearer authorization. This is not - the authentication token. You can find the token in - https://huggingface.co/settings/token. Alternatively, you can - find both your organizations and personal API tokens using - `HfApi().whoami(token)`. - timeout (`int`): - Timeout in seconds - """ - - headers = build_hf_headers( - token=token, library_name="lorax", library_version=__version__ - ) - - # Text Generation Inference client only supports a subset of the available hub models - if not check_model_support(repo_id, headers): - raise NotSupportedError(repo_id) - - base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" - - super(InferenceAPIClient, self).__init__( - base_url, headers=headers, timeout=timeout - ) - - -class InferenceAPIAsyncClient(AsyncClient): - """Aynschronous Client to make calls to the HuggingFace Inference API. - - Only supports a subset of the available lorax or text2lorax models that are served using - lorax-inference - - Example: - - ```python - >>> from lorax import InferenceAPIAsyncClient - - >>> client = InferenceAPIAsyncClient("bigscience/bloomz") - >>> response = await client.generate("Why is the sky blue?") - >>> response.generated_text - ' Rayleigh scattering' - - >>> result = "" - >>> async for response in client.generate_stream("Why is the sky blue?"): - >>> if not response.token.special: - >>> result += response.token.text - >>> result - ' Rayleigh scattering' - ``` - """ - - def __init__(self, repo_id: str, token: Optional[str] = None, timeout: int = 10): - """ - Init headers and API information - - Args: - repo_id (`str`): - Id of repository (e.g. `bigscience/bloom`). - token (`str`, `optional`): - The API token to use as HTTP bearer authorization. This is not - the authentication token. You can find the token in - https://huggingface.co/settings/token. Alternatively, you can - find both your organizations and personal API tokens using - `HfApi().whoami(token)`. - timeout (`int`): - Timeout in seconds - """ - headers = build_hf_headers( - token=token, library_name="lorax", library_version=__version__ - ) - - # Text Generation Inference client only supports a subset of the available hub models - if not check_model_support(repo_id, headers): - raise NotSupportedError(repo_id) - - base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" - - super(InferenceAPIAsyncClient, self).__init__( - base_url, headers=headers, timeout=timeout - ) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index f61eb3205..c9637873d 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -5,7 +5,14 @@ from lorax.errors import ValidationError +ADAPTER_SOURCES = ["hub", "local", "s3"] + + class Parameters(BaseModel): + # The ID of the adapter to use + adapter_id: Optional[str] + # The source of the adapter to use + adapter_source: Optional[str] # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens @@ -40,6 +47,12 @@ class Parameters(BaseModel): # Get decoder input token logprobs and ids decoder_input_details: bool = False + @validator("adapter_source") + def valid_adapter_source(cls, v): + if v is not None and v not in ADAPTER_SOURCES: + raise ValidationError(f"`adapter_source` must be one of {ADAPTER_SOURCES}") + return v + @validator("best_of") def valid_best_of(cls, field_value, values): if field_value is not None: diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 206540a2e..e8679746d 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,10 +1,10 @@ [tool.poetry] -name = "lorax" -version = "0.6.0" +name = "lorax-client" +version = "0.1.0" description = "LoRAX Python Client" license = "Apache-2.0" -authors = ["Olivier Dehaene "] -maintainers = ["Olivier Dehaene "] +authors = ["Travis Addair ", "Olivier Dehaene "] +maintainers = ["Travis Addair "] readme = "README.md" homepage = "https://github.com/predibase/lorax" repository = "https://github.com/predibase/lorax"