Skip to content

Commit

Permalink
add max-active-adapters, adapter-cycle-time-s as cli args (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
noyoshi authored Nov 16, 2023
1 parent f9e1e70 commit 7d46c4a
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 7 deletions.
4 changes: 3 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ impl Infer {
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
max_active_adapters: usize,
adapter_cycle_time_s: u64,
requires_padding: bool,
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
Expand All @@ -56,7 +58,7 @@ impl Infer {
});

// Routes requests to the appropriate adapter queue
let adapter_scheduler = AdapterScheduler::new(client.clone(), adapter_event.clone(), requires_padding, 16, window_size);
let adapter_scheduler = AdapterScheduler::new(client.clone(), adapter_event.clone(), requires_padding, 16, window_size, max_active_adapters, adapter_cycle_time_s);

// Initialize with base model adapter (empty) mapping to index 0
let adapter_to_index = Arc::new(Mutex::new(HashMap::from([("".to_string(), 0)])));
Expand Down
8 changes: 8 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ struct Args {
max_batch_total_tokens: Option<u32>,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "128", long, env)]
max_active_adapters: usize,
#[clap(default_value = "2", long, env)]
adapter_cycle_time_s: u64,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
Expand Down Expand Up @@ -81,6 +85,8 @@ fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_active_adapters,
adapter_cycle_time_s,
hostname,
port,
master_shard_uds_path,
Expand Down Expand Up @@ -261,6 +267,8 @@ fn main() -> Result<(), RouterError> {
max_batch_prefill_tokens,
max_supported_batch_total_tokens,
max_waiting_tokens,
max_active_adapters,
adapter_cycle_time_s,
sharded_client,
tokenizer,
validation_workers,
Expand Down
6 changes: 3 additions & 3 deletions router/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ pub(crate) struct AdapterQueuesState {
}

impl AdapterQueuesState {
pub(crate) fn new() -> Self {
pub(crate) fn new(max_active_adapters: usize, adapter_cycle_time_s: u64) -> Self {
let queue_map = HashMap::new();
let pending_adapters = VecDeque::new();
let active_adapters = VecDeque::new();
Expand All @@ -174,8 +174,8 @@ impl AdapterQueuesState {
pending_adapters,
active_adapters,
tracked_adapters,
max_active_adapters: 128,
max_active_time: Duration::from_secs(2),
max_active_adapters: max_active_adapters,
max_active_time: Duration::from_secs(adapter_cycle_time_s),
next_id: 0,
}
}
Expand Down
12 changes: 9 additions & 3 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ impl AdapterScheduler {
requires_padding: bool,
block_size: u32,
window_size: Option<u32>,
max_active_adapters: usize,
adapter_cycle_time_s: u64,
) -> Self {
let (sender, receiver) = flume::unbounded();

Expand All @@ -43,6 +45,8 @@ impl AdapterScheduler {
block_size,
window_size,
receiver,
max_active_adapters,
adapter_cycle_time_s,
));

Self {
Expand Down Expand Up @@ -102,8 +106,10 @@ async fn adapter_scheduler_task(
block_size: u32,
window_size: Option<u32>,
receiver: flume::Receiver<AdapterSchedulerCommand>,
max_active_adapters: usize,
adapter_cycle_time_s: u64,
) {
let mut state = AdapterSchedulerState::new(client, requires_padding, block_size, window_size);
let mut state = AdapterSchedulerState::new(client, requires_padding, block_size, window_size, max_active_adapters, adapter_cycle_time_s);

while let Ok(cmd) = receiver.recv_async().await {
match cmd {
Expand Down Expand Up @@ -152,8 +158,8 @@ struct AdapterSchedulerState {
}

impl AdapterSchedulerState {
fn new(client: ShardedClient, requires_padding: bool, block_size: u32, window_size: Option<u32>) -> Self {
let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new()));
fn new(client: ShardedClient, requires_padding: bool, block_size: u32, window_size: Option<u32>, max_active_adapters: usize, adapter_cycle_time_s: u64) -> Self {
let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new(max_active_adapters, adapter_cycle_time_s)));
let loader = AdapterLoader::new(client.clone());

Self {
Expand Down
4 changes: 4 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ pub async fn run(
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_active_adapters: usize,
adapter_cycle_time_s: u64,
client: ShardedClient,
tokenizer: Option<Tokenizer>,
validation_workers: usize,
Expand Down Expand Up @@ -590,6 +592,8 @@ pub async fn run(
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,
max_active_adapters,
adapter_cycle_time_s,
shard_info.requires_padding,
shard_info.window_size,
generation_health,
Expand Down

0 comments on commit 7d46c4a

Please sign in to comment.