Skip to content

Commit

Permalink
Fix: short circuit download, load, offload for preloaded adapters (#552)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 23, 2024
1 parent 59631a0 commit 07addea
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 53 deletions.
6 changes: 0 additions & 6 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,12 +1164,6 @@ fn spawn_webserver(
router_args.push("--eager-prefill".to_string());
}

// Preloaded adapters
for adapter_id in args.preloaded_adapter_ids {
router_args.push("--preloaded-adapter-ids".to_string());
router_args.push(adapter_id);
}

// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
10 changes: 10 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ service LoraxService {
message HealthRequest {}
message HealthResponse {}

message PreloadedAdapter {
/// Adapter params
AdapterParameters adapter_parameters = 1;
/// Adapter source
AdapterSource adapter_source = 2;
/// Adapter index
uint32 adapter_index = 3;
}

/// Empty request
message InfoRequest {}

Expand All @@ -44,6 +53,7 @@ message InfoResponse {
optional uint32 window_size = 4;
uint32 block_size = 5;
uint32 speculate = 6;
repeated PreloadedAdapter preloaded_adapters = 7;
}

/// Empty request
Expand Down
4 changes: 2 additions & 2 deletions router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, Embedding,
Entity, EntityList, FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy,
NextTokenChooserParameters, NextTokens, PrefillTokens, Request, StoppingCriteriaParameters,
TokenizedInputs,
NextTokenChooserParameters, NextTokens, PrefillTokens, PreloadedAdapter, Request,
StoppingCriteriaParameters, TokenizedInputs,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
Expand Down
20 changes: 12 additions & 8 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use futures::stream::StreamExt;
use itertools::multizip;
use lorax_client::{
Batch, CachedBatch, ClientError, Embedding, EntityList, GeneratedText, Generation,
PrefillTokens, ShardedClient,
PrefillTokens, PreloadedAdapter, ShardedClient,
};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
Expand Down Expand Up @@ -151,9 +151,9 @@ impl Infer {
generation_health: Arc<AtomicBool>,
eager_prefill: bool,
tokenizer_config: HubTokenizerConfig,
preloaded_adapter_ids: Vec<String>,
block_size: u32,
speculate: u32,
preloaded_adapters: Vec<PreloadedAdapter>,
) -> Self {
let adapter_event = Arc::new(AdapterEvent {
batching_task: Notify::new(),
Expand Down Expand Up @@ -182,12 +182,16 @@ impl Infer {
)]);

// Pre-populate the adapter_to_index with the preloaded adapters
for (idx, adapter_id) in preloaded_adapter_ids.iter().enumerate() {
let adapter_key = AdapterParameters {
adapter_ids: vec![adapter_id.clone()],
..Default::default()
};
adapter_to_index.insert(adapter_key, (idx + 1) as u32);
for adapter in preloaded_adapters.iter() {
if let Some(adapter_parameters) = &adapter.adapter_parameters {
adapter_to_index.insert(
AdapterParameters {
adapter_ids: adapter_parameters.adapter_ids.clone(),
..Default::default()
},
adapter.adapter_index,
);
}
}

let adapter_to_index = Arc::new(Mutex::new(adapter_to_index));
Expand Down
4 changes: 0 additions & 4 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ struct Args {
adapter_source: String,
#[clap(long, env)]
eager_prefill: bool,
#[clap(long, env)]
preloaded_adapter_ids: Vec<String>,
}

#[tokio::main]
Expand Down Expand Up @@ -130,7 +128,6 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge,
adapter_source,
eager_prefill,
preloaded_adapter_ids,
} = args;

init_logging(otlp_endpoint, json_output);
Expand Down Expand Up @@ -465,7 +462,6 @@ async fn main() -> Result<(), RouterError> {
adapter_source,
embedding_model,
eager_prefill,
preloaded_adapter_ids,
)
.await?;
Ok(())
Expand Down
3 changes: 1 addition & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ pub async fn run(
adapter_source: String,
embedding_model: bool,
eager_prefill: bool,
preloaded_adapter_ids: Vec<String>,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
Expand Down Expand Up @@ -1107,9 +1106,9 @@ pub async fn run(
generation_health,
eager_prefill,
tokenizer_config,
preloaded_adapter_ids,
shard_info.block_size,
shard_info.speculate,
shard_info.preloaded_adapters,
);

// Duration buckets
Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def from_pb(
max_truncation = max(max_truncation, r.truncate)

if all(r.HasField("tokenized_inputs") for r in pb.requests):
batch_tokenized_inputs = [
r.tokenized_inputs.ids[-max_truncation :] for r in pb.requests
]
batch_tokenized_inputs = [r.tokenized_inputs.ids[-max_truncation:] for r in pb.requests]
else:
batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)["input_ids"]

Expand Down
26 changes: 24 additions & 2 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lorax_server.adapters.utils import download_adapter_weights
from lorax_server.adapters.weights import LayerAdapterWeights
from lorax_server.models.types import Batch, GeneratedText
from lorax_server.pb import generate_pb2
from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse
from lorax_server.utils.adapter import (
BASE_MODEL_ADAPTER_ID,
Expand Down Expand Up @@ -64,6 +65,9 @@ def __init__(
self.target_to_layer = self.adapter_target_to_layer()
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
self.preloaded_adapter_indices = set()
self.preloaded_adapter_memory_fractions = {}
self.preloaded_adapters = []

self.trust_remote_code = trust_remote_code

Expand Down Expand Up @@ -93,6 +97,7 @@ def info(self) -> InfoResponse:
window_size=self.sliding_window,
block_size=self.block_size,
speculate=get_speculative_tokens(),
preloaded_adapters=self.preloaded_adapters,
)

@property
Expand Down Expand Up @@ -192,6 +197,18 @@ def max_speculative_tokens(self) -> int:
default=0,
)

def register_preloaded_adapters(
self, preloaded_adapters: List[generate_pb2.PreloadedAdapter], adapter_memory_fractions: List[float]
):
self.preloaded_adapter_indices.update({adapter.adapter_index for adapter in preloaded_adapters})
self.preloaded_adapter_memory_fractions.update(
{
adapter.adapter_parameters.adapter_ids[0]: memory_fraction
for adapter, memory_fraction in zip(preloaded_adapters, adapter_memory_fractions)
}
)
self.preloaded_adapters.extend(preloaded_adapters)

def load_adapter(
self,
adapter_parameters: AdapterParameters,
Expand Down Expand Up @@ -282,11 +299,15 @@ def offload_adapter(
adapter_parameters: AdapterParameters,
adapter_source: AdapterSource,
adapter_index: int,
):
) -> bool:
"""Offloads the adapter weights from GPU to CPU or disk."""
if adapter_index not in self.loaded_adapters:
# Adapter already offloaded
return
return False

if adapter_index in self.preloaded_adapter_indices:
# Adapter was preloaded and should not be offloaded
return False

if not self.supports_adapter_loading:
raise ValueError("This model does not support adapter loading.")
Expand All @@ -304,3 +325,4 @@ def offload_adapter(
self.layer_to_adapter_weights[layer_name].remove_adapter(adapter_index)

self.loaded_adapters.remove(adapter_index)
return True
73 changes: 52 additions & 21 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ async def Decode(self, request: generate_pb2.DecodeRequest, context):
)

async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, context):
if (
len(request.adapter_parameters.adapter_ids) == 1
and request.adapter_parameters.adapter_ids[0] in self.model.preloaded_adapter_memory_fractions
):
logger.info("Adapter is already preloaded. Skipping.")
return generate_pb2.DownloadAdapterResponse(
downloaded=True,
memory_fraction=self.model.preloaded_adapter_memory_fractions[
request.adapter_parameters.adapter_ids[0]
],
)

return download_adapter(request, self.model)

async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context):
Expand All @@ -170,6 +182,10 @@ async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context):
logger.info("No adapter to load for base model. Skipping.")
return generate_pb2.LoadAdapterResponse(loaded=False)

if request.adapter_index in self.model.loaded_adapters:
logger.info(f"Adapter {request.adapter_index} is already loaded. Skipping.")
return generate_pb2.LoadAdapterResponse(loaded=True)

try:
adapter_source = adapter_source_enum_to_string(request.adapter_source)
adapter_index = request.adapter_index
Expand Down Expand Up @@ -199,13 +215,14 @@ async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, cont
adapter_idx = request.adapter_index
adapter_source = adapter_source_enum_to_string(request.adapter_source)
adapter_index = request.adapter_index
self.model.offload_adapter(adapter_idx, adapter_source, adapter_index)

# Ensure there is enough memory for the next adapter
torch.cuda.empty_cache()
torch.cuda.synchronize(self.model.device)
offloaded = self.model.offload_adapter(adapter_idx, adapter_source, adapter_index)
if offloaded:
# Ensure there is enough memory for the next adapter
torch.cuda.empty_cache()
torch.cuda.synchronize(self.model.device)

return generate_pb2.OffloadAdapterResponse(offloaded=True)
return generate_pb2.OffloadAdapterResponse(offloaded=offloaded)
except Exception:
logger.exception("Error when offloading adapter")
raise
Expand Down Expand Up @@ -288,50 +305,64 @@ async def serve_inner(
# Derive the predibase token from an env variable if we are using predibase adapters.
adapter_preload_api_token = os.getenv("PREDIBASE_API_TOKEN")

requests = [
generate_pb2.DownloadAdapterRequest(
preloaded_adapters = [
generate_pb2.PreloadedAdapter(
adapter_parameters=generate_pb2.AdapterParameters(adapter_ids=[adapter_id]),
adapter_source=_adapter_source,
adapter_index=i + 1,
)
for i, adapter_id in enumerate(preloaded_adapter_ids)
]

download_requests = [
generate_pb2.DownloadAdapterRequest(
adapter_parameters=adapter_info.adapter_parameters,
adapter_source=adapter_info.adapter_source,
api_token=adapter_preload_api_token,
)
for adapter_id in preloaded_adapter_ids
for adapter_info in preloaded_adapters
]
models = [model] * len(requests)
models = [model] * len(download_requests)

# Download adapters
t0 = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
responses = list(tqdm(executor.map(download_adapter, requests, models), total=len(requests)))
logger.info(f"Downloaded {len(preloaded_adapter_ids)} adapters in {time.time() - t0:.2f}s")
download_responses = list(
tqdm(executor.map(download_adapter, download_requests, models), total=len(download_requests))
)
logger.info(f"Downloaded {len(download_requests)} adapters in {time.time() - t0:.2f}s")

if not all(responses):
if not all(download_responses):
raise RuntimeError("Failed to download all adapters")

def load_adapter(adapter_id: str, i: int) -> bool:
_adapter_source = adapter_source
if adapter_source == PBASE:
adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token=adapter_preload_api_token)
def load_adapter(adapter_info: generate_pb2.PreloadedAdapter) -> bool:
_adapter_source = adapter_source_enum_to_string(adapter_info.adapter_source)
_adapter_id = adapter_info.adapter_parameters.adapter_ids[0]
if _adapter_source == PBASE:
_adapter_id = map_pbase_model_id_to_s3(_adapter_id, api_token=adapter_preload_api_token)
_adapter_source = S3

model.load_adapter(
generate_pb2.AdapterParameters(adapter_ids=[adapter_id]),
generate_pb2.AdapterParameters(adapter_ids=[_adapter_id]),
_adapter_source,
adapter_index=i + 1,
adapter_index=adapter_info.adapter_index,
api_token=None,
dynamic=True,
)
return True

# Load adapters
t0 = time.time()
indices = list(range(len(preloaded_adapter_ids)))
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
responses = list(tqdm(executor.map(load_adapter, preloaded_adapter_ids, indices), total=len(indices)))
responses = list(tqdm(executor.map(load_adapter, preloaded_adapters), total=len(preloaded_adapters)))

if not all(responses):
raise RuntimeError("Failed to preload all adapters")

logger.info(f"Preloaded {len(preloaded_adapter_ids)} adapters in {time.time() - t0:.2f}s")
logger.info(f"Preloaded {len(preloaded_adapters)} adapters in {time.time() - t0:.2f}s")

adapter_memory_fractions = [r.memory_fraction for r in download_responses]
model.register_preloaded_adapters(preloaded_adapters, adapter_memory_fractions)

# set speculative decoding tokens
speculative_tokens = max(model.max_speculative_tokens, speculative_tokens)
Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def load_module_map(
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer


def download_adapter(request: generate_pb2.DownloadAdapterRequest, model: "Model") -> bool:
def download_adapter(
request: generate_pb2.DownloadAdapterRequest, model: "Model"
) -> generate_pb2.DownloadAdapterResponse:
adapter_parameters = request.adapter_parameters
if is_base_model(adapter_parameters):
logger.info("No adapter to download for base model. Skipping.")
Expand Down
5 changes: 1 addition & 4 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,7 @@ def static(cls, config, dim, base, device, dtype):
scaling_factor=rope_scaling["factor"],
low_freq_factor=rope_scaling["low_freq_factor"],
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
return cls(
inv_freq,
Expand Down Expand Up @@ -749,7 +747,6 @@ def apply_llama3_scaling(
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scaling_factor)
else:

assert low_freq_wavelen != high_freq_wavelen
smooth = (original_max_position_embeddings / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
Expand Down

0 comments on commit 07addea

Please sign in to comment.