From cc2e0a90380c1342ea39cc483f3db8230cbf8d05 Mon Sep 17 00:00:00 2001 From: Noah Yoshida Date: Tue, 16 Apr 2024 20:02:28 -0700 Subject: [PATCH] lorax launcher now has --default-adapter-source (#419) --- launcher/src/main.rs | 24 +++++++++++++++++++++--- server/lorax_server/cli.py | 1 + server/lorax_server/server.py | 2 +- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c0bbaeaca..d5548030e 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -4,7 +4,7 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines, Read}; +use std::io::{BufRead, BufReader, Lines}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; use std::process::{Child, Command, ExitStatus, Stdio}; @@ -115,6 +115,15 @@ struct Args { #[clap(default_value = "hub", long, env)] source: String, + /// The default source of the dynamic adapters to load. + /// If not defined, we fallback to the value from `adapter_source` + /// Can be `hub` or `s3` or `pbase` + /// `hub` will load the model from the huggingface hub. + /// `s3` will load the model from the predibase S3 bucket. + /// `pbase` will load an s3 model but resolve the metadata from a predibase server + #[clap(long, env)] + default_adapter_source: Option, + /// The source of the static adapter to load. /// Can be `hub` or `s3` or `pbase` /// `hub` will load the model from the huggingface hub. @@ -1041,9 +1050,18 @@ fn spawn_webserver( format!("{}-0", args.shard_uds_path), "--tokenizer-name".to_string(), args.model_id, - "--adapter-source".to_string(), - args.adapter_source, ]; + // Set the default adapter source as "default_adapter_source" if defined, otherwise, "adapter_source" + // adapter_source in the router is used to set the default adapter source for dynamically loaded adapters. + let adapter_source; + if let Some(default_adapter_source) = args.default_adapter_source { + adapter_source = default_adapter_source + } else { + adapter_source = args.adapter_source + } + + router_args.push("--adapter-source".to_string()); + router_args.push(adapter_source.to_string()); // Model optional max batch total tokens if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 5fe7532d1..5bee9a3f6 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -47,6 +47,7 @@ def serve( adapter_source: str = "hub", speculative_tokens: int = 0, ): + if sharded: assert os.getenv("RANK", None) is not None, "RANK must be set when sharded is True" assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True" diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 869b7e02f..7950433a7 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -250,7 +250,7 @@ async def serve_inner( dtype, trust_remote_code, source, - adapter_source, + adapter_source ) except Exception: logger.exception("Error when initializing model")