From 037ea55af3e68f65c1397ff89b54d8a02b50a727 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sun, 8 Dec 2024 12:36:46 +0100 Subject: [PATCH] Attempt for cleverer auto batch_prefill values (some simplifications). --- .../models/test_flash_phi35_moe.py | 1 - launcher/src/main.rs | 223 +++++++++++++++--- .../models/flash_causal_lm.py | 51 ++-- .../text_generation_server/models/globals.py | 2 +- 4 files changed, 230 insertions(+), 47 deletions(-) diff --git a/integration-tests/models/test_flash_phi35_moe.py b/integration-tests/models/test_flash_phi35_moe.py index 0cb8f85d8ec..d3043b028a8 100644 --- a/integration-tests/models/test_flash_phi35_moe.py +++ b/integration-tests/models/test_flash_phi35_moe.py @@ -6,7 +6,6 @@ def flash_phi35_moe_handle(launcher): with launcher( "microsoft/Phi-3.5-MoE-instruct", num_shard=4, - max_batch_prefill_tokens=10000, ) as handle: yield handle diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 671ec2ee5f5..3c18959c255 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -30,21 +30,64 @@ mod env_runtime; mod gpu; fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option { - if let (Some(config), Some(compute)) = (config, compute) { - if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) { - tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}"); - let optimal_size = (f16_max_compute / model_compute) as usize; - if optimal_size > 100 { - // Ignore calculations that's too low - // Most likely an error - Some(optimal_size) - } else { - None - } + let config = config?; + let compute = compute?; + let f16_max_compute = compute.f16_flop()?; + let model_compute = config.flop()?; + tracing::debug!( + "Max compute {} model compute {}", + human_size(f16_max_compute as usize, "flop"), + human_size(model_compute as usize, "flop") + ); + let optimal_size = (f16_max_compute / model_compute) as usize; + if optimal_size > 100 { + // Ignore calculations that's too low + // Most likely an error + Some(optimal_size) + } else { + None + } +} + +fn human_size(size: usize, suffix: &str) -> String { + let mut size: f64 = size as f64; + let mut p = ""; + for prefix in ["", "K", "M", "G", "T"] { + p = prefix; + if size > 1_000.0 { + size /= 1_000.0; } else { - None + break; } + } + format!("{size:.2}{p}{suffix}") +} + +fn vram_maximum( + config: Option<&Config>, + compute: Option<&ComputeType>, + memory_fraction: f32, +) -> Option { + let config = config?; + let compute = compute?; + let available = compute.vram(memory_fraction)?; + let model = config.model_vram()?; + let token_vram = config.token_vram()?; + if let Some(vram) = available.checked_sub(model) { + let tokens_allowed = vram / token_vram; + tracing::debug!( + "Available vram {}: model needs{}, every tokens requires {}, maximum allocatable tokens {tokens_allowed}", + human_size(available, "B"), + human_size(model, "B"), + human_size(token_vram, "B"), + ); + Some(tokens_allowed) } else { + tracing::warn!( + "Not enough VRAM to run the model: Available: {} - Model {}.", + human_size(available, "B"), + human_size(model, "B") + ); None } } @@ -175,6 +218,9 @@ struct RawConfig { num_experts_per_token: Option, #[serde(rename = "n_shared_experts")] num_shared_experts: Option, + #[serde(rename = "num_local_experts")] + num_experts: Option, + vocab_size: Option, } #[derive(Deserialize)] @@ -200,6 +246,8 @@ struct Config { is_encoder_decoder: bool, num_experts_per_token: usize, num_shared_experts: usize, + num_experts: usize, + vocab_size: Option, } impl Config { @@ -231,6 +279,47 @@ impl Config { let total = layer_flops * num_layers; Some(total) } + + fn kv_vram_per_tok(&self) -> Option { + if self.quantize.is_some() { + // TODO handle quantization + return None; + } + // 2 for key and values + // 2 for f16 dtype? + Some(self.num_kv_heads? * 2 * self.head_dim? * 2 * self.num_layers?) + } + + fn mlp_vram_per_tok(&self) -> Option { + // TODO handle quantization + // TODO This calculation depends on the actual implementation + let dtype_size = 2; + let mlp_size = self.intermediate_size?; + Some((mlp_size + mlp_size / 2) * self.num_experts * dtype_size * 3) + } + + fn token_vram(&self) -> Option { + let kv = self.kv_vram_per_tok()?; + let mlp_intermediary = self.mlp_vram_per_tok()?; + let per_tok = kv + mlp_intermediary; + Some(per_tok) + } + + fn model_vram(&self) -> Option { + let attn_vram = (self.num_heads? + 2 * self.num_kv_heads?) * self.head_dim?; + let o_vram = self.num_heads? * self.head_dim? * self.hidden_size?; + // gate + up + down = 3 + let mlp_vram = 3 * self.intermediate_size? * self.num_experts * self.hidden_size?; + let layer_vram = mlp_vram + attn_vram + o_vram; + let vocab = self.hidden_size? * self.vocab_size?; + let params = layer_vram * self.num_layers? + 2 * vocab; + let dtype_size = 2; + if self.quantize.is_some() { + // TODO handle quantization + return None; + } + Some(params * dtype_size) + } } impl From for Config { @@ -260,6 +349,8 @@ impl From for Config { let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); let num_experts_per_token = other.num_experts_per_token.unwrap_or(1); let num_shared_experts = other.num_shared_experts.unwrap_or(0); + let num_experts = other.num_experts.unwrap_or(1); + let vocab_size = other.vocab_size; Config { max_position_embeddings, quantize, @@ -274,6 +365,8 @@ impl From for Config { num_layers, num_experts_per_token, num_shared_experts, + num_experts, + vocab_size, } } } @@ -1528,37 +1621,101 @@ fn spawn_shards( Ok(()) } +#[derive(Debug)] +enum Gpu { + RTX4090, + T4, + L4, + A10G, + H100, + A100, + Unknown(String), +} + #[derive(Debug)] struct ComputeType { count: usize, - card: String, + card: Gpu, +} + +impl From<&str> for Gpu { + fn from(value: &str) -> Self { + match value { + "nvidia-4090" => Gpu::RTX4090, + "nvidia-t4" => Gpu::T4, + "nvidia-l4" => Gpu::L4, + "nvidia-a10g" => Gpu::A10G, + "nvidia-h100-80gb-hbm3" => Gpu::H100, + "nvidia-a100-sxm4-80gb" => Gpu::A100, + "nvidia-a100" => Gpu::A100, + card => Gpu::Unknown(card.to_string()), + } + } +} + +impl std::fmt::Display for Gpu { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Gpu::RTX4090 => write!(f, "nvida-4090"), + Gpu::T4 => write!(f, "nvida-t4"), + Gpu::L4 => write!(f, "nvida-l4"), + Gpu::A10G => write!(f, "nvidia-a10g"), + Gpu::H100 => write!(f, "nvidia-h100-80fb-hbm3"), + Gpu::A100 => write!(f, "nvida-a100-sxm4-80gb"), + Gpu::Unknown(card) => write!(f, "{}", card), + } + } } impl ComputeType { fn f16_flop(&self) -> Option { - let card_flop = match &self.card[..] { + let card_flop = match &self.card { // https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/ // Specs are unclear https://www.itcreations.com/nvidia-gpu/nvidia-geforce-rtx-4090-gpu - "nvidia-4090" => Some(82 * 10u64.pow(12)), + Gpu::RTX4090 => Some(82 * 10u64.pow(12)), // https://www.nvidia.com/en-us/data-center/tesla-t4/ - "nvidia-t4" => Some(65 * 10u64.pow(12)), + Gpu::T4 => Some(65 * 10u64.pow(12)), // https://www.nvidia.com/en-us/data-center/l4/ - "nvidia-l4" => Some(121 * 10u64.pow(12)), + Gpu::L4 => Some(121 * 10u64.pow(12)), // https://www.nvidia.com/en-us/data-center/products/a10-gpu/ - "nvidia-a10g" => Some(125 * 10u64.pow(12)), + Gpu::A10G => Some(125 * 10u64.pow(12)), // https://www.nvidia.com/en-us/data-center/h100/ // https://www.techpowerup.com/gpu-specs/docs/nvidia-gh100-architecture.pdf - "nvidia-h100-80gb-hbm3" => Some(900 * 10u64.pow(12)), + Gpu::H100 => Some(900 * 10u64.pow(12)), // https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf - "nvidia-a100-sxm4-80gb" => Some(312 * 10u64.pow(12)), - "nvidia-a100" => Some(312 * 10u64.pow(12)), - card => { + Gpu::A100 => Some(312 * 10u64.pow(12)), + Gpu::Unknown(card) => { tracing::warn!("Unkown compute for card {card}"); None } }; card_flop.map(|f| f * self.count as u64) } + + fn vram(&self, memory_fraction: f32) -> Option { + let output = Command::new("nvidia-smi") + .args(["--query-gpu=memory.total", "--format=csv"]) + .output() + .ok()?; + let output = String::from_utf8(output.stdout).ok()?; + let fullname = output.split('\n').nth(1)?; + let mut tokens = fullname.split(' '); + let amount = tokens.next()?; + let unit = tokens.next()?; + if unit != "MiB" { + tracing::warn!("Unexpected memory unit {unit}, expected MiB"); + return None; + } + let amount: usize = amount.parse().ok()?; + let amount = amount * 2usize.pow(20); + let wiggle_room: f32 = env::var("TGI_WIGGLE_ROOM") + .ok() + .and_then(|wiggle| wiggle.parse().ok()) + .unwrap_or(0.95); + let total = amount * self.count; + let adjusted = ((total as f32) * memory_fraction * wiggle_room) as usize; + Some(adjusted) + } } impl From for OsString { @@ -1567,7 +1724,7 @@ impl From for OsString { } } -fn compute_type(num_shard: usize) -> Option { +fn compute_type(count: usize) -> Option { let output = Command::new("nvidia-smi") .args(["--query-gpu=gpu_name", "--format=csv"]) .output() @@ -1575,10 +1732,8 @@ fn compute_type(num_shard: usize) -> Option { let output = String::from_utf8(output.stdout).ok()?; let fullname = output.split('\n').nth(1)?; let cardname = fullname.replace(' ', "-").to_lowercase(); - Some(ComputeType { - count: num_shard, - card: cardname, - }) + let card = (&*cardname).into(); + Some(ComputeType { count, card }) } fn spawn_webserver( @@ -1864,16 +2019,28 @@ fn main() -> Result<(), LauncherError> { match args.max_batch_prefill_tokens { Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, None => { - // TODO figure out hardware optimal value let compute_type = compute_type(num_shard); let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref()); let default = compute_optimal.unwrap_or(4096); + let vram_maximum = vram_maximum( + config.as_ref(), + compute_type.as_ref(), + args.cuda_memory_fraction, + ); let max_position_embeddings = config.and_then(|c| c.max_position_embeddings); let value = if let Some(max_position_embeddings) = max_position_embeddings { default.min(max_position_embeddings) } else { default }; + let value = if let Some(vram_maximum) = vram_maximum { + if vram_maximum < value { + tracing::warn!("Reducing the max batch prefill from {default} to {vram_maximum} because there is not enough VRAM to support it."); + } + value.min(vram_maximum) + } else { + value + }; tracing::info!("Default `max_batch_prefill_tokens` to {value}"); value as u32 } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 8989110a7ad..07b7604d693 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1557,11 +1557,22 @@ def warmup( self.kv_cache_dtype, self.device, ) + batch_num_blocks = batch.num_blocks num_tokens = batch.to_pb().current_tokens if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) + synchronize(self.device) + free_memory = get_free_memory( + self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM + ) + real_free_memory = get_free_memory(self.device, MEMORY_FRACTION) + log_master( + logger.debug, + f"Free memory {free_memory/1e9:.2f}GB , (real: {real_free_memory/1e9:.2f}GB", + ) + _, _batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( @@ -1570,12 +1581,11 @@ def warmup( ) from e synchronize(self.device) - - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - + free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM) + kv_memory = free_memory num_blocks = ( # Leave 5% for some wiggle room - int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) + int(kv_memory // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + batch_num_blocks ) @@ -1584,21 +1594,11 @@ def warmup( if max_total_tokens is None: if get_support_chunking(): model_max_length = self.tokenizer.model_max_length - max_input_tokens = ( - min((num_blocks * BLOCK_SIZE - 1), model_max_length) - if max_input_tokens is None - else max_input_tokens - ) - max_total_tokens = num_blocks * BLOCK_SIZE - + max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length) else: max_total_tokens = sum(batch.cache_lengths) - max_input_tokens = ( - max_total_tokens - 1 - if max_input_tokens is None - else max_input_tokens - ) - elif max_input_tokens is None: + + if max_input_tokens is None: max_input_tokens = max_total_tokens - 1 del _batch, batch @@ -1676,8 +1676,25 @@ def warmup( ) # Warmup cuda graphs for bs in CUDA_GRAPHS: + synchronize(self.device) + free_memory = get_free_memory( + self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM + ) + log_master( + logger.debug, + f"Free RAM before cuda graph {bs} {free_memory / 1e9:.2f}GB", + ) if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens) + empty_cache() + synchronize(self.device) + free_memory = get_free_memory( + self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM + ) + log_master( + logger.debug, + f"Free RAM after cuda graphs {free_memory / 1e9:.2f}GB", + ) except torch.cuda.OutOfMemoryError: logger.exception("Decode cuda graph warmup failed") else: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 8d988ad5870..ce8791411f9 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -24,7 +24,7 @@ raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None -TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1