From 6dea404e479a72cf91a73bbd6ef786a3229ec30d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 6 Mar 2024 16:14:36 -0800 Subject: [PATCH] Enforce adapters cannot be loaded past `--adapter-memory-fraction` (#306) --- docs/reference/launcher.md | 2 +- launcher/src/main.rs | 4 +- proto/generate.proto | 7 ++ router/client/src/client.rs | 4 +- router/client/src/lib.rs | 6 +- router/client/src/sharded_client.rs | 7 +- router/src/loader.rs | 3 +- router/src/queue.rs | 68 ++++++++++++++++++- server/lorax_server/cli.py | 1 + server/lorax_server/models/flash_causal_lm.py | 6 +- server/lorax_server/models/model.py | 3 + server/lorax_server/server.py | 37 +++++++++- server/lorax_server/utils/sources/source.py | 62 ++++++++++++++++- 13 files changed, 191 insertions(+), 19 deletions(-) diff --git a/docs/reference/launcher.md b/docs/reference/launcher.md index f73ad76b0..8cf0a314d 100644 --- a/docs/reference/launcher.md +++ b/docs/reference/launcher.md @@ -147,7 +147,7 @@ Options: Maximum number of adapters that can be placed on the GPU and accept requests at a time [env: MAX_ACTIVE_ADAPTERS=] - [default: 128] + [default: 1024] --adapter-cycle-time-s The time in seconds between adapter exchanges diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2ed19cb5b..f7a6d475c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -264,7 +264,7 @@ struct Args { max_waiting_tokens: usize, /// Maximum number of adapters that can be placed on the GPU and accept requests at a time. - #[clap(default_value = "128", long, env)] + #[clap(default_value = "1024", long, env)] max_active_adapters: usize, /// The time in seconds between adapter exchanges. @@ -275,7 +275,7 @@ struct Args { /// Increasing this value will reduce the size of the KV cache in exchange for allowing more /// adapters to be loaded onto the GPU at once. /// This value is NOT scaled relative to `cuda_memory_fraction`, but is expressed in absolute terms. - #[clap(default_value = "0.0", long, env)] + #[clap(default_value = "0.1", long, env)] adapter_memory_fraction: f32, /// The IP address to listen on diff --git a/proto/generate.proto b/proto/generate.proto index 7a1a76892..fe5d17363 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -288,6 +288,13 @@ message DownloadAdapterRequest { message DownloadAdapterResponse { /// True if download occurred, false if skipped bool downloaded = 1; + + /// Fraction of the adapter memory limit consumed by the adapter. + /// If no limit is set, will return 0. + /// When the total across all loaded adapters exceeds + /// the adapter_memory_fraction limit, no more adapters + /// will be loaded to GPU and LoRAX will begin swapping. + float memory_fraction = 2; } message LoadAdapterRequest { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 0c3a3a46a..b50e14f56 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -193,7 +193,7 @@ impl Client { adapter_parameters: AdapterParameters, adapter_source: String, api_token: Option, - ) -> Result { + ) -> Result { if let Some(adapter_source_enum) = AdapterSource::from_str_name(adapter_source.to_uppercase().as_str()) { @@ -204,7 +204,7 @@ impl Client { }) .inject_context(); let response = self.stub.download_adapter(request).await?.into_inner(); - Ok(response.downloaded) + Ok(response) } else { let err_string = format!( "Invalid source '{}' when downloading adapter '{}'", diff --git a/router/client/src/lib.rs b/router/client/src/lib.rs index 3aa4e10ae..0dcd1f944 100644 --- a/router/client/src/lib.rs +++ b/router/client/src/lib.rs @@ -9,9 +9,9 @@ pub use client::Client; pub use pb::generate::v1::HealthResponse; pub use pb::generate::v1::InfoResponse as ShardInfo; pub use pb::generate::v1::{ - AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation, - MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, Request, - StoppingCriteriaParameters, + AdapterParameters, Batch, CachedBatch, DownloadAdapterResponse, FinishReason, GeneratedText, + Generation, MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, + Request, StoppingCriteriaParameters, }; pub use sharded_client::ShardedClient; use thiserror::Error; diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 13524fc19..7427114f4 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,8 @@ /// Multi shard Client -use crate::{AdapterParameters, Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo}; +use crate::{ + AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation, + HealthResponse, ShardInfo, +}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; @@ -155,7 +158,7 @@ impl ShardedClient { adapter_parameters: AdapterParameters, adapter_source: String, api_token: Option, - ) -> Result { + ) -> Result { // Only download the adapter with one client, since they share a single disk self.clients[0] .download_adapter(adapter_parameters, adapter_source, api_token) diff --git a/router/src/loader.rs b/router/src/loader.rs index 7f7ac0b1b..96b7268a7 100644 --- a/router/src/loader.rs +++ b/router/src/loader.rs @@ -146,12 +146,13 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver { + Ok(resp) => { tracing::info!("adapter {} downloaded", adapter.as_string()); let mut locked_state = queues_state.lock().unwrap(); if locked_state.has_adapter(&adapter) { // Above check guards against the case where the adapter was terminated between the initial // time of request and the time of adapter download + locked_state.set_cost(&adapter, resp.memory_fraction); locked_state.set_status(&adapter, AdapterStatus::Downloaded); } } diff --git a/router/src/queue.rs b/router/src/queue.rs index 4d1cf5d78..b8e344375 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -66,6 +66,9 @@ pub(crate) struct QueueState { /// Adapter status status: AdapterStatus, + /// Cost as a fraction of the adapter memory budget + cost: Option, + /// Timestamp when the adapter was last activated activation_ts: Option, @@ -80,6 +83,7 @@ impl QueueState { entries: VecDeque::with_capacity(128), adapter, status, + cost: None, activation_ts: None, event, } @@ -143,6 +147,14 @@ impl QueueState { &self.status } + pub(crate) fn set_cost(&mut self, cost: f32) { + self.cost = Some(cost); + } + + pub(crate) fn cost(&self) -> Option { + self.cost + } + pub(crate) fn set_activation_ts(&mut self, ts: Instant) { self.activation_ts = Some(ts); } @@ -169,6 +181,9 @@ pub(crate) struct AdapterQueuesState { /// Number of adapters that can be active at a time max_active_adapters: usize, + /// Fraction of adapter memory budget remaining to allocate to new adapters + memory_budget_remaining: f32, + /// Maximum time an adapter is allowed to be active before exchanging out max_active_time: Duration, @@ -189,6 +204,7 @@ impl AdapterQueuesState { active_adapters, tracked_adapters, max_active_adapters: max_active_adapters, + memory_budget_remaining: 1.0, max_active_time: Duration::from_secs(adapter_cycle_time_s), next_id: 0, } @@ -255,6 +271,17 @@ impl AdapterQueuesState { errored_adapters } + pub(crate) fn set_cost(&mut self, adapter: &Adapter, cost: f32) { + let q = self.queue_map.get_mut(adapter); + if q.is_none() { + // TODO(travis): remove this + tracing::error!("adapter {} not found in queue_map", adapter.as_string()); + println!("{:?}", Backtrace::force_capture()); + } + let queue = q.unwrap(); + queue.set_cost(cost); + } + pub(crate) fn set_status(&mut self, adapter: &Adapter, status: AdapterStatus) { let q = self.queue_map.get_mut(adapter); if q.is_none() { @@ -388,20 +415,57 @@ impl AdapterQueuesState { } } + // Add back cost for all offload adapters + for adapter in offload_adapters.iter() { + let queue = self.queue_map.get(adapter).unwrap().clone(); + let cost = queue.cost().unwrap(); + self.memory_budget_remaining += cost; + tracing::info!( + "offloading adapter {} with cost {} (memory budget remaining: {})", + adapter.as_string(), + cost, + self.memory_budget_remaining + ); + } + // Add pending adapters to the active set until we reach the max while self.active_adapters.len() < self.max_active_adapters && self.pending_adapters.len() > 0 { - let adapter = self.pending_adapters.pop_front().unwrap(); + let queue = self + .queue_map + .get_mut(self.pending_adapters.front().unwrap()) + .unwrap(); + if queue.cost().is_none() { + // Adapter has not been downloaded yet + break; + } + + // Check to see that we have enough memory budget remaining to load the adapter + let cost = queue.cost().unwrap(); + if cost > self.memory_budget_remaining { + // Adapter is too expensive to load + break; + } // Update activation timestamp - let queue = self.queue_map.get_mut(&adapter).unwrap(); + let adapter = self.pending_adapters.pop_front().unwrap(); queue.set_activation_ts(now); + // Calculate remaining memory budget + self.memory_budget_remaining -= cost; + // Start async loading process load_adapters.push(adapter.clone()); self.active_adapters.push_back(adapter.clone()); + + tracing::info!( + "loading adapter {} with cost {} (memory budget remaining: {})", + adapter.as_string(), + cost, + self.memory_budget_remaining + ); } (offload_adapters, load_adapters) diff --git a/server/lorax_server/cli.py b/server/lorax_server/cli.py index 632384da2..7ad076bbc 100644 --- a/server/lorax_server/cli.py +++ b/server/lorax_server/cli.py @@ -90,6 +90,7 @@ def serve( model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code, uds_path, source, adapter_source ) + def _download_weights( model_id: str, revision: Optional[str] = None, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index e79d9029c..2fe53ca1f 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -34,7 +34,7 @@ from lorax_server.utils.tokenizer import TokenizerManager -ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.0")) +ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1")) tracer = trace.get_tracer(__name__) @@ -722,6 +722,10 @@ def __init__( @property def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch + + def adapter_memory_size(self) -> int: + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + return ADAPTER_MEMORY_FRACTION * total_gpu_memory def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): torch.cuda.empty_cache() diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index da6c57d03..82423f4b4 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -81,6 +81,9 @@ def info(self) -> InfoResponse: @abstractmethod def batch_type(self) -> Type[B]: raise NotImplementedError + + def adapter_memory_size(self) -> int: + return 0 @abstractmethod def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 611a2fbdf..26b8120a2 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -13,13 +13,14 @@ from typing import List, Optional from lorax_server.cache import Cache -from lorax_server.cli import download_weights +from lorax_server.cli import _download_weights from lorax_server.interceptor import ExceptionInterceptor from lorax_server.models import Model, get_model from lorax_server.pb import generate_pb2_grpc, generate_pb2 from lorax_server.tracing import UDSOpenTelemetryAioServerInterceptor from lorax_server.utils import HUB, LOCAL, S3, PBASE, get_config_path, get_local_dir, map_pbase_model_id_to_s3 from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, is_base_model +from lorax_server.utils.sources import get_model_source class LoraxService(generate_pb2_grpc.LoraxServiceServicer): @@ -132,6 +133,7 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co logger.info("No adapter to download for base model. Skipping.") return generate_pb2.DownloadAdapterResponse(downloaded=False) + adapter_bytes = 0 api_token = request.api_token adapter_source = _adapter_source_enum_to_string(request.adapter_source) for adapter_id in adapter_parameters.adapter_ids: @@ -153,7 +155,13 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co config_path = get_config_path(adapter_id, adapter_source) PeftConfig.from_pretrained(config_path, token=api_token) - download_weights(adapter_id, source=adapter_source, api_token=api_token) + _download_weights( + adapter_id, source=adapter_source, api_token=api_token + ) + + # Calculate size of adapter to be loaded + source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) + adapter_bytes += source.get_weight_bytes() except Exception: logger.exception("Error when downloading adapter") @@ -168,7 +176,26 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co f"download error: {e}\nIgnoring.") raise - return generate_pb2.DownloadAdapterResponse(downloaded=True) + adapter_memory_size = self.model.adapter_memory_size() + if adapter_memory_size > 0: + logger.info(f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes " + f"(reservation: {adapter_memory_size} bytes)") + adapter_memory_fraction = adapter_bytes / adapter_memory_size + if adapter_memory_fraction > 1: + raise ValueError( + f"Adapter {adapter_id} is larger than adapter memory reservation: " + f"{adapter_bytes} / {adapter_memory_size} bytes" + ) + else: + # Assume 0.0 memory fraction if adapter memory size is not set + logger.info(f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes " + f"(no reservation limit)") + adapter_memory_fraction = 0.0 + + return generate_pb2.DownloadAdapterResponse( + downloaded=True, + memory_fraction=adapter_memory_fraction + ) async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context): adapter_parameters = request.adapter_parameters @@ -206,6 +233,10 @@ async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, cont adapter_source = _adapter_source_enum_to_string(request.adapter_source) adapter_index = request.adapter_index self.model.offload_adapter(adapter_idx, adapter_source, adapter_index) + + # Ensure there is enough memory for the next adapter + torch.cuda.empty_cache() + torch.cuda.synchronize(self.model.device) return generate_pb2.OffloadAdapterResponse(offloaded=True) except Exception: diff --git a/server/lorax_server/utils/sources/source.py b/server/lorax_server/utils/sources/source.py index 38bac0b48..f847309c0 100644 --- a/server/lorax_server/utils/sources/source.py +++ b/server/lorax_server/utils/sources/source.py @@ -1,3 +1,4 @@ +import json import os from typing import Optional, List from pathlib import Path @@ -41,7 +42,7 @@ class BaseModelSource: def remote_weight_files(self, extension: str = None): raise NotImplementedError - def weight_files(self, extension: str = None): + def weight_files(self, extension: str = None) -> List[Path]: raise NotImplementedError def download_weights(self, filenames: List[str]): @@ -54,4 +55,61 @@ def download_model_assets(self): for other future sources we might need something different. So this function will take the necessary steps to download the needed files for any source """ - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + def get_weight_bytes(self) -> int: + total_size = 0 + for path in self.weight_files(): + fname = str(path) + + # safetensor format explained here: https://huggingface.co/docs/safetensors/en/index + # parsing taken from: https://github.com/by321/safetensors_util/blob/main/safetensors_file.py + st = os.stat(fname) + if st.st_size < 8: + raise RuntimeError(f"Length of safetensor file less than 8 bytes: {fname}") + + with open(fname, "rb") as f: + # read header size + b8 = f.read(8) + if len(b8) != 8: + raise RuntimeError(f"Failed to read first 8 bytes of safetensor file: {fname}") + + headerlen = int.from_bytes(b8, 'little', signed=False) + if 8 + headerlen > st.st_size: + raise RuntimeError(f"Header extends past end of file: {fname}") + + hdrbuf = f.read(headerlen) + header = json.loads(hdrbuf) + metadata = header.get('__metadata__', {}) + total_size_bytes = metadata.get('total_size') + if total_size_bytes is None: + # Fallback to determining this value from the data offsets + min_data_offset = None + max_data_offset = None + for v in header.values(): + if not isinstance(v, dict): + continue + + data_offsets = v.get('data_offsets') + if data_offsets is None: + continue + + if min_data_offset is not None: + min_data_offset = min(min_data_offset, data_offsets[0]) + else: + min_data_offset = data_offsets[0] + + if max_data_offset is not None: + max_data_offset = max(max_data_offset, data_offsets[1]) + else: + max_data_offset = data_offsets[1] + + if min_data_offset is None or max_data_offset is None: + # Fallback to determining total bytes from file size + total_size_bytes = st.st_size + else: + total_size_bytes = max_data_offset - min_data_offset + + total_size += total_size_bytes + + return total_size