Skip to content

Commit

Permalink
lorax launcher now has --default-adapter-source (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
noyoshi authored Apr 17, 2024
1 parent 04f96b2 commit cc2e0a9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
24 changes: 21 additions & 3 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use nix::unistd::Pid;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Lines, Read};
use std::io::{BufRead, BufReader, Lines};
use std::os::unix::process::{CommandExt, ExitStatusExt};
use std::path::Path;
use std::process::{Child, Command, ExitStatus, Stdio};
Expand Down Expand Up @@ -115,6 +115,15 @@ struct Args {
#[clap(default_value = "hub", long, env)]
source: String,

/// The default source of the dynamic adapters to load.
/// If not defined, we fallback to the value from `adapter_source`
/// Can be `hub` or `s3` or `pbase`
/// `hub` will load the model from the huggingface hub.
/// `s3` will load the model from the predibase S3 bucket.
/// `pbase` will load an s3 model but resolve the metadata from a predibase server
#[clap(long, env)]
default_adapter_source: Option<String>,

/// The source of the static adapter to load.
/// Can be `hub` or `s3` or `pbase`
/// `hub` will load the model from the huggingface hub.
Expand Down Expand Up @@ -1041,9 +1050,18 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(),
args.model_id,
"--adapter-source".to_string(),
args.adapter_source,
];
// Set the default adapter source as "default_adapter_source" if defined, otherwise, "adapter_source"
// adapter_source in the router is used to set the default adapter source for dynamically loaded adapters.
let adapter_source;
if let Some(default_adapter_source) = args.default_adapter_source {
adapter_source = default_adapter_source
} else {
adapter_source = args.adapter_source
}

router_args.push("--adapter-source".to_string());
router_args.push(adapter_source.to_string());

// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def serve(
adapter_source: str = "hub",
speculative_tokens: int = 0,
):

if sharded:
assert os.getenv("RANK", None) is not None, "RANK must be set when sharded is True"
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ async def serve_inner(
dtype,
trust_remote_code,
source,
adapter_source,
adapter_source
)
except Exception:
logger.exception("Error when initializing model")
Expand Down

0 comments on commit cc2e0a9

Please sign in to comment.