Skip to content

Commit

Permalink
Updated python client
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Nov 16, 2023
1 parent 80b96ad commit 0cd8994
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 244 deletions.
77 changes: 16 additions & 61 deletions clients/python/README.md
Original file line number Diff line number Diff line change
@@ -1,80 +1,31 @@
# Text Generation
# LoRAX Python Client

The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`lorax-inference` instance running on
[Hugging Face Inference Endpoints](https://huggingface.co/inference-endpoints) or on the Hugging Face Hub.
LoRAX Python client provides a convenient way of interfacing with a
`lorax` instance running in your environment.

## Get Started
## Getting Started

### Install

```shell
pip install lorax
pip install lorax-client
```

### Inference API Usage

```python
from lorax import InferenceAPIClient

client = InferenceAPIClient("bigscience/bloomz")
text = client.generate("Why is the sky blue?").generated_text
print(text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text

print(text)
# ' Rayleigh scattering'
```

or with the asynchronous client:

```python
from lorax import InferenceAPIAsyncClient

client = InferenceAPIAsyncClient("bigscience/bloomz")
response = await client.generate("Why is the sky blue?")
print(response.generated_text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
if not response.token.special:
text += response.token.text

print(text)
# ' Rayleigh scattering'
```

Check all currently deployed models on the Huggingface Inference API with `Text Generation` support:

```python
from lorax.inference_api import deployed_models

print(deployed_models())
```

### Hugging Face Inference Endpoint usage
### Run

```python
from lorax import Client

endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
endpoint_url = "http://127.0.0.1:8080"

client = Client(endpoint_url)
text = client.generate("Why is the sky blue?").generated_text
text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
print(text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?"):
for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text

Expand All @@ -87,16 +38,16 @@ or with the asynchronous client:
```python
from lorax import AsyncClient

endpoint_url = "https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
endpoint_url = "http://127.0.0.1:8080"

client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?")
response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
print(response.generated_text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?"):
async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text

Expand All @@ -109,6 +60,10 @@ print(text)
```python
# Request Parameters
class Parameters:
# The ID of the adapter to use
adapter_id: Optional[str]
# The source of the adapter to use
adapter_source: Optional[str]
# Activate logits sampling
do_sample: bool
# Maximum number of generated tokens
Expand Down
3 changes: 1 addition & 2 deletions clients/python/lorax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.3.0"
__version__ = "0.1.0"

from lorax.client import Client, AsyncClient
from lorax.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
42 changes: 37 additions & 5 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class Client:
```python
>>> from lorax import Client
>>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> client.generate("Why is the sky blue?").generated_text
>>> client = Client("http://127.0.0.1:8080")
>>> client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
' Rayleigh scattering'
>>> result = ""
>>> for response in client.generate_stream("Why is the sky blue?"):
>>> for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
Expand Down Expand Up @@ -61,6 +61,8 @@ def __init__(
def generate(
self,
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
Expand All @@ -82,6 +84,10 @@ def generate(
Args:
prompt (`str`):
Input text
adapter_id (`Optional[str]`):
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -119,6 +125,8 @@ def generate(
"""
# Validate parameters
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
best_of=best_of,
details=True,
do_sample=do_sample,
Expand Down Expand Up @@ -152,6 +160,8 @@ def generate(
def generate_stream(
self,
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
Expand All @@ -171,6 +181,10 @@ def generate_stream(
Args:
prompt (`str`):
Input text
adapter_id (`Optional[str]`):
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -204,6 +218,8 @@ def generate_stream(
"""
# Validate parameters
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
best_of=None,
details=True,
decoder_input_details=False,
Expand Down Expand Up @@ -264,12 +280,12 @@ class AsyncClient:
>>> from lorax import AsyncClient
>>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
>>> response = await client.generate("Why is the sky blue?")
>>> response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
>>> response.generated_text
' Rayleigh scattering'
>>> result = ""
>>> async for response in client.generate_stream("Why is the sky blue?"):
>>> async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
>>> if not response.token.special:
>>> result += response.token.text
>>> result
Expand Down Expand Up @@ -303,6 +319,8 @@ def __init__(
async def generate(
self,
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
Expand All @@ -324,6 +342,10 @@ async def generate(
Args:
prompt (`str`):
Input text
adapter_id (`Optional[str]`):
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -361,6 +383,8 @@ async def generate(
"""
# Validate parameters
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
best_of=best_of,
details=True,
decoder_input_details=decoder_input_details,
Expand Down Expand Up @@ -392,6 +416,8 @@ async def generate(
async def generate_stream(
self,
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
Expand All @@ -411,6 +437,10 @@ async def generate_stream(
Args:
prompt (`str`):
Input text
adapter_id (`Optional[str]`):
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -444,6 +474,8 @@ async def generate_stream(
"""
# Validate parameters
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
best_of=None,
details=True,
decoder_input_details=False,
Expand Down
5 changes: 1 addition & 4 deletions clients/python/lorax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def __init__(self, message: str):

class NotSupportedError(Exception):
def __init__(self, model_id: str):
message = (
f"Model `{model_id}` is not available for inference with this client. \n"
"Use `huggingface_hub.inference_api.InferenceApi` instead."
)
message = f"Model `{model_id}` is not available for inference with this client."
super(NotSupportedError, self).__init__(message)


Expand Down
Loading

0 comments on commit 0cd8994

Please sign in to comment.