diff --git a/launcher/src/main.rs b/launcher/src/main.rs index f7a6d475c..3773dc07c 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -497,6 +497,9 @@ fn shard_manager( // Safetensors load fast envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); + // Disable progress bars to prevent hanging in containers + envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); + // Enable hf transfer for insane download speeds let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); envs.push(( @@ -564,13 +567,20 @@ fn shard_manager( } }; - // Redirect STDOUT to the console - let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); - let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); + let shard_stdout = BufReader::new(p.stdout.take().unwrap()); + + thread::spawn(move || { + log_lines(shard_stdout.lines()); + }); + + let shard_stderr = BufReader::new(p.stderr.take().unwrap()); - //stdout tracing thread + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); thread::spawn(move || { - log_lines(shard_stdout_reader.lines()); + for line in shard_stderr.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } }); let mut ready = false; @@ -579,13 +589,6 @@ fn shard_manager( loop { // Process exited if let Some(exit_status) = p.try_wait().unwrap() { - // We read stderr in another thread as it seems that lines() can block in some cases - let (err_sender, err_receiver) = mpsc::channel(); - thread::spawn(move || { - for line in shard_stderr_reader.lines().flatten() { - err_sender.send(line).unwrap_or(()); - } - }); let mut err = String::new(); while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { err = err + "\n" + &line; @@ -796,6 +799,9 @@ fn download_convert_model( envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); }; + // Disable progress bars to prevent hanging in containers + envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into())); + // Enable hf transfer for insane download speeds let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string()); envs.push(( @@ -840,12 +846,20 @@ fn download_convert_model( } }; - // Redirect STDOUT to the console - let download_stdout = download_process.stdout.take().unwrap(); - let stdout = BufReader::new(download_stdout); + let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); thread::spawn(move || { - log_lines(stdout.lines()); + log_lines(download_stdout.lines()); + }); + + let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); + + // We read stderr in another thread as it seems that lines() can block in some cases + let (err_sender, err_receiver) = mpsc::channel(); + thread::spawn(move || { + for line in download_stderr.lines().flatten() { + err_sender.send(line).unwrap_or(()); + } }); loop { @@ -856,12 +870,9 @@ fn download_convert_model( } let mut err = String::new(); - download_process - .stderr - .take() - .unwrap() - .read_to_string(&mut err) - .unwrap(); + while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) { + err = err + "\n" + &line; + } if let Some(signal) = status.signal() { tracing::error!( "Download process was signaled to shutdown with signal {signal}: {err}" diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index b2d4b88fa..7f8355483 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -784,10 +784,12 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int): ) with warmup_mode(): + logger.info("Warming up to max_total_tokens: {}", max_new_tokens) with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar: for _ in range(max_new_tokens): _, batch = self.generate_token(batch, is_warmup=True) pbar.update(1) + logger.info("Finished generating warmup tokens") except RuntimeError as e: if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError): raise RuntimeError(