Skip to content

Commit

Permalink
Enforce adapters cannot be loaded past --adapter-memory-fraction (#306
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tgaddair authored Mar 7, 2024
1 parent 21631fa commit 6dea404
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/reference/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ Options:
Maximum number of adapters that can be placed on the GPU and accept requests at a time

[env: MAX_ACTIVE_ADAPTERS=]
[default: 128]
[default: 1024]

--adapter-cycle-time-s <ADAPTER_CYCLE_TIME_S>
The time in seconds between adapter exchanges
Expand Down
4 changes: 2 additions & 2 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ struct Args {
max_waiting_tokens: usize,

/// Maximum number of adapters that can be placed on the GPU and accept requests at a time.
#[clap(default_value = "128", long, env)]
#[clap(default_value = "1024", long, env)]
max_active_adapters: usize,

/// The time in seconds between adapter exchanges.
Expand All @@ -275,7 +275,7 @@ struct Args {
/// Increasing this value will reduce the size of the KV cache in exchange for allowing more
/// adapters to be loaded onto the GPU at once.
/// This value is NOT scaled relative to `cuda_memory_fraction`, but is expressed in absolute terms.
#[clap(default_value = "0.0", long, env)]
#[clap(default_value = "0.1", long, env)]
adapter_memory_fraction: f32,

/// The IP address to listen on
Expand Down
7 changes: 7 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ message DownloadAdapterRequest {
message DownloadAdapterResponse {
/// True if download occurred, false if skipped
bool downloaded = 1;

/// Fraction of the adapter memory limit consumed by the adapter.
/// If no limit is set, will return 0.
/// When the total across all loaded adapters exceeds
/// the adapter_memory_fraction limit, no more adapters
/// will be loaded to GPU and LoRAX will begin swapping.
float memory_fraction = 2;
}

message LoadAdapterRequest {
Expand Down
4 changes: 2 additions & 2 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl Client {
adapter_parameters: AdapterParameters,
adapter_source: String,
api_token: Option<String>,
) -> Result<bool> {
) -> Result<DownloadAdapterResponse> {
if let Some(adapter_source_enum) =
AdapterSource::from_str_name(adapter_source.to_uppercase().as_str())
{
Expand All @@ -204,7 +204,7 @@ impl Client {
})
.inject_context();
let response = self.stub.download_adapter(request).await?.into_inner();
Ok(response.downloaded)
Ok(response)
} else {
let err_string = format!(
"Invalid source '{}' when downloading adapter '{}'",
Expand Down
6 changes: 3 additions & 3 deletions router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ pub use client::Client;
pub use pb::generate::v1::HealthResponse;
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
AdapterParameters, Batch, CachedBatch, FinishReason, GeneratedText, Generation,
MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens, Request,
StoppingCriteriaParameters,
AdapterParameters, Batch, CachedBatch, DownloadAdapterResponse, FinishReason, GeneratedText,
Generation, MajoritySignMethod, MergeStrategy, NextTokenChooserParameters, PrefillTokens,
Request, StoppingCriteriaParameters,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
Expand Down
7 changes: 5 additions & 2 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
/// Multi shard Client
use crate::{AdapterParameters, Batch, CachedBatch, Client, Generation, HealthResponse, ShardInfo};
use crate::{
AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation,
HealthResponse, ShardInfo,
};
use crate::{ClientError, Result};
use futures::future::join_all;
use tonic::transport::Uri;
Expand Down Expand Up @@ -155,7 +158,7 @@ impl ShardedClient {
adapter_parameters: AdapterParameters,
adapter_source: String,
api_token: Option<String>,
) -> Result<bool> {
) -> Result<DownloadAdapterResponse> {
// Only download the adapter with one client, since they share a single disk
self.clients[0]
.download_adapter(adapter_parameters, adapter_source, api_token)
Expand Down
3 changes: 2 additions & 1 deletion router/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,13 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
)
.await
{
Ok(_) => {
Ok(resp) => {
tracing::info!("adapter {} downloaded", adapter.as_string());
let mut locked_state = queues_state.lock().unwrap();
if locked_state.has_adapter(&adapter) {
// Above check guards against the case where the adapter was terminated between the initial
// time of request and the time of adapter download
locked_state.set_cost(&adapter, resp.memory_fraction);
locked_state.set_status(&adapter, AdapterStatus::Downloaded);
}
}
Expand Down
68 changes: 66 additions & 2 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub(crate) struct QueueState {
/// Adapter status
status: AdapterStatus,

/// Cost as a fraction of the adapter memory budget
cost: Option<f32>,

/// Timestamp when the adapter was last activated
activation_ts: Option<Instant>,

Expand All @@ -80,6 +83,7 @@ impl QueueState {
entries: VecDeque::with_capacity(128),
adapter,
status,
cost: None,
activation_ts: None,
event,
}
Expand Down Expand Up @@ -143,6 +147,14 @@ impl QueueState {
&self.status
}

pub(crate) fn set_cost(&mut self, cost: f32) {
self.cost = Some(cost);
}

pub(crate) fn cost(&self) -> Option<f32> {
self.cost
}

pub(crate) fn set_activation_ts(&mut self, ts: Instant) {
self.activation_ts = Some(ts);
}
Expand All @@ -169,6 +181,9 @@ pub(crate) struct AdapterQueuesState {
/// Number of adapters that can be active at a time
max_active_adapters: usize,

/// Fraction of adapter memory budget remaining to allocate to new adapters
memory_budget_remaining: f32,

/// Maximum time an adapter is allowed to be active before exchanging out
max_active_time: Duration,

Expand All @@ -189,6 +204,7 @@ impl AdapterQueuesState {
active_adapters,
tracked_adapters,
max_active_adapters: max_active_adapters,
memory_budget_remaining: 1.0,
max_active_time: Duration::from_secs(adapter_cycle_time_s),
next_id: 0,
}
Expand Down Expand Up @@ -255,6 +271,17 @@ impl AdapterQueuesState {
errored_adapters
}

pub(crate) fn set_cost(&mut self, adapter: &Adapter, cost: f32) {
let q = self.queue_map.get_mut(adapter);
if q.is_none() {
// TODO(travis): remove this
tracing::error!("adapter {} not found in queue_map", adapter.as_string());
println!("{:?}", Backtrace::force_capture());
}
let queue = q.unwrap();
queue.set_cost(cost);
}

pub(crate) fn set_status(&mut self, adapter: &Adapter, status: AdapterStatus) {
let q = self.queue_map.get_mut(adapter);
if q.is_none() {
Expand Down Expand Up @@ -388,20 +415,57 @@ impl AdapterQueuesState {
}
}

// Add back cost for all offload adapters
for adapter in offload_adapters.iter() {
let queue = self.queue_map.get(adapter).unwrap().clone();
let cost = queue.cost().unwrap();
self.memory_budget_remaining += cost;
tracing::info!(
"offloading adapter {} with cost {} (memory budget remaining: {})",
adapter.as_string(),
cost,
self.memory_budget_remaining
);
}

// Add pending adapters to the active set until we reach the max
while self.active_adapters.len() < self.max_active_adapters
&& self.pending_adapters.len() > 0
{
let adapter = self.pending_adapters.pop_front().unwrap();
let queue = self
.queue_map
.get_mut(self.pending_adapters.front().unwrap())
.unwrap();
if queue.cost().is_none() {
// Adapter has not been downloaded yet
break;
}

// Check to see that we have enough memory budget remaining to load the adapter
let cost = queue.cost().unwrap();
if cost > self.memory_budget_remaining {
// Adapter is too expensive to load
break;
}

// Update activation timestamp
let queue = self.queue_map.get_mut(&adapter).unwrap();
let adapter = self.pending_adapters.pop_front().unwrap();
queue.set_activation_ts(now);

// Calculate remaining memory budget
self.memory_budget_remaining -= cost;

// Start async loading process
load_adapters.push(adapter.clone());

self.active_adapters.push_back(adapter.clone());

tracing::info!(
"loading adapter {} with cost {} (memory budget remaining: {})",
adapter.as_string(),
cost,
self.memory_budget_remaining
);
}

(offload_adapters, load_adapters)
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def serve(
model_id, adapter_id, revision, sharded, quantize, compile, dtype, trust_remote_code, uds_path, source, adapter_source
)


def _download_weights(
model_id: str,
revision: Optional[str] = None,
Expand Down
6 changes: 5 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lorax_server.utils.tokenizer import TokenizerManager


ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.0"))
ADAPTER_MEMORY_FRACTION = float(os.getenv("ADAPTER_MEMORY_FRACTION", "0.1"))

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -722,6 +722,10 @@ def __init__(
@property
def batch_type(self) -> Type[FlashCausalLMBatch]:
return FlashCausalLMBatch

def adapter_memory_size(self) -> int:
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
return ADAPTER_MEMORY_FRACTION * total_gpu_memory

def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
torch.cuda.empty_cache()
Expand Down
3 changes: 3 additions & 0 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def info(self) -> InfoResponse:
@abstractmethod
def batch_type(self) -> Type[B]:
raise NotImplementedError

def adapter_memory_size(self) -> int:
return 0

@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
Expand Down
37 changes: 34 additions & 3 deletions server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from typing import List, Optional

from lorax_server.cache import Cache
from lorax_server.cli import download_weights
from lorax_server.cli import _download_weights
from lorax_server.interceptor import ExceptionInterceptor
from lorax_server.models import Model, get_model
from lorax_server.pb import generate_pb2_grpc, generate_pb2
from lorax_server.tracing import UDSOpenTelemetryAioServerInterceptor
from lorax_server.utils import HUB, LOCAL, S3, PBASE, get_config_path, get_local_dir, map_pbase_model_id_to_s3
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, is_base_model
from lorax_server.utils.sources import get_model_source


class LoraxService(generate_pb2_grpc.LoraxServiceServicer):
Expand Down Expand Up @@ -132,6 +133,7 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co
logger.info("No adapter to download for base model. Skipping.")
return generate_pb2.DownloadAdapterResponse(downloaded=False)

adapter_bytes = 0
api_token = request.api_token
adapter_source = _adapter_source_enum_to_string(request.adapter_source)
for adapter_id in adapter_parameters.adapter_ids:
Expand All @@ -153,7 +155,13 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co
config_path = get_config_path(adapter_id, adapter_source)
PeftConfig.from_pretrained(config_path, token=api_token)

download_weights(adapter_id, source=adapter_source, api_token=api_token)
_download_weights(
adapter_id, source=adapter_source, api_token=api_token
)

# Calculate size of adapter to be loaded
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
adapter_bytes += source.get_weight_bytes()
except Exception:
logger.exception("Error when downloading adapter")

Expand All @@ -168,7 +176,26 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co
f"download error: {e}\nIgnoring.")
raise

return generate_pb2.DownloadAdapterResponse(downloaded=True)
adapter_memory_size = self.model.adapter_memory_size()
if adapter_memory_size > 0:
logger.info(f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes "
f"(reservation: {adapter_memory_size} bytes)")
adapter_memory_fraction = adapter_bytes / adapter_memory_size
if adapter_memory_fraction > 1:
raise ValueError(
f"Adapter {adapter_id} is larger than adapter memory reservation: "
f"{adapter_bytes} / {adapter_memory_size} bytes"
)
else:
# Assume 0.0 memory fraction if adapter memory size is not set
logger.info(f"Downloaded adapter {adapter_id} memory size: {adapter_bytes} bytes "
f"(no reservation limit)")
adapter_memory_fraction = 0.0

return generate_pb2.DownloadAdapterResponse(
downloaded=True,
memory_fraction=adapter_memory_fraction
)

async def LoadAdapter(self, request: generate_pb2.LoadAdapterRequest, context):
adapter_parameters = request.adapter_parameters
Expand Down Expand Up @@ -206,6 +233,10 @@ async def OffloadAdapter(self, request: generate_pb2.OffloadAdapterRequest, cont
adapter_source = _adapter_source_enum_to_string(request.adapter_source)
adapter_index = request.adapter_index
self.model.offload_adapter(adapter_idx, adapter_source, adapter_index)

# Ensure there is enough memory for the next adapter
torch.cuda.empty_cache()
torch.cuda.synchronize(self.model.device)

return generate_pb2.OffloadAdapterResponse(offloaded=True)
except Exception:
Expand Down
Loading

0 comments on commit 6dea404

Please sign in to comment.