From a46b816371ddeec1afea93623d68e6e871810a13 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 12 Jan 2024 15:01:37 -0800 Subject: [PATCH 1/6] Added pbase adapter_source and expose api_token in client --- clients/python/lorax/client.py | 16 ++++++++++++++++ clients/python/lorax/types.py | 4 +++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index d7c559662..190f004ba 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -63,6 +63,7 @@ def generate( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, best_of: Optional[int] = None, @@ -88,6 +89,8 @@ def generate( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + api_token (`Optional[str]`): + API token for accessing private adapters do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -127,6 +130,7 @@ def generate( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + api_token=api_token, best_of=best_of, details=True, do_sample=do_sample, @@ -162,6 +166,7 @@ def generate_stream( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, @@ -185,6 +190,8 @@ def generate_stream( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + api_token (`Optional[str]`): + API token for accessing private adapters do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -220,6 +227,7 @@ def generate_stream( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + api_token=api_token, best_of=None, details=True, decoder_input_details=False, @@ -321,6 +329,7 @@ async def generate( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, best_of: Optional[int] = None, @@ -346,6 +355,8 @@ async def generate( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + api_token (`Optional[str]`): + API token for accessing private adapters do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -385,6 +396,7 @@ async def generate( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + api_token=api_token, best_of=best_of, details=True, decoder_input_details=decoder_input_details, @@ -418,6 +430,7 @@ async def generate_stream( prompt: str, adapter_id: Optional[str] = None, adapter_source: Optional[str] = None, + api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, repetition_penalty: Optional[float] = None, @@ -441,6 +454,8 @@ async def generate_stream( Adapter ID to apply to the base model for the request adapter_source (`Optional[str]`): Source of the adapter (hub, local, s3) + api_token (`Optional[str]`): + API token for accessing private adapters do_sample (`bool`): Activate logits sampling max_new_tokens (`int`): @@ -476,6 +491,7 @@ async def generate_stream( parameters = Parameters( adapter_id=adapter_id, adapter_source=adapter_source, + api_token=api_token, best_of=None, details=True, decoder_input_details=False, diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index a94c01d7e..fe880f5f5 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -5,7 +5,7 @@ from lorax.errors import ValidationError -ADAPTER_SOURCES = ["hub", "local", "s3"] +ADAPTER_SOURCES = ["hub", "local", "s3", "pbase"] class Parameters(BaseModel): @@ -13,6 +13,8 @@ class Parameters(BaseModel): adapter_id: Optional[str] # The source of the adapter to use adapter_source: Optional[str] + # API token for accessing private adapters + api_token: Optional[str] # Activate logits sampling do_sample: bool = False # Maximum number of generated tokens From c1ef5d60a345a7a901149245e69442aea1dc1ab2 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 12 Jan 2024 15:02:52 -0800 Subject: [PATCH 2/6] Bump client version --- clients/python/lorax/__init__.py | 2 +- clients/python/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/clients/python/lorax/__init__.py b/clients/python/lorax/__init__.py index 1d324b45a..e0c800858 100644 --- a/clients/python/lorax/__init__.py +++ b/clients/python/lorax/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0" +__version__ = "0.2.1" from lorax.client import Client, AsyncClient diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 691ac61cd..a4b284764 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -3,7 +3,7 @@ name = "lorax-client" packages = [ {include = "lorax"} ] -version = "0.2.0" +version = "0.2.1" description = "LoRAX Python Client" license = "Apache-2.0" authors = ["Travis Addair ", "Olivier Dehaene "] From 660b259d5c9080d0b8109e224933d788c2dceb55 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 17 Jan 2024 10:20:44 -0800 Subject: [PATCH 3/6] Updated docs --- docs/models/adapters.md | 18 +++++++++++++++++- docs/reference/python_client.md | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/docs/models/adapters.md b/docs/models/adapters.md index dcae3c15c..f92b51ef3 100644 --- a/docs/models/adapters.md +++ b/docs/models/adapters.md @@ -77,6 +77,22 @@ Usage: } ``` +### Predibase + +Any adapter hosted in [Predibase](https://predibase.com/) can be used in LoRAX by setting `adapter_source="pbase"`. + +When using Predibase hosted adapters, the `adapter_id` format is `/`. If the `model_version` is +omitted, the latest version in the [Model Repoistory](https://docs.predibase.com/ui-guide/Supervised-ML/models/model-repos) +will be used. + +Usage: + +```json +"parameters": { + "adapter_id": "model_repo/model_version", + "adapter_source": "pbase", +} + ### Local When specifying an adapter in a local path, the `adapter_id` should correspond to the root directory of the adapter containing the following files: @@ -112,4 +128,4 @@ Usage: "adapter_id": "s3://adapters_bucket/vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", "adapter_source": "s3", } -``` \ No newline at end of file +``` diff --git a/docs/reference/python_client.md b/docs/reference/python_client.md index 790b8e48a..9e238d5e9 100644 --- a/docs/reference/python_client.md +++ b/docs/reference/python_client.md @@ -95,6 +95,8 @@ class Parameters: adapter_id: Optional[str] # The source of the adapter to use adapter_source: Optional[str] + # API token for accessing private adapters + api_token: Optional[str] # Activate logits sampling do_sample: bool # Maximum number of generated tokens From 142a150a6d133c4083f8aeb97a923504bdfa23b6 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 17 Jan 2024 11:32:23 -0800 Subject: [PATCH 4/6] Plumb token for HF --- README.md | 2 +- router/src/lib.rs | 2 +- server/lorax_server/cli.py | 8 ++++--- server/lorax_server/models/model.py | 4 ++-- server/lorax_server/server.py | 24 ++++++++++++------- server/lorax_server/utils/adapter.py | 8 +++---- server/lorax_server/utils/sources/__init__.py | 4 ++-- server/lorax_server/utils/sources/hub.py | 22 +++++++++-------- 8 files changed, 42 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 5f4e8eee6..c1633cb91 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ curl 127.0.0.1:8080/generate \ "inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]", "parameters": { "max_new_tokens": 64, - "adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k" + "adapter_id": "tgaddair/test-private-lora" } }' \ -H 'Content-Type: application/json' diff --git a/router/src/lib.rs b/router/src/lib.rs index b66583e18..13d827f13 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -78,7 +78,7 @@ pub(crate) struct GenerateParameters { #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, #[serde(default)] - #[schema(nullable = true, default = "null", example = "")] + #[schema(nullable = true, default = "null", example = "")] pub api_token: Option, #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 945f7ab56..4c0fcb36d 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -92,11 +92,12 @@ def _download_weights( extension: str = ".safetensors", auto_convert: bool = True, source: str = "hub", + api_token: Optional[str] = None, ): # Import here after the logger is added to log potential import exceptions from lorax_server import utils from lorax_server.utils import sources - model_source = sources.get_model_source(source, model_id, revision, extension) + model_source = sources.get_model_source(source, model_id, revision, extension, api_token) # Test if files were already download try: @@ -186,6 +187,7 @@ def download_weights( source: str = "hub", adapter_id: str = "", adapter_source: str = "hub", + api_token: Optional[str] = None, ): # Remove default handler logger.remove() @@ -198,9 +200,9 @@ def download_weights( backtrace=True, diagnose=False, ) - _download_weights(model_id, revision, extension, auto_convert, source) + _download_weights(model_id, revision, extension, auto_convert, source, api_token) if adapter_id: - _download_weights(adapter_id, revision, extension, auto_convert, adapter_source) + _download_weights(adapter_id, revision, extension, auto_convert, adapter_source, api_token) @app.command() diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 36d49fe84..fd5a6e894 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -137,7 +137,7 @@ def get_num_layers_for_type(self, layer_type: str) -> int: def is_row_parallel(self, layer_type: str) -> bool: return False - def load_adapter(self, adapter_id, adapter_source, adapter_index): + def load_adapter(self, adapter_id, adapter_source, adapter_index, api_token): """Physically loads the adapter weights into the model. adapter_id must be `BASE_MODEL_ADAPTER_ID` if adapter statically loaded @@ -163,7 +163,7 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index): logger.info(f"Loading adapter weights into model: {adapter_id}") weight_names = tuple([v[0] for v in self.target_to_layer.values()]) module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map( - self.model_id, adapter_id, adapter_source, weight_names + self.model_id, adapter_id, adapter_source, weight_names, api_token ) unused_weight_names = adapter_weight_names.copy() diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 12a80d144..5225b9ce1 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -2,6 +2,7 @@ import os import shutil import torch +from huggingface_hub import HfApi from peft import PeftConfig from grpc import aio @@ -134,19 +135,23 @@ async def DownloadAdapter(self, request, context): adapter_source=request.adapter_source, ) + api_token = request.api_token adapter_source = _adapter_source_enum_to_string(request.adapter_source) if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token) + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) adapter_source = S3 try: - # fail fast if ID is not an adapter (i.e. it is a full model) - # TODO(geoffrey): do this for S3– can't do it this way because the - # files are not yet downloaded locally at this point. if adapter_source == HUB: + # Quick auth check on the repo against the token + HfApi(token=api_token).model_info(adapter_id, revision=None) + + # fail fast if ID is not an adapter (i.e. it is a full model) + # TODO(geoffrey): do this for S3– can't do it this way because the + # files are not yet downloaded locally at this point. config_path = get_config_path(adapter_id, adapter_source) - PeftConfig.from_pretrained(config_path) + PeftConfig.from_pretrained(config_path, token=api_token) - download_weights(adapter_id, source=adapter_source) + download_weights(adapter_id, source=adapter_source, api_token=api_token) return generate_pb2.DownloadAdapterResponse( adapter_id=adapter_id, adapter_source=request.adapter_source, @@ -162,7 +167,7 @@ async def DownloadAdapter(self, request, context): shutil.rmtree(local_path) except Exception as e: logger.warning(f"Error cleaning up safetensors files after " - f"download error: {e}\nIgnoring.") + f"download error: {e}\nIgnoring.") raise async def LoadAdapter(self, request, context): @@ -170,10 +175,11 @@ async def LoadAdapter(self, request, context): adapter_id = request.adapter_id adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index + api_token = request.api_token if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, request.api_token) + adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) adapter_source = S3 - self.model.load_adapter(adapter_id, adapter_source, adapter_index) + self.model.load_adapter(adapter_id, adapter_source, adapter_index, api_token) return generate_pb2.LoadAdapterResponse( adapter_id=adapter_id, diff --git a/server/lorax_server/utils/adapter.py b/server/lorax_server/utils/adapter.py index d568c74c4..51dfa9021 100644 --- a/server/lorax_server/utils/adapter.py +++ b/server/lorax_server/utils/adapter.py @@ -22,12 +22,12 @@ @lru_cache(maxsize=128) -def load_module_map(model_id, adapter_id, adapter_source, weight_names): +def load_module_map(model_id, adapter_id, adapter_source, weight_names, api_token): # TODO(geoffrey): refactor this and merge parts of this function with # lorax_server/utils/adapter.py::create_merged_weight_files - source = get_model_source(adapter_source, adapter_id, extension=".safetensors") + source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) config_path = get_config_path(adapter_id, adapter_source) - adapter_config = LoraConfig.from_pretrained(config_path) + adapter_config = LoraConfig.from_pretrained(config_path, token=api_token) if adapter_config.base_model_name_or_path != model_id: expected_config = AutoConfig.from_pretrained(model_id) model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path) @@ -43,7 +43,7 @@ def load_module_map(model_id, adapter_id, adapter_source, weight_names): f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.") try: - adapter_tokenizer = AutoTokenizer.from_pretrained(config_path) + adapter_tokenizer = AutoTokenizer.from_pretrained(config_path, token=api_token) except Exception: # Adapter does not have a tokenizer, so fallback to base model tokenizer adapter_tokenizer = None diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index 33e342d5b..5cb4e632a 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -40,9 +40,9 @@ def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: # TODO(travis): refactor into registry pattern -def get_model_source(source: str, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"): +def get_model_source(source: str, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None): if source == HUB: - return HubModelSource(model_id, revision, extension) + return HubModelSource(model_id, revision, extension, api_token) elif source == S3: return S3ModelSource(model_id, revision, extension) elif source == LOCAL: diff --git a/server/lorax_server/utils/sources/hub.py b/server/lorax_server/utils/sources/hub.py index ac4ee377d..8c909b8a6 100644 --- a/server/lorax_server/utils/sources/hub.py +++ b/server/lorax_server/utils/sources/hub.py @@ -26,10 +26,10 @@ def get_hub_model_local_dir(model_id: str) -> Path: def weight_hub_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None ) -> List[str]: """Get the weights filenames on the hub""" - api = HfApi() + api = HfApi(token=api_token) info = api.model_info(model_id, revision=revision) filenames = [ s.rfilename @@ -52,7 +52,7 @@ def weight_hub_files( def weight_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None ) -> List[Path]: """Get the local files""" # Local model @@ -65,12 +65,12 @@ def weight_files( return local_files try: - filenames = weight_hub_files(model_id, revision, extension) + filenames = weight_hub_files(model_id, revision, extension, api_token) except EntryNotFoundError as e: if extension != ".safetensors": raise e # Try to see if there are pytorch weights - pt_filenames = weight_hub_files(model_id, revision, extension=".bin") + pt_filenames = weight_hub_files(model_id, revision, extension=".bin", api_token=api_token) # Change pytorch extension to safetensors extension # It is possible that we have safetensors weights locally even though they are not on the # hub if we converted weights locally without pushing them @@ -107,7 +107,7 @@ def weight_files( def download_weights( - filenames: List[str], model_id: str, revision: Optional[str] = None + filenames: List[str], model_id: str, revision: Optional[str] = None, api_token: Optional[str] = None ) -> List[Path]: """Download the safetensors files from the hub""" @@ -127,6 +127,7 @@ def download_file(filename, tries=5, backoff: int = 5): repo_id=model_id, revision=revision, local_files_only=False, + token=api_token, ) logger.info( f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." @@ -157,21 +158,22 @@ def download_file(filename, tries=5, backoff: int = 5): class HubModelSource(BaseModelSource): - def __init__(self, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"): + def __init__(self, model_id: str, revision: Optional[str] = None, extension: str = ".safetensors", api_token: Optional[str] = None): self.model_id = model_id self.revision = revision self.extension = extension + self.api_token = api_token def remote_weight_files(self, extension: str = None): extension = extension or self.extension - return weight_hub_files(self.model_id, self.revision, extension) + return weight_hub_files(self.model_id, self.revision, extension, self.api_token) def weight_files(self, extension=None): extension = extension or self.extension - return weight_files(self.model_id, self.revision, extension) + return weight_files(self.model_id, self.revision, extension, self.api_token) def download_weights(self, filenames): - return download_weights(filenames, self.model_id, self.revision) + return download_weights(filenames, self.model_id, self.revision, self.api_token) def download_model_assets(self): filenames = self.remote_weight_files() From 734b893507937cc4524fd6b09024d94ec4b1e61b Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 17 Jan 2024 12:39:27 -0800 Subject: [PATCH 5/6] Docs --- README.md | 4 ++-- docs/index.md | 2 +- docs/models/adapters.md | 29 +++++++++++++++++++++++++---- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c1633cb91..36820bc41 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ LoRAX (LoRA eXchange) is a framework that allows users to serve thousands of fin - 🏋️‍♀️ **Heterogeneous Continuous Batching:** packs requests for different adapters together into the same batch, keeping latency and throughput nearly constant with the number of concurrent adapters. - 🧁 **Adapter Exchange Scheduling:** asynchronously prefetches and offloads adapters between GPU and CPU memory, schedules request batching to optimize the aggregate throughput of the system. - 👬 **Optimized Inference:** high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming. -- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations. +- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations. Private adapters through per-request tenant isolation. - 🤯 **Free for Commercial Use:** Apache 2.0 License. Enough said 😎. @@ -99,7 +99,7 @@ curl 127.0.0.1:8080/generate \ "inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]", "parameters": { "max_new_tokens": 64, - "adapter_id": "tgaddair/test-private-lora" + "adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k" } }' \ -H 'Content-Type: application/json' diff --git a/docs/index.md b/docs/index.md index 7212f3823..cefe4df12 100644 --- a/docs/index.md +++ b/docs/index.md @@ -31,7 +31,7 @@ LoRAX (LoRA eXchange) is a framework that allows users to serve thousands of fin - 🏋️‍♀️ **Heterogeneous Continuous Batching:** packs requests for different adapters together into the same batch, keeping latency and throughput nearly constant with the number of concurrent adapters. - 🧁 **Adapter Exchange Scheduling:** asynchronously prefetches and offloads adapters between GPU and CPU memory, schedules request batching to optimize the aggregate throughput of the system. - 👬 **Optimized Inference:** high throughput and low latency optimizations including tensor parallelism, pre-compiled CUDA kernels ([flash-attention](https://arxiv.org/abs/2307.08691), [paged attention](https://arxiv.org/abs/2309.06180), [SGMV](https://arxiv.org/abs/2310.18547)), quantization, token streaming. -- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations. +- 🚢 **Ready for Production** prebuilt Docker images, Helm charts for Kubernetes, Prometheus metrics, and distributed tracing with Open Telemetry. OpenAI compatible API supporting multi-turn chat conversations. Private adapters through per-request tenant isolation. - 🤯 **Free for Commercial Use:** Apache 2.0 License. Enough said 😎. diff --git a/docs/models/adapters.md b/docs/models/adapters.md index f92b51ef3..bcddf9b6f 100644 --- a/docs/models/adapters.md +++ b/docs/models/adapters.md @@ -73,7 +73,7 @@ Usage: ```json "parameters": { "adapter_id": "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", - "adapter_source": "hub", + "adapter_source": "hub" } ``` @@ -90,7 +90,7 @@ Usage: ```json "parameters": { "adapter_id": "model_repo/model_version", - "adapter_source": "pbase", + "adapter_source": "pbase" } ### Local @@ -113,7 +113,7 @@ Usage: ```json "parameters": { "adapter_id": "/data/adapters/vineetsharma--qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", - "adapter_source": "local", + "adapter_source": "local" } ``` @@ -126,6 +126,27 @@ Usage: ```json "parameters": { "adapter_id": "s3://adapters_bucket/vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k", - "adapter_source": "s3", + "adapter_source": "s3" } ``` + +## Private Adapter Repositories + +For hosted adapter repositories like HuggingFace Hub and [Predibase](https://predibase.com/), you can perform inference using private adapters per request. + +Usage: + +```json +"parameters": { + "adapter_id": "my-repo/private-adapter", + "api_token": "" +} +``` + +The authorization check is performed per-request in the background (prior to batching to prevent slowing down inference) every time, so even if the +adapter is cachd locally or the authorization token has been invalidated, the check will be performed and handled appropriately. + +For details on generating API tokens, see: + +- [HuggingFace docs](https://huggingface.co/docs/hub/security-tokens) +- [Predibase docs](https://docs.predibase.com/) From 5ff8f3592077caf61392a64c3dc108505864008a Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 17 Jan 2024 12:41:04 -0800 Subject: [PATCH 6/6] cargo fmt --- router/src/lib.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 13d827f13..ec75aa66b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -78,7 +78,11 @@ pub(crate) struct GenerateParameters { #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, #[serde(default)] - #[schema(nullable = true, default = "null", example = "")] + #[schema( + nullable = true, + default = "null", + example = "" + )] pub api_token: Option, #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)]