diff --git a/core/tests/app-updater/tests/update.rs b/core/tests/app-updater/tests/update.rs index fd2123f919d5..3e7a899e3323 100644 --- a/core/tests/app-updater/tests/update.rs +++ b/core/tests/app-updater/tests/update.rs @@ -17,6 +17,7 @@ use hyper::{ Body, Method, Request, Response, StatusCode, }; use serde::Serialize; +use tokio::sync::Mutex; use tokio_util::codec::{BytesCodec, FramedRead}; const UPDATER_PRIVATE_KEY: &str = "dW50cnVzdGVkIGNvbW1lbnQ6IHJzaWduIGVuY3J5cHRlZCBzZWNyZXQga2V5ClJXUlRZMEl5dkpDN09RZm5GeVAzc2RuYlNzWVVJelJRQnNIV2JUcGVXZUplWXZXYXpqUUFBQkFBQUFBQUFBQUFBQUlBQUFBQTZrN2RnWGh5dURxSzZiL1ZQSDdNcktiaHRxczQwMXdQelRHbjRNcGVlY1BLMTBxR2dpa3I3dDE1UTVDRDE4MXR4WlQwa1BQaXdxKy9UU2J2QmVSNXhOQWFDeG1GSVllbUNpTGJQRkhhTnROR3I5RmdUZi90OGtvaGhJS1ZTcjdZU0NyYzhQWlQ5cGM9Cg=="; @@ -383,6 +384,13 @@ fn update_app_flow) -> (PathBuf, TauriVersion)>(build_app_ } }; + let updater_state = UpdaterState { + target: Default::default(), + signature: Default::default(), + updater_path: Default::default(), + }; + let (runtime, shutdown_tx) = start_updater_server(updater_state.clone()); + for (bundle_target, out_bundle_path) in bundle_paths(&app_root, UPDATE_APP_VERSION) { let mut bundle_updater_ext = out_bundle_path .extension() @@ -418,70 +426,10 @@ fn update_app_flow) -> (PathBuf, TauriVersion)>(build_app_ )); std::fs::rename(&out_updater_path, &updater_path).expect("failed to rename bundle"); - let target = target.clone(); - - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - - let runtime = tokio::runtime::Runtime::new().unwrap(); - - runtime.spawn(async move { - // create the updater server - let addr = "127.0.0.1:3007".parse().unwrap(); - - let make_service = make_service_fn(move |_| { - let updater_path = updater_path.clone(); - let signature = signature.clone(); - let target = target.clone(); - async move { - Ok::<_, hyper::Error>(service_fn(move |req| { - let updater_path = updater_path.clone(); - let signature = signature.clone(); - let target = target.clone(); - async move { - match (req.method(), req.uri().path()) { - (&Method::GET, "/") => { - let mut platforms = HashMap::new(); - - platforms.insert( - target.clone(), - PlatformUpdate { - signature: signature.clone(), - url: "http://localhost:3007/download", - with_elevated_task: false, - }, - ); - let body = serde_json::to_vec(&Update { - version: UPDATE_APP_VERSION, - date: time::OffsetDateTime::now_utc() - .format(&time::format_description::well_known::Rfc3339) - .unwrap(), - platforms, - }) - .unwrap(); - - Ok(Response::new(hyper::Body::from(body))) - } - (&Method::GET, "/download") => { - let file = tokio::fs::File::open(&updater_path).await.unwrap(); - let stream = FramedRead::new(file, BytesCodec::new()); - let body = Body::wrap_stream(stream); - return Ok(Response::new(body)); - } - _ => Response::builder() - .status(StatusCode::NOT_FOUND) - .body("Not Found".into()), - } - } - })) - } - }); - let server = hyper::Server::bind(&addr).serve(make_service); - - let graceful = server.with_graceful_shutdown(async { - rx.await.ok(); - }); - - graceful.await.unwrap(); + runtime.block_on(async { + *updater_state.target.lock().await = target.clone(); + *updater_state.signature.lock().await = signature.clone(); + *updater_state.updater_path.lock().await = updater_path.clone(); }); let config = Config { @@ -575,8 +523,89 @@ fn update_app_flow) -> (PathBuf, TauriVersion)>(build_app_ // force Rust to rebuild the binary so it doesn't conflict with other test runs #[cfg(windows)] std::fs::remove_file(tauri_v1_fixture_dir.join("target/debug/app-updater.exe")).unwrap(); - - // graceful shutdown - tx.send(()).unwrap(); } + + // graceful shutdown + shutdown_tx.send(()).unwrap(); +} + +#[derive(Clone)] +struct UpdaterState { + target: Arc>, + signature: Arc>, + updater_path: Arc>, +} + +fn start_updater_server( + state: UpdaterState, +) -> (tokio::runtime::Runtime, tokio::sync::oneshot::Sender<()>) { + let (tx, rx) = tokio::sync::oneshot::channel::<()>(); + + let runtime = tokio::runtime::Runtime::new().unwrap(); + + runtime.spawn(async move { + // create the updater server + let addr = "127.0.0.1:3007".parse().unwrap(); + + let make_service = make_service_fn(move |_| { + let state = state.clone(); + async move { + Ok::<_, hyper::Error>(service_fn(move |req| { + let state = state.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::GET, "/") => { + let mut platforms = HashMap::new(); + + platforms.insert( + state.target.lock().await.clone(), + PlatformUpdate { + signature: state.signature.lock().await.clone(), + url: "http://localhost:3007/download", + with_elevated_task: false, + }, + ); + let body = serde_json::to_vec(&Update { + version: UPDATE_APP_VERSION, + date: time::OffsetDateTime::now_utc() + .format(&time::format_description::well_known::Rfc3339) + .unwrap(), + platforms, + }) + .unwrap(); + + Ok(Response::new(hyper::Body::from(body))) + } + (&Method::GET, "/download") => { + println!("downloading updater"); + let file = tokio::fs::File::open(&*state.updater_path.lock().await) + .await + .unwrap(); + println!("opened updater file"); + let stream = FramedRead::new(file, BytesCodec::new()); + let body = Body::wrap_stream(stream); + println!("sending updater response"); + return Ok(Response::new(body)); + } + _ => Response::builder() + .status(StatusCode::NOT_FOUND) + .body("Not Found".into()), + } + } + })) + } + }); + let server = hyper::Server::bind(&addr).serve(make_service); + + let graceful = server.with_graceful_shutdown(async { + println!("received shutdown"); + rx.await.ok(); + }); + + graceful.await.unwrap(); + + println!("done serving updates"); + }); + + (runtime, tx) }