From 143f30308ef0f8c51ef949d016a5c0694e20268c Mon Sep 17 00:00:00 2001 From: Magdy Saleh <17618143+magdyksaleh@users.noreply.github.com> Date: Thu, 14 Dec 2023 18:55:51 -0500 Subject: [PATCH] Add predibase as a source for adapters (#125) --- .gitignore | 3 +- build.sh | 36 +------------------ docs/reference/openapi.json | 4 +++ launcher/src/main.rs | 3 +- proto/generate.proto | 8 ++++- router/client/src/client.rs | 4 +++ router/client/src/sharded_client.rs | 5 ++- router/src/adapter.rs | 15 ++++++-- router/src/infer.rs | 8 ++++- router/src/lib.rs | 4 +++ router/src/loader.rs | 7 +++- router/src/validation.rs | 8 ++--- server/lorax_server/server.py | 10 +++++- server/lorax_server/utils/__init__.py | 6 +++- server/lorax_server/utils/sources/__init__.py | 31 ++++++++++++++++ 15 files changed, 103 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index d041675b4..09e2b1b5f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ target router/tokenizer.json *__pycache__* -run.sh \ No newline at end of file +run.sh +data/ diff --git a/build.sh b/build.sh index c8ea41f82..8e6170648 100755 --- a/build.sh +++ b/build.sh @@ -17,7 +17,7 @@ COMMIT_SHA=$(git rev-parse --short HEAD) TAG="${COMMIT_SHA}${DIRTY}" # Name of the Docker image -IMAGE_NAME="kubellm" +IMAGE_NAME="lorax" # ECR Repository URL (replace with your actual ECR repository URL) ECR_REPO="474375891613.dkr.ecr.us-west-2.amazonaws.com" @@ -28,37 +28,3 @@ echo "Building ${IMAGE_NAME}:${TAG}" docker build -t ${IMAGE_NAME}:${TAG} . docker tag ${IMAGE_NAME}:${TAG} ${IMAGE_NAME}:latest -# Tag the Docker image for ECR repository -docker tag ${IMAGE_NAME}:${TAG} ${ECR_REPO}/${IMAGE_NAME}:${TAG} - -# Log in to the ECR registry (assumes AWS CLI and permissions are set up) -aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin ${ECR_REPO} - -# Push to ECR -docker push ${ECR_REPO}/${IMAGE_NAME}:${TAG} - - -latest_flag=false - -# Parse command line arguments -while [[ $# -gt 0 ]]; do - case "$1" in - --latest) - latest_flag=true - shift - ;; - *) - echo "Unknown option: $1" - exit 1 - ;; - esac -done - -# Check if the --latest flag has been passed -if $latest_flag; then - # Tag and push as 'latest' - docker tag ${IMAGE_NAME}:${TAG} ${ECR_REPO}/${IMAGE_NAME}:latest - docker push ${ECR_REPO}/${IMAGE_NAME}:latest -else - echo "The --latest flag has not been passed. Skipping push to ECR as latest." -fi diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index cd1d9f24f..4d7322618 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -596,6 +596,10 @@ "adapter_source": { "type": "string", "nullable": true + }, + "api_token": { + "type": "string", + "nullable": true } } }, diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 0cb1d2ad7..9e9ce877e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -100,9 +100,10 @@ struct Args { source: String, /// The source of the model to load. - /// Can be `hub` or `s3`. + /// Can be `hub` or `s3` or `pbase` /// `hub` will load the model from the huggingface hub. /// `s3` will load the model from the predibase S3 bucket. + /// `pbase` will load an s3 model but resolve the metadata from a predibase server #[clap(default_value = "hub", long, env)] adapter_source: String, diff --git a/proto/generate.proto b/proto/generate.proto index 5dc5b3ec8..ae4caa3f9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -224,6 +224,8 @@ enum AdapterSource { S3 = 1; /// Adapters loaded via local filesystem path LOCAL = 2; + /// Adapters loaded via predibase + PBASE = 3; } message DownloadAdapterRequest { @@ -231,6 +233,8 @@ message DownloadAdapterRequest { string adapter_id = 1; /// Adapter source AdapterSource adapter_source = 2; + /// Token for external API (predibase / HuggingFace) + optional string api_token = 3; } message DownloadAdapterResponse { @@ -247,6 +251,8 @@ message LoadAdapterRequest { AdapterSource adapter_source = 2; /// Adapter index uint32 adapter_index = 3; + /// Token for external API (predibase / HuggingFace) + optional string api_token = 4; } message LoadAdapterResponse { @@ -274,4 +280,4 @@ message OffloadAdapterResponse { AdapterSource adapter_source = 2; /// Adapter index uint32 adapter_index = 3; -} \ No newline at end of file +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index b6f1c34c2..6442c97c6 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -183,6 +183,7 @@ impl Client { &mut self, adapter_id: String, adapter_source: String, + api_token: Option, ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) @@ -190,6 +191,7 @@ impl Client { let request = tonic::Request::new(DownloadAdapterRequest { adapter_id, adapter_source: adapter_source_enum.into(), + api_token: api_token, }) .inject_context(); let response = self.stub.download_adapter(request).await?.into_inner(); @@ -210,6 +212,7 @@ impl Client { adapter_id: String, adapter_source: String, adapter_index: u32, + api_token: Option, ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) @@ -218,6 +221,7 @@ impl Client { adapter_id, adapter_source: adapter_source_enum.into(), adapter_index, + api_token: api_token, }) .inject_context(); let response = self.stub.load_adapter(request).await?.into_inner(); diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 02f73b65c..75bfd5c3f 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -151,10 +151,11 @@ impl ShardedClient { &mut self, adapter_id: String, adapter_source: String, + api_token: Option, ) -> Result { // Only download the adapter with one client, since they share a single disk self.clients[0] - .download_adapter(adapter_id, adapter_source) + .download_adapter(adapter_id, adapter_source, api_token) .await } @@ -163,6 +164,7 @@ impl ShardedClient { adapter_id: String, adapter_source: String, adapter_index: u32, + api_token: Option, ) -> Result { // Load the adapter in all clients since there is sharding done between them let futures: Vec<_> = self @@ -173,6 +175,7 @@ impl ShardedClient { adapter_id.clone(), adapter_source.clone(), adapter_index, + api_token.clone(), )) }) .collect(); diff --git a/router/src/adapter.rs b/router/src/adapter.rs index 8a23559e6..a7d81156f 100644 --- a/router/src/adapter.rs +++ b/router/src/adapter.rs @@ -17,11 +17,18 @@ pub(crate) struct Adapter { source: String, /// index of the adapter index: u32, + /// Optional - External api token + api_token: Option, } impl Adapter { - pub(crate) fn new(id: String, source: String, index: u32) -> Self { - Self { id, source, index } + pub(crate) fn new(id: String, source: String, index: u32, api_token: Option) -> Self { + Self { + id, + source, + index, + api_token, + } } pub(crate) fn id(&self) -> &str { @@ -32,6 +39,10 @@ impl Adapter { &self.source } + pub(crate) fn api_token(&self) -> &std::option::Option { + &self.api_token + } + pub(crate) fn index(&self) -> u32 { self.index } diff --git a/router/src/infer.rs b/router/src/infer.rs index b2c0cc647..a93897f5a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -138,7 +138,13 @@ impl Infer { } } - let adapter = Adapter::new(adapter_id.unwrap(), adapter_source.unwrap(), adapter_idx); + let api_token = request.parameters.api_token.clone(); + let adapter = Adapter::new( + adapter_id.unwrap(), + adapter_source.unwrap(), + adapter_idx, + api_token, + ); // Validate request let valid_request = self diff --git a/router/src/lib.rs b/router/src/lib.rs index f73ceef93..8637f9b2c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -78,6 +78,9 @@ pub(crate) struct GenerateParameters { #[schema(nullable = true, default = "null", example = "hub")] pub adapter_source: Option, #[serde(default)] + #[schema(nullable = true, default = "null", example = "")] + pub api_token: Option, + #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] pub best_of: Option, #[serde(default)] @@ -159,6 +162,7 @@ fn default_parameters() -> GenerateParameters { GenerateParameters { adapter_id: None, adapter_source: None, + api_token: None, best_of: None, temperature: None, repetition_penalty: None, diff --git a/router/src/loader.rs b/router/src/loader.rs index bf86cafad..74c39b690 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -139,7 +139,11 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { @@ -185,6 +189,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver str: return S3 elif adapter_source == generate_pb2.AdapterSource.LOCAL: return LOCAL + elif adapter_source == generate_pb2.AdapterSource.PBASE: + return PBASE else: raise ValueError(f"Unknown adapter source {adapter_source}") \ No newline at end of file diff --git a/server/lorax_server/utils/__init__.py b/server/lorax_server/utils/__init__.py index 336702ae9..910ae613f 100644 --- a/server/lorax_server/utils/__init__.py +++ b/server/lorax_server/utils/__init__.py @@ -11,10 +11,12 @@ get_config_path, get_local_dir, download_weights, + map_pbase_model_id_to_s3, weight_hub_files, weight_files, EntryNotFoundError, HUB, + PBASE, LOCAL, LocalEntryNotFoundError, RevisionNotFoundError, @@ -41,17 +43,19 @@ "get_local_dir", "get_start_stop_idxs_for_rank", "initialize_torch_distributed", + "map_pbase_model_id_to_s3", "download_weights", "weight_hub_files", "EntryNotFoundError", "HeterogeneousNextTokenChooser", "HUB", "LOCAL", + "PBASE", + "S3", "LocalEntryNotFoundError", "RevisionNotFoundError", "Greedy", "NextTokenChooser", - "S3", "Sampling", "StoppingCriteria", "StopSequenceCriteria", diff --git a/server/lorax_server/utils/sources/__init__.py b/server/lorax_server/utils/sources/__init__.py index a212bd83e..33e342d5b 100644 --- a/server/lorax_server/utils/sources/__init__.py +++ b/server/lorax_server/utils/sources/__init__.py @@ -1,4 +1,8 @@ +import os from typing import Optional +from functools import lru_cache + +import requests from .hub import EntryNotFoundError, LocalEntryNotFoundError, RevisionNotFoundError, get_hub_model_local_dir, weight_files, download_weights, weight_hub_files, HubModelSource from .local import LocalModelSource, get_model_local_dir @@ -7,6 +11,32 @@ HUB = "hub" S3 = "s3" LOCAL = "local" +PBASE = "pbase" + +PREDIBASE_MODEL_URL_ENDPOINT = "/v1/models/version/name/{}" +PREDIBASE_MODEL_VERSION_URL_ENDPOINT = "/v1/models/version/name/{}?version={}" +PREDIBASE_GATEWAY_ENDPOINT = os.getenv("PREDIBASE_GATEWAY_ENDPOINT", "https://api.predibase.com") + + +@lru_cache(maxsize=256) +def map_pbase_model_id_to_s3(model_id: str, api_token: str) -> str: + if api_token is None: + raise ValueError("api_token must be provided to for a model of source pbase") + headers = {"Authorization": f"Bearer {api_token}"} + name_components = model_id.split("/") + # version is optional + if len(name_components) == 1: + name = name_components[0] + url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_URL_ENDPOINT.format(name) + elif len(name_components) == 2: + name, version = name_components + url = PREDIBASE_GATEWAY_ENDPOINT + PREDIBASE_MODEL_VERSION_URL_ENDPOINT.format(name, version) + else: + raise ValueError(f"Invalid model id {model_id}") + resp = requests.get(url, headers=headers) + resp.raise_for_status() + uuid, best_run_id = resp.json()["uuid"], resp.json()["bestRunID"] + return f"{uuid}/{best_run_id}/artifacts/model/model_weights/" # TODO(travis): refactor into registry pattern @@ -53,4 +83,5 @@ def get_local_dir(model_id: str, source: str): "RevisionNotFoundError", "get_hub_model_local_dir", "get_s3_model_local_dir", + "map_pbase_model_id_to_s3", ] \ No newline at end of file