Skip to content

Commit

Permalink
Preload adapters during init (#543)
Browse files Browse the repository at this point in the history
Co-authored-by: Noah Yoshida <[email protected]>
  • Loading branch information
tgaddair and noyoshi authored Jul 17, 2024
1 parent 2dd5277 commit 5c25e26
Show file tree
Hide file tree
Showing 13 changed files with 196 additions and 70 deletions.
19 changes: 19 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ struct Args {
#[clap(long, env)]
speculative_tokens: Option<usize>,

/// The list of adapter ids to preload during initialization (to avoid cold start times).
#[clap(long, env)]
preloaded_adapter_ids: Vec<String>,

/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long, env, value_enum)]
dtype: Option<Dtype>,
Expand Down Expand Up @@ -416,6 +420,7 @@ fn shard_manager(
quantize: Option<Quantization>,
compile: bool,
speculative_tokens: Option<usize>,
preloaded_adapter_ids: Vec<String>,
dtype: Option<Dtype>,
trust_remote_code: bool,
uds_path: String,
Expand Down Expand Up @@ -493,6 +498,12 @@ fn shard_manager(
shard_args.push(speculative_tokens.to_string())
}

// Preloaded adapters
for adapter_id in preloaded_adapter_ids {
shard_args.push("--preloaded-adapter-ids".to_string());
shard_args.push(adapter_id);
}

if let Some(dtype) = dtype {
shard_args.push("--dtype".to_string());
shard_args.push(dtype.to_string())
Expand Down Expand Up @@ -959,6 +970,7 @@ fn spawn_shards(
let quantize = args.quantize;
let compile = args.compile;
let speculative_tokens = args.speculative_tokens;
let preloaded_adapter_ids = args.preloaded_adapter_ids.clone();
let dtype = args.dtype;
let trust_remote_code = args.trust_remote_code;
let master_port = args.master_port;
Expand All @@ -977,6 +989,7 @@ fn spawn_shards(
quantize,
compile,
speculative_tokens,
preloaded_adapter_ids,
dtype,
trust_remote_code,
uds_path,
Expand Down Expand Up @@ -1140,6 +1153,12 @@ fn spawn_webserver(
router_args.push("--eager-prefill".to_string());
}

// Preloaded adapters
for adapter_id in args.preloaded_adapter_ids {
router_args.push("--preloaded-adapter-ids".to_string());
router_args.push(adapter_id);
}

// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
16 changes: 14 additions & 2 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl Infer {
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
eager_prefill: bool,
preloaded_adapter_ids: Vec<String>,
) -> Self {
let adapter_event = Arc::new(AdapterEvent {
batching_task: Notify::new(),
Expand All @@ -76,13 +77,24 @@ impl Infer {
);

// Initialize with base model adapter (empty) mapping to index 0
let adapter_to_index = Arc::new(Mutex::new(HashMap::from([(
let mut adapter_to_index = HashMap::from([(
AdapterParameters {
adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()],
..Default::default()
},
0,
)])));
)]);

// Pre-populate the adapter_to_index with the preloaded adapters
for (idx, adapter_id) in preloaded_adapter_ids.iter().enumerate() {
let adapter_key = AdapterParameters {
adapter_ids: vec![adapter_id.clone()],
..Default::default()
};
adapter_to_index.insert(adapter_key, (idx + 1) as u32);
}

let adapter_to_index = Arc::new(Mutex::new(adapter_to_index));

// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ struct Args {
adapter_source: String,
#[clap(long, env)]
eager_prefill: bool,
#[clap(long, env)]
preloaded_adapter_ids: Vec<String>,
}

#[tokio::main]
Expand Down Expand Up @@ -125,6 +127,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge,
adapter_source,
eager_prefill,
preloaded_adapter_ids,
} = args;

init_logging(otlp_endpoint, json_output);
Expand Down Expand Up @@ -380,6 +383,7 @@ async fn main() -> Result<(), RouterError> {
adapter_source,
embedding_model,
eager_prefill,
preloaded_adapter_ids,
)
.await?;
Ok(())
Expand Down
2 changes: 2 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,7 @@ pub async fn run(
adapter_source: String,
embedding_model: bool,
eager_prefill: bool,
preloaded_adapter_ids: Vec<String>,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
Expand Down Expand Up @@ -1060,6 +1061,7 @@ pub async fn run(
shard_info.window_size,
generation_health,
eager_prefill,
preloaded_adapter_ids,
);

// Duration buckets
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lorax_server.utils.weights import download_weights


def download_adapter(
def download_adapter_weights(
adapter_id: str,
adapter_source: str,
api_token: Optional[str] = None,
Expand Down
6 changes: 5 additions & 1 deletion server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from enum import Enum
from pathlib import Path
from typing import Optional
from typing import List, Optional

import typer
from loguru import logger
Expand Down Expand Up @@ -47,7 +47,10 @@ def serve(
source: str = "hub",
adapter_source: str = "hub",
speculative_tokens: int = 0,
preloaded_adapter_ids: Optional[List[str]] = typer.Option(None),
):
preloaded_adapter_ids = preloaded_adapter_ids or []

if sharded:
assert os.getenv("RANK", None) is not None, "RANK must be set when sharded is True"
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
Expand Down Expand Up @@ -94,6 +97,7 @@ def serve(
source,
adapter_source,
speculative_tokens,
preloaded_adapter_ids,
)


Expand Down
7 changes: 3 additions & 4 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,11 @@ def get_model(
if model_type == "distilbert":
from lorax_server.models.flash_distilbert import FlashDistilBert

if config_dict['architectures'][0] == 'DistilBertForMaskedLM':
if config_dict["architectures"][0] == "DistilBertForMaskedLM":
return FlashDistilBert(model_id, revision=revision, dtype=dtype)

if config_dict['architectures'][0] == 'DistilBertForTokenClassification':
return FlashDistilBert(model_id, revision=revision, dtype=dtype, classifcation_head=True)

if config_dict["architectures"][0] == "DistilBertForTokenClassification":
return FlashDistilBert(model_id, revision=revision, dtype=dtype, classifcation_head=True)

if model_id.startswith("bigcode/") or model_type == "gpt_bigcode":
from lorax_server.models.flash_santacoder import FlashSantacoderSharded
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/models/flash_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ]
ROW_PARALLEL = {O_PROJ, DOWN_PROJ}


class FlashGemma2(FlashCausalLM):
def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loguru import logger
from transformers import PreTrainedTokenizerBase

from lorax_server.adapters.utils import download_adapter
from lorax_server.adapters.utils import download_adapter_weights
from lorax_server.adapters.weights import LayerAdapterWeights
from lorax_server.models.types import Batch, GeneratedText
from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
self.has_position_ids = inspect.signature(model.forward).parameters.get("position_ids", None) is not None

if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID:
download_adapter(adapter_id, adapter_source, api_token=None)
download_adapter_weights(adapter_id, adapter_source, api_token=None)
self.load_adapter(
AdapterParameters(adapter_ids=[adapter_id]),
adapter_source,
Expand Down
Loading

0 comments on commit 5c25e26

Please sign in to comment.