From 2fb6be9d6403e20540d00f90075e644b1f6ed842 Mon Sep 17 00:00:00 2001 From: Jason Tsai Date: Wed, 10 Jul 2024 18:18:24 +0800 Subject: [PATCH] fix(macos): prevent unsafe async custom protocol panic --- src/wkwebview/mod.rs | 101 +++++++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 22 deletions(-) diff --git a/src/wkwebview/mod.rs b/src/wkwebview/mod.rs index aef2a79c1..702d21e73 100644 --- a/src/wkwebview/mod.rs +++ b/src/wkwebview/mod.rs @@ -45,8 +45,8 @@ use objc2_web_kit::{ WKAudiovisualMediaTypes, WKDownload, WKDownloadDelegate, WKFrameInfo, WKMediaCaptureType, WKNavigation, WKNavigationAction, WKNavigationActionPolicy, WKNavigationDelegate, WKNavigationResponse, WKNavigationResponsePolicy, WKOpenPanelParameters, WKPermissionDecision, - WKScriptMessage, WKScriptMessageHandler, WKSecurityOrigin, WKUIDelegate, WKURLSchemeTask, - WKUserContentController, WKUserScript, WKUserScriptInjectionTime, WKWebView, + WKScriptMessage, WKScriptMessageHandler, WKSecurityOrigin, WKUIDelegate, WKURLSchemeHandler, + WKURLSchemeTask, WKUserContentController, WKUserScript, WKUserScriptInjectionTime, WKWebView, WKWebViewConfiguration, WKWebsiteDataStore, }; use once_cell::sync::Lazy; @@ -54,7 +54,7 @@ use raw_window_handle::{HasWindowHandle, RawWindowHandle}; use std::{ borrow::Cow, - collections::HashSet, + collections::{HashMap, HashSet}, ffi::{c_void, CStr}, os::raw::c_char, panic::{catch_unwind, AssertUnwindSafe}, @@ -190,15 +190,16 @@ impl InnerWebView { // Task handler for custom protocol extern "C" fn start_task<'a>( this: &AnyObject, - _: objc2::runtime::Sel, - _webview: &WKWebView, - task: *mut ProtocolObject, // FIXME: not sure if this work. + _sel: objc2::runtime::Sel, + webview: *mut WryWebView, + task: *mut ProtocolObject, ) { unsafe { #[cfg(feature = "tracing")] tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty).entered(); - let webview_id = *this.get_ivar::("webview_id"); + let task_key = (*task).hash(); // hash by task object address + let task_uuid = (*webview).add_custom_task_key(task_key); let ivar = this.class().instance_variable("webview_id").unwrap(); let webview_id: u32 = ivar.load::(this).clone(); @@ -279,23 +280,49 @@ impl InnerWebView { // send response match http_request.body(sent_form_body) { Ok(final_request) => { - // [objc2] FIXME: retain the task? - // let () = msg_send![task, retain]; let responder: Box>)> = Box::new(move |sent_response| { fn check_webview_id_valid(webview_id: u32) -> crate::Result<()> { - match WEBVIEW_IDS.lock().unwrap().contains(&webview_id) { - true => Ok(()), - false => Err(crate::Error::CustomProtocolTaskInvalid), + if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) { + return Err(crate::Error::CustomProtocolTaskInvalid); } + Ok(()) + } + /// Task may not live longer than async custom protocol handler. + /// + /// There are roughly 2 ways to cause segfault: + /// 1. Task has stopped. pointer of the task not valid anymore. + /// 2. Task had stopped, but the pointer of the task has allocated to a new task. + /// Outdated custom handler may call to the new task instance and cause segfault. + fn check_task_is_valud( + webview: &WryWebView, + task_key: usize, + current_uuid: Retained, + ) -> crate::Result<()> { + let latest_task_uuid = webview.get_custom_task_uuid(task_key); + if let Some(latest_uuid) = latest_task_uuid { + if latest_uuid != current_uuid { + return Err(crate::Error::CustomProtocolTaskInvalid); + } + } else { + return Err(crate::Error::CustomProtocolTaskInvalid); + } + Ok(()) } + // FIXME: This is 10000% unsafe. `task` and `webview` are not guaranteed to be valid. + // We should consider use sync command only. unsafe fn response( task: *mut ProtocolObject, + webview: *mut WryWebView, + task_key: usize, + task_uuid: Retained, webview_id: u32, url: Retained, sent_response: HttpResponse>, ) -> crate::Result<()> { + check_task_is_valud(&*webview, task_key, task_uuid.clone())?; + let content = sent_response.body(); // default: application/octet-stream, but should be provided by the client let wanted_mime = sent_response.headers().get(CONTENT_TYPE); @@ -338,8 +365,8 @@ impl InnerWebView { ) .unwrap(); - // [objc2] FIXME: https://github.com/tauri-apps/wry/commit/8b691df1ac57eb5eb15082c5f6d72e871965c61e check_webview_id_valid(webview_id)?; + check_task_is_valud(&*webview, task_key, task_uuid.clone())?; (*task).didReceiveResponse(&response); // Send data @@ -348,19 +375,28 @@ impl InnerWebView { // MIGRATE NOTE: we copied the content to the NSData because content will be freed // when out of scope but NSData will also free the content when it's done and cause doube free. let data = NSData::initWithBytes_length(data, bytes, content.len()); - // [objc2] FIXME: https://github.com/tauri-apps/wry/commit/8b691df1ac57eb5eb15082c5f6d72e871965c61e check_webview_id_valid(webview_id)?; + check_task_is_valud(&*webview, task_key, task_uuid.clone())?; (*task).didReceiveData(&data); // Finish - // [objc2] FIXME: https://github.com/tauri-apps/wry/commit/8b691df1ac57eb5eb15082c5f6d72e871965c61e check_webview_id_valid(webview_id)?; + check_task_is_valud(&*webview, task_key, task_uuid.clone())?; (*task).didFinish(); + (*webview).remove_custom_task_key(task_key); Ok(()) } - let _ = response(task, webview_id, url.clone(), sent_response); + let _ = response( + task, + webview, + task_key, + task_uuid, + webview_id, + url.clone(), + sent_response, + ); }); #[cfg(feature = "tracing")] @@ -378,11 +414,12 @@ impl InnerWebView { } } extern "C" fn stop_task( - _: &AnyObject, - _: objc2::runtime::Sel, - _webview: &WKWebView, - _task: &ProtocolObject, + _this: &ProtocolObject, + _sel: objc2::runtime::Sel, + webview: &mut WryWebView, + task: &ProtocolObject, ) { + webview.remove_custom_task_key(task.hash()); } let mut wv_ids = WEBVIEW_IDS.lock().unwrap(); @@ -456,8 +493,8 @@ impl InnerWebView { if catch_unwind(|| { config_unwind_safe.setURLSchemeHandler_forURLScheme( Some(&*(handler_unwind_safe.cast::>())), - &NSString::from_str(&name), - ); + &NSString::from_str(&name), + ); }) .is_err() { @@ -476,6 +513,7 @@ impl InnerWebView { }, #[cfg(target_os = "macos")] accept_first_mouse: Bool::new(attributes.accept_first_mouse), + custom_protocol_task_ids: HashMap::new(), }); config.setWebsiteDataStore(&data_store); @@ -1177,6 +1215,7 @@ pub struct WryWebViewIvars { drag_drop_handler: Box bool>, #[cfg(target_os = "macos")] accept_first_mouse: objc2::runtime::Bool, + custom_protocol_task_ids: HashMap>, } declare_class!( @@ -1283,6 +1322,24 @@ declare_class!( } ); +// Custom Protocol Task Checker +impl WryWebView { + fn add_custom_task_key(&mut self, task_id: usize) -> Retained { + let task_uuid = NSUUID::new(); + self + .ivars_mut() + .custom_protocol_task_ids + .insert(task_id, task_uuid.clone()); + task_uuid + } + fn remove_custom_task_key(&mut self, task_id: usize) { + self.ivars_mut().custom_protocol_task_ids.remove(&task_id); + } + fn get_custom_task_uuid(&self, task_id: usize) -> Option> { + self.ivars().custom_protocol_task_ids.get(&task_id).cloned() + } +} + struct WryWebViewDelegateIvars { controller: Retained, ipc_handler: Box)>,