Skip to content

Commit

Permalink
hil: Use multi-part download for hil (#253)
Browse files Browse the repository at this point in the history
* hil: Use multi-part download for hil

AWS rust sdk is slow compared to cli, multipart download can speedup the
process significantly,
see awslabs/aws-sdk-rust#1024

Relates: ORBP-275

* fixup! hil: Use multi-part download for hil

* fixup! hil: Use multi-part download for hil
  • Loading branch information
AlexKaravaev authored Oct 15, 2024
1 parent b34fcbf commit 5404dfd
Showing 1 changed file with 211 additions and 63 deletions.
274 changes: 211 additions & 63 deletions hil/src/download_s3.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,54 @@
use std::{io::IsTerminal, str::FromStr, time::Duration};
use std::{
fs::File,
io::IsTerminal,
ops::RangeInclusive,
os::unix::fs::FileExt,
str::FromStr,
sync::atomic::{AtomicU64, Ordering},
sync::Arc,
time::Duration,
};

use aws_config::{
meta::{credentials::CredentialsProviderChain, region::RegionProviderChain},
retry::RetryConfig,
stalled_stream_protection::StalledStreamProtectionConfig,
BehaviorVersion,
};
use aws_sdk_s3::config::ProvideCredentials;
use aws_sdk_s3::Client;
use camino::Utf8Path;
use color_eyre::{
eyre::{ensure, ContextCompat, OptionExt, WrapErr},
Result, Section,
};
use indicatif::{ProgressState, ProgressStyle};
use tempfile::NamedTempFile;
use tracing::info;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tokio::time::timeout;
use tracing::{info, warn};

#[derive(Debug, Eq, PartialEq)]
const PART_SIZE: u64 = 25 * 1024 * 1024; // 25 MiB
const CONCURRENCY: usize = 16;
const TIMEOUT_RETRY_ATTEMPTS: u32 = 5;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExistingFileBehavior {
/// If a file exists, overwrite it
Overwrite,
/// If a file exists, abort
Abort,
}

/// `out_path` is the final path of the file after downloading.
struct ContentRange(RangeInclusive<u64>);

impl std::fmt::Display for ContentRange {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let start = self.0.start();
let end = self.0.end();
write!(f, "bytes={}-{}", start, end)
}
}

pub async fn download_url(
url: &str,
out_path: &Utf8Path,
Expand All @@ -33,84 +59,157 @@ pub async fn download_url(
}
let parent_dir = out_path
.parent()
.expect("please provide the path to a file");
.expect("please provide the path to a file")
.to_owned();
ensure!(
parent_dir.try_exists().unwrap_or(false),
"parent directory {parent_dir} doesn't exist"
);
let s3_parts: S3UrlParts = url.parse().wrap_err("invalid s3 url")?;
let (tmp_file, tmp_file_path) =
tempfile::NamedTempFile::new_in(out_path.parent().unwrap())
.wrap_err("failed to create tempfile")?
.into_parts();
let mut tmp_file: tokio::fs::File = tmp_file.into();

let start_time = std::time::Instant::now();
let resp = client()
.await?
.get_object()
.bucket(s3_parts.bucket)
.key(s3_parts.key)
let client = client().await?;
let head_resp = client
.head_object()
.bucket(&s3_parts.bucket)
.key(&s3_parts.key)
.send()
.await
.wrap_err("failed to make aws get_object request")?;
let bytes_to_download = resp
.content_length()
.ok_or_eyre("expected a content length")?;
.wrap_err("failed to make aws head_object request")?;

let is_interactive = std::io::stdout().is_terminal();
if is_interactive {
info!("we are interactive");
} else {
info!("we are not interactive");
}
let bytes_to_download = head_resp.content_length().unwrap();

let bytes_to_download: u64 = bytes_to_download
.try_into()
.expect("Download size is too large to fit into u64");

let bytes_to_download: u64 = bytes_to_download.try_into().expect("overflow");
let is_interactive = std::io::stdout().is_terminal();
let pb = indicatif::ProgressBar::new(bytes_to_download);
pb.set_style(ProgressStyle::with_template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
pb.set_style(
indicatif::ProgressStyle::with_template(
"{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})",
)
.unwrap()
.with_key("eta", |state: &ProgressState, w: &mut dyn std::fmt::Write| write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap())
.progress_chars("#>-"));

let mut bytes_so_far = 0;
let mut pct = 0;
let reader =
tokio_util::io::InspectReader::new(resp.body.into_async_read(), |bytes| {
if !is_interactive {
bytes_so_far += bytes.len() as u64;
let new_pct = bytes_so_far * 100 / bytes_to_download;
if new_pct > pct {
info!(
"Downloaded: ({}/{} MiB) {}%",
bytes_so_far >> 20,
bytes_to_download >> 20,
new_pct,
);
.with_key("eta", |state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
})
.progress_chars("#>-"),
);

let step_size = PART_SIZE
.try_into()
.expect("PART_SIZE is too large to fit into usize");
let ranges = (0..bytes_to_download).step_by(step_size).map(move |start| {
let end = std::cmp::min(start + PART_SIZE - 1, bytes_to_download - 1);
ContentRange(start..=end)
});

let (tmp_file, tmp_file_path) = tokio::task::spawn_blocking(move || {
let tmp_file = tempfile::NamedTempFile::new_in(parent_dir)
.wrap_err("failed to create tempfile")?;
tmp_file.as_file().set_len(bytes_to_download)?;
Ok::<_, color_eyre::Report>(tmp_file.into_parts())
})
.await?
.wrap_err("failed to create tempfile")?;

let tmp_file: Arc<File> = Arc::new(tmp_file);
let tmp_file_path = Arc::new(tmp_file_path);

let ranges = Arc::new(Mutex::new(ranges));
let mut tasks = JoinSet::new();
let bytes_downloaded = Arc::new(AtomicU64::new(0));

for _ in 0..CONCURRENCY {
let ranges = Arc::clone(&ranges);
let client = client.clone();
let bucket = s3_parts.bucket.clone();
let key = s3_parts.key.clone();
let tmp_file = Arc::clone(&tmp_file);
let pb = pb.clone();
let bytes_downloaded = Arc::clone(&bytes_downloaded);

tasks.spawn(async move {
loop {
let range_option = {
let mut ranges_lock = ranges.lock().await;
ranges_lock.next()
};

let Some(range) = range_option else {
break;
};

let body =
download_part_retry_on_timeout(&range, &client, &bucket, &key)
.await?;
let chunk_size = body.len() as u64;

tokio::task::spawn_blocking({
let tmp_file = Arc::clone(&tmp_file);
move || {
tmp_file.write_all_at(&body, *range.0.start())?;
Ok::<(), std::io::Error>(())
}
})
.await??;

if is_interactive {
pb.inc(chunk_size);
} else {
let bytes_so_far = bytes_downloaded
.fetch_add(chunk_size, Ordering::Relaxed)
+ chunk_size;
let pct = (bytes_so_far * 100) / bytes_to_download;
if pct % 5 == 0 {
info!(
"Downloaded: ({}/{} MiB) {}%",
bytes_so_far >> 20,
bytes_to_download >> 20,
pct,
);
}
}
pct = new_pct;
}

Ok::<(), color_eyre::Report>(())
});
}

tokio::io::copy(&mut pb.wrap_async_read(reader), &mut tmp_file)
.await
.wrap_err("failed to download file")?;
tmp_file
.sync_all()
.await
.wrap_err("failed to finish saving file to disk")?;
let file_size = tmp_file
.metadata()
.await
.wrap_err("failed to inspect downloaded file size")?
.len();
while let Some(res) = tasks.join_next().await {
res??;
}

pb.finish_and_clear();

tokio::task::spawn_blocking({
let tmp_file = tmp_file.clone();
move || {
tmp_file.sync_all()?;
Ok::<(), std::io::Error>(())
}
})
.await??;

let file_size = tokio::task::spawn_blocking({
let tmp_file = tmp_file.clone();
move || {
let metadata = tmp_file.metadata()?;
Ok::<_, std::io::Error>(metadata.len())
}
})
.await??;
assert_eq!(bytes_to_download, file_size);

info!(
"Downloaded {}MiB, took {}",
bytes_to_download >> 20,
elapsed_time_as_str(start_time.elapsed(),)
);

let tmp_file = NamedTempFile::from_parts(tmp_file.into_std().await, tmp_file_path);
let tmp_file_path =
Arc::try_unwrap(tmp_file_path).expect("Multiple references to tmp_file_path");

let out_path_clone = out_path.to_owned();
tokio::task::spawn_blocking(move || {
if existing_file_behavior == ExistingFileBehavior::Abort {
Expand All @@ -119,8 +218,8 @@ pub async fn download_url(
"{out_path_clone:?} already exists!"
);
}
tmp_file
.persist(out_path_clone)
tmp_file_path
.persist_noclobber(&out_path_clone)
.wrap_err("failed to persist temporary file")
})
.await
Expand All @@ -129,6 +228,49 @@ pub async fn download_url(
Ok(())
}

async fn download_part_retry_on_timeout(
range: &ContentRange,
client: &Client,
bucket: &str,
key: &str,
) -> Result<bytes::Bytes> {
loop {
match timeout(
Duration::from_secs(30), // Timeout for downloading one part
download_part(range, client, bucket, key),
)
.await
{
Ok(result) => return result,
Err(e) => warn!("get part timeout for part {}", e),
}
}
}

async fn download_part(
range: &ContentRange,
client: &Client,
bucket: &str,
key: &str,
) -> Result<bytes::Bytes> {
let part = client
.get_object()
.bucket(bucket)
.key(key)
.range(range.to_string())
.send()
.await
.wrap_err("failed to make aws get_object request")?;

let body = part
.body
.collect()
.await
.wrap_err("failed to collect body")?;

Ok(body.into_bytes())
}

async fn client() -> Result<aws_sdk_s3::Client> {
let region_provider = RegionProviderChain::default_provider().or_else("us-east-1");
let region = region_provider.region().await.expect("infallible");
Expand All @@ -146,9 +288,15 @@ async fn client() -> Result<aws_sdk_s3::Client> {
https://worldcoin.github.io/orb-software/hil/cli."
})
.with_suggestion(|| "try running `AWS_PROFILE=hil aws sso login`")?;

let retry_config =
RetryConfig::standard().with_max_attempts(TIMEOUT_RETRY_ATTEMPTS);

let config = aws_config::defaults(BehaviorVersion::v2024_03_28())
.region(region_provider)
.credentials_provider(credentials_provider)
.retry_config(retry_config)
.stalled_stream_protection(StalledStreamProtectionConfig::disabled())
.load()
.await;

Expand Down Expand Up @@ -226,7 +374,7 @@ mod test {
fn test_parse() -> color_eyre::Result<()> {
let examples = [
(
"s3://worldcoin-orb-update-packages-stage/worldcoin/orb-os/2024-05-07-heads-main-0-g4b8aae5/rts/rts-dev.tar.zst",
"s3://worldcoin-orb-update-packages-stage/worldcoin/orb-os/2024-05-07-heads-main-0-g4b8aae5/rts/rts-dev.tar.zst",
"2024-05-07-heads-main-0-g4b8aae5-rts-dev.tar.zst"
),
(
Expand Down

0 comments on commit 5404dfd

Please sign in to comment.