From 272beef806ffc4b31889971add42de05f1cad83b Mon Sep 17 00:00:00 2001 From: james58899 Date: Fri, 15 Dec 2023 10:59:22 +0000 Subject: [PATCH] Keep download state until request completed --- src/main.rs | 2 +- src/route/cache.rs | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/main.rs b/src/main.rs index 7523ff1..208967f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -139,7 +139,7 @@ struct Args { proxy: Option, } -type DownloadState = RwLock>>, watch::Receiver)>>; +type DownloadState = RwLock>>, Arc>)>>; struct AppState { runtime: Handle, diff --git a/src/route/cache.rs b/src/route/cache.rs index 6454806..1ee2635 100644 --- a/src/route/cache.rs +++ b/src/route/cache.rs @@ -90,12 +90,12 @@ async fn hath( data.download_state.write().remove(&info.hash()); return HttpResponse::NotFound().body("An error has occurred. (404)"); } - (tempfile.unwrap().as_ref().unwrap().clone(), progress) + (tempfile.unwrap().as_ref().unwrap().clone(), progress.subscribe()) } else { // Tracking download progress let (temp_tx, temp_rx) = watch::channel(None); // Tempfile - let (tx, rx) = watch::channel(0); // Download progress - data.download_state.write().insert(info.hash(), (temp_rx.clone(), rx.clone())); + let tx = Arc::new(watch::channel(0).0); // Download progress + data.download_state.write().insert(info.hash(), (temp_rx.clone(), tx.clone())); let temp_path = Arc::new(data.cache_manager.create_temp_file().await); temp_tx.send_replace(Some(temp_path.clone())); @@ -109,6 +109,7 @@ async fn hath( }; // Download worker + let tx2: Arc> = tx.clone(); let info2 = info.clone(); let temp_path2 = temp_path.clone(); data.runtime.clone().spawn(async move { @@ -169,7 +170,7 @@ async fn hath( } hasher.update(data); progress += write_size as u64; - tx.send_replace(progress); + tx2.send_replace(progress); } if progress == file_size { if let Err(err) = file.flush().await { @@ -177,10 +178,11 @@ async fn hath( break 'retry; } let hash = hasher.finish(); + tx2.send_replace(progress); + tx2.closed().await; // Wait all request done data.download_state.write().remove(&info2.hash()); - tx.send_replace(progress); if hash == info2.hash() { - tx.closed().await; // Wait all request done + tx2.closed().await; // Wait again to avoid race conditions data.cache_manager.import_cache(&info2, &temp_path2).await; } else { error!("Cache hash mismatch: expected: {:x?}, got: {:x?}", info2.hash(), hash); @@ -194,7 +196,7 @@ async fn hath( data.download_state.write().remove(&info2.hash()); }); - (temp_path, rx) + (temp_path, tx.subscribe()) }; // Wait download start or 404