Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hanging caused by tqdm stderr not being printed #352

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ fn shard_manager(
// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));

// Disable progress bars to prevent hanging in containers
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
Expand Down Expand Up @@ -564,13 +567,20 @@ fn shard_manager(
}
};

// Redirect STDOUT to the console
let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap());
let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap());
let shard_stdout = BufReader::new(p.stdout.take().unwrap());

thread::spawn(move || {
log_lines(shard_stdout.lines());
});

let shard_stderr = BufReader::new(p.stderr.take().unwrap());

//stdout tracing thread
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
log_lines(shard_stdout_reader.lines());
for line in shard_stderr.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});

let mut ready = false;
Expand All @@ -579,13 +589,6 @@ fn shard_manager(
loop {
// Process exited
if let Some(exit_status) = p.try_wait().unwrap() {
// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in shard_stderr_reader.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});
let mut err = String::new();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
Expand Down Expand Up @@ -796,6 +799,9 @@ fn download_convert_model(
envs.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};

// Disable progress bars to prevent hanging in containers
envs.push(("HF_HUB_DISABLE_PROGRESS_BARS".into(), "1".into()));

// Enable hf transfer for insane download speeds
let enable_hf_transfer = env::var("HF_HUB_ENABLE_HF_TRANSFER").unwrap_or("1".to_string());
envs.push((
Expand Down Expand Up @@ -840,12 +846,20 @@ fn download_convert_model(
}
};

// Redirect STDOUT to the console
let download_stdout = download_process.stdout.take().unwrap();
let stdout = BufReader::new(download_stdout);
let download_stdout = BufReader::new(download_process.stdout.take().unwrap());

thread::spawn(move || {
log_lines(stdout.lines());
log_lines(download_stdout.lines());
});

let download_stderr = BufReader::new(download_process.stderr.take().unwrap());

// We read stderr in another thread as it seems that lines() can block in some cases
let (err_sender, err_receiver) = mpsc::channel();
thread::spawn(move || {
for line in download_stderr.lines().flatten() {
err_sender.send(line).unwrap_or(());
}
});

loop {
Expand All @@ -856,12 +870,9 @@ fn download_convert_model(
}

let mut err = String::new();
download_process
.stderr
.take()
.unwrap()
.read_to_string(&mut err)
.unwrap();
while let Ok(line) = err_receiver.recv_timeout(Duration::from_millis(10)) {
err = err + "\n" + &line;
}
if let Some(signal) = status.signal() {
tracing::error!(
"Download process was signaled to shutdown with signal {signal}: {err}"
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,10 +784,12 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):
)

with warmup_mode():
logger.info("Warming up to max_total_tokens: {}", max_new_tokens)
with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar:
for _ in range(max_new_tokens):
_, batch = self.generate_token(batch, is_warmup=True)
pbar.update(1)
logger.info("Finished generating warmup tokens")
except RuntimeError as e:
if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError):
raise RuntimeError(
Expand Down
Loading