Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pbase adapter_source and expose api_token in client #181

Merged
merged 7 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LoRAX (LoRA eXchange) is a framework that allows users to serve thousands of fin
- 🏋️‍♀️ **Heterogeneous Continuous Batching:** packs requests for different adapters together into the same batch, keeping latency and throughput nearly constant with the number of concurrent adapters.
- 🧁 **Adapter Exchange Scheduling:** asynchronously prefetches and offloads adapters between GPU and CPU memory, schedules request batching to optimize the aggregate throughput of the system.
- 👬 **Optimized Inference:** high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming.
- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations.
- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations. Private adapters through per-request tenant isolation.
- 🤯 **Free for Commercial Use:** Apache 2.0 License. Enough said 😎.


Expand Down
2 changes: 1 addition & 1 deletion clients/python/lorax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.2.0"
__version__ = "0.2.1"

from lorax.client import Client, AsyncClient
16 changes: 16 additions & 0 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def generate(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
Expand All @@ -88,6 +89,8 @@ def generate(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -127,6 +130,7 @@ def generate(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
api_token=api_token,
best_of=best_of,
details=True,
do_sample=do_sample,
Expand Down Expand Up @@ -162,6 +166,7 @@ def generate_stream(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
Expand All @@ -185,6 +190,8 @@ def generate_stream(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -220,6 +227,7 @@ def generate_stream(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
api_token=api_token,
best_of=None,
details=True,
decoder_input_details=False,
Expand Down Expand Up @@ -321,6 +329,7 @@ async def generate(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
best_of: Optional[int] = None,
Expand All @@ -346,6 +355,8 @@ async def generate(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -385,6 +396,7 @@ async def generate(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
api_token=api_token,
best_of=best_of,
details=True,
decoder_input_details=decoder_input_details,
Expand Down Expand Up @@ -418,6 +430,7 @@ async def generate_stream(
prompt: str,
adapter_id: Optional[str] = None,
adapter_source: Optional[str] = None,
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
repetition_penalty: Optional[float] = None,
Expand All @@ -441,6 +454,8 @@ async def generate_stream(
Adapter ID to apply to the base model for the request
adapter_source (`Optional[str]`):
Source of the adapter (hub, local, s3)
api_token (`Optional[str]`):
API token for accessing private adapters
do_sample (`bool`):
Activate logits sampling
max_new_tokens (`int`):
Expand Down Expand Up @@ -476,6 +491,7 @@ async def generate_stream(
parameters = Parameters(
adapter_id=adapter_id,
adapter_source=adapter_source,
api_token=api_token,
best_of=None,
details=True,
decoder_input_details=False,
Expand Down
4 changes: 3 additions & 1 deletion clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from lorax.errors import ValidationError


ADAPTER_SOURCES = ["hub", "local", "s3"]
ADAPTER_SOURCES = ["hub", "local", "s3", "pbase"]


class Parameters(BaseModel):
# The ID of the adapter to use
adapter_id: Optional[str]
# The source of the adapter to use
adapter_source: Optional[str]
# API token for accessing private adapters
api_token: Optional[str]
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
Expand Down
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "lorax-client"
packages = [
{include = "lorax"}
]
version = "0.2.0"
version = "0.2.1"
description = "LoRAX Python Client"
license = "Apache-2.0"
authors = ["Travis Addair <[email protected]>", "Olivier Dehaene <[email protected]>"]
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ LoRAX (LoRA eXchange) is a framework that allows users to serve thousands of fin
- 🏋️‍♀️ **Heterogeneous Continuous Batching:** packs requests for different adapters together into the same batch, keeping latency and throughput nearly constant with the number of concurrent adapters.
- 🧁 **Adapter Exchange Scheduling:** asynchronously prefetches and offloads adapters between GPU and CPU memory, schedules request batching to optimize the aggregate throughput of the system.
- 👬 **Optimized Inference:** high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming.
- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations.
- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations. Private adapters through per-request tenant isolation.
- 🤯 **Free for Commercial Use:** Apache 2.0 License. Enough said 😎.


Expand Down
45 changes: 41 additions & 4 deletions docs/models/adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,26 @@ Usage:
```json
"parameters": {
"adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
"adapter_source": "hub",
"adapter_source": "hub"
}
```

### Predibase

Any adapter hosted in [Predibase](https://predibase.com/) can be used in LoRAX by setting `adapter_source="pbase"`.

When using Predibase hosted adapters, the `adapter_id` format is `<model_repo>/<model_version>`. If the `model_version` is
omitted, the latest version in the [Model Repoistory](https://docs.predibase.com/ui-guide/Supervised-ML/models/model-repos)
will be used.

Usage:

```json
"parameters": {
"adapter_id": "model_repo/model_version",
"adapter_source": "pbase"
}

### Local

When specifying an adapter in a local path, the `adapter_id` should correspond to the root directory of the adapter containing the following files:
Expand All @@ -97,7 +113,7 @@ Usage:
```json
"parameters": {
"adapter_id": "/data/adapters/vineetsharma--qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
"adapter_source": "local",
"adapter_source": "local"
}
```

Expand All @@ -110,6 +126,27 @@ Usage:
```json
"parameters": {
"adapter_id": "s3://adapters_bucket/vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k",
"adapter_source": "s3",
"adapter_source": "s3"
}
```

## Private Adapter Repositories

For hosted adapter repositories like HuggingFace Hub and [Predibase](https://predibase.com/), you can perform inference using private adapters per request.

Usage:

```json
"parameters": {
"adapter_id": "my-repo/private-adapter",
"api_token": "<auth_token>"
}
```
```

The authorization check is performed per-request in the background (prior to batching to prevent slowing down inference) every time, so even if the
adapter is cachd locally or the authorization token has been invalidated, the check will be performed and handled appropriately.

For details on generating API tokens, see:

- [HuggingFace docs](https://huggingface.co/docs/hub/security-tokens)
- [Predibase docs](https://docs.predibase.com/)
2 changes: 2 additions & 0 deletions docs/reference/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class Parameters:
adapter_id: Optional[str]
# The source of the adapter to use
adapter_source: Optional[str]
# API token for accessing private adapters
api_token: Optional[str]
# Activate logits sampling
do_sample: bool
# Maximum number of generated tokens
Expand Down
6 changes: 5 additions & 1 deletion router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ pub(crate) struct GenerateParameters {
#[schema(nullable = true, default = "null", example = "hub")]
pub adapter_source: Option<String>,
#[serde(default)]
#[schema(nullable = true, default = "null", example = "<token from predibase>")]
#[schema(
nullable = true,
default = "null",
example = "<token for private adapters>"
)]
pub api_token: Option<String>,
#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]
Expand Down
8 changes: 5 additions & 3 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ def _download_weights(
extension: str = ".safetensors",
auto_convert: bool = True,
source: str = "hub",
api_token: Optional[str] = None,
):
# Import here after the logger is added to log potential import exceptions
from lorax_server import utils
from lorax_server.utils import sources
model_source = sources.get_model_source(source, model_id, revision, extension)
model_source = sources.get_model_source(source, model_id, revision, extension, api_token)

# Test if files were already download
try:
Expand Down Expand Up @@ -186,6 +187,7 @@ def download_weights(
source: str = "hub",
adapter_id: str = "",
adapter_source: str = "hub",
api_token: Optional[str] = None,
):
# Remove default handler
logger.remove()
Expand All @@ -198,9 +200,9 @@ def download_weights(
backtrace=True,
diagnose=False,
)
_download_weights(model_id, revision, extension, auto_convert, source)
_download_weights(model_id, revision, extension, auto_convert, source, api_token)
if adapter_id:
_download_weights(adapter_id, revision, extension, auto_convert, adapter_source)
_download_weights(adapter_id, revision, extension, auto_convert, adapter_source, api_token)


@app.command()
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_num_layers_for_type(self, layer_type: str) -> int:
def is_row_parallel(self, layer_type: str) -> bool:
return False

def load_adapter(self, adapter_id, adapter_source, adapter_index):
def load_adapter(self, adapter_id, adapter_source, adapter_index, api_token):
"""Physically loads the adapter weights into the model.

adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded
Expand All @@ -163,7 +163,7 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):
logger.info(f"Loading adapter weights into model: {adapter_id}")
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map(
self.model_id, adapter_id, adapter_source, weight_names
self.model_id, adapter_id, adapter_source, weight_names, api_token
)

unused_weight_names = adapter_weight_names.copy()
Expand Down
24 changes: 15 additions & 9 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
import torch
from huggingface_hub import HfApi
from peft import PeftConfig

from grpc import aio
Expand Down Expand Up @@ -134,19 +135,23 @@ async def DownloadAdapter(self, request, context):
adapter_source=request.adapter_source,
)

api_token = request.api_token
adapter_source = _adapter_source_enum_to_string(request.adapter_source)
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token)
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token)
adapter_source = S3
try:
# fail fast if ID is not an adapter (i.e. it is a full model)
# TODO(geoffrey): do this for S3– can't do it this way because the
# files are not yet downloaded locally at this point.
if adapter_source == HUB:
# Quick auth check on the repo against the token
HfApi(token=api_token).model_info(adapter_id, revision=None)

# fail fast if ID is not an adapter (i.e. it is a full model)
# TODO(geoffrey): do this for S3– can't do it this way because the
# files are not yet downloaded locally at this point.
config_path = get_config_path(adapter_id, adapter_source)
PeftConfig.from_pretrained(config_path)
PeftConfig.from_pretrained(config_path, token=api_token)

download_weights(adapter_id, source=adapter_source)
download_weights(adapter_id, source=adapter_source, api_token=api_token)
return generate_pb2.DownloadAdapterResponse(
adapter_id=adapter_id,
adapter_source=request.adapter_source,
Expand All @@ -162,18 +167,19 @@ async def DownloadAdapter(self, request, context):
shutil.rmtree(local_path)
except Exception as e:
logger.warning(f"Error cleaning up safetensors files after "
f"download error: {e}\nIgnoring.")
f"download error: {e}\nIgnoring.")
raise

async def LoadAdapter(self, request, context):
try:
adapter_id = request.adapter_id
adapter_source = _adapter_source_enum_to_string(request.adapter_source)
adapter_index = request.adapter_index
api_token = request.api_token
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token)
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token)
adapter_source = S3
self.model.load_adapter(adapter_id, adapter_source, adapter_index)
self.model.load_adapter(adapter_id, adapter_source, adapter_index, api_token)

return generate_pb2.LoadAdapterResponse(
adapter_id=adapter_id,
Expand Down
8 changes: 4 additions & 4 deletions server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@


@lru_cache(maxsize=128)
def load_module_map(model_id, adapter_id, adapter_source, weight_names):
def load_module_map(model_id, adapter_id, adapter_source, weight_names, api_token):
# TODO(geoffrey): refactor this and merge parts of this function with
# lorax_server/utils/adapter.py::create_merged_weight_files
source = get_model_source(adapter_source, adapter_id, extension=".safetensors")
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
config_path = get_config_path(adapter_id, adapter_source)
adapter_config = LoraConfig.from_pretrained(config_path)
adapter_config = LoraConfig.from_pretrained(config_path, token=api_token)
if adapter_config.base_model_name_or_path != model_id:
expected_config = AutoConfig.from_pretrained(model_id)
model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path)
Expand All @@ -43,7 +43,7 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names):
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")

try:
adapter_tokenizer = AutoTokenizer.from_pretrained(config_path)
adapter_tokenizer = AutoTokenizer.from_pretrained(config_path, token=api_token)
except Exception:
# Adapter does not have a tokenizer, so fallback to base model tokenizer
adapter_tokenizer = None
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/utils/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str:


# TODO(travis): refactor into registry pattern
def get_model_source(source: str, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"):
def get_model_source(source: str, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None):
if source == HUB:
return HubModelSource(model_id, revision, extension)
return HubModelSource(model_id, revision, extension, api_token)
elif source == S3:
return S3ModelSource(model_id, revision, extension)
elif source == LOCAL:
Expand Down
Loading
Loading