Skip to content

Commit

Permalink
Refactor download state
Browse files Browse the repository at this point in the history
  • Loading branch information
james58899 committed Dec 25, 2023
1 parent 0255fdd commit b211b79
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct Args {
proxy: Option<String>,
}

type DownloadState = RwLock<HashMap<[u8; 20], (watch::Receiver<Option<Arc<TempPath>>>, Arc<watch::Sender<u64>>)>>;
type DownloadState = Mutex<HashMap<[u8; 20], (watch::Receiver<Option<Arc<TempPath>>>, Arc<watch::Sender<u64>>)>>;

struct AppState {
runtime: Handle,
Expand Down
25 changes: 16 additions & 9 deletions src/route/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,31 @@ async fn hath(
let file_size = info.size() as u64;

// Check if the file is already downloading
let download_state = data.download_state.read().get(&info.hash()).cloned();
let (temp_path, mut rx) = if let Some((mut tempfile, progress)) = download_state {
let (temp_tx, temp_rx) = watch::channel(None); // Tempfile
let tx = Arc::new(watch::channel(0).0); // Download progress
let state;
{
let mut download_state = data.download_state.lock();
state = download_state.get(&info.hash()).cloned();
// Tracking download progress
if state.is_none() {
download_state.insert(info.hash(), (temp_rx.clone(), tx.clone()));
}
}

let (temp_path, mut rx) = if let Some((mut tempfile, progress)) = state {
let tempfile = tempfile.wait_for(Option::is_some).await;
if let Err(err) = tempfile {
error!("Waiting tempfile create error: {}", err);
data.download_state.write().remove(&info.hash());
data.download_state.lock().remove(&info.hash());
return HttpResponse::NotFound().body("An error has occurred. (404)");
}
(tempfile.unwrap().as_ref().unwrap().clone(), progress.subscribe())
} else {
// Tracking download progress
let (temp_tx, temp_rx) = watch::channel(None); // Tempfile
let tx = Arc::new(watch::channel(0).0); // Download progress
data.download_state.write().insert(info.hash(), (temp_rx.clone(), tx.clone()));
// Make sure the state will be removed when cancellation.
let data2 = data.clone();
let state_guard = scopeguard::guard(info.hash(), move |hash| {
data2.download_state.write().remove(&hash);
data2.download_state.lock().remove(&hash);
});

let temp_path = Arc::new(data.cache_manager.create_temp_file().await);
Expand Down Expand Up @@ -185,7 +192,7 @@ async fn hath(
let hash = hasher.finish();
tx2.send_replace(progress);
tx2.closed().await; // Wait all request done
data.download_state.write().remove(&info2.hash());
data.download_state.lock().remove(&info2.hash());
if hash == info2.hash() {
tx2.closed().await; // Wait again to avoid race conditions
data.cache_manager.import_cache(&info2, &temp_path2).await;
Expand Down

0 comments on commit b211b79

Please sign in to comment.