Skip to content

Commit

Permalink
fix(macos): prevent unsafe async custom protocol panic
Browse files Browse the repository at this point in the history
  • Loading branch information
pewsheen committed Jul 10, 2024
1 parent 986cecd commit 2fb6be9
Showing 1 changed file with 79 additions and 22 deletions.
101 changes: 79 additions & 22 deletions src/wkwebview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ use objc2_web_kit::{
WKAudiovisualMediaTypes, WKDownload, WKDownloadDelegate, WKFrameInfo, WKMediaCaptureType,
WKNavigation, WKNavigationAction, WKNavigationActionPolicy, WKNavigationDelegate,
WKNavigationResponse, WKNavigationResponsePolicy, WKOpenPanelParameters, WKPermissionDecision,
WKScriptMessage, WKScriptMessageHandler, WKSecurityOrigin, WKUIDelegate, WKURLSchemeTask,
WKUserContentController, WKUserScript, WKUserScriptInjectionTime, WKWebView,
WKScriptMessage, WKScriptMessageHandler, WKSecurityOrigin, WKUIDelegate, WKURLSchemeHandler,
WKURLSchemeTask, WKUserContentController, WKUserScript, WKUserScriptInjectionTime, WKWebView,
WKWebViewConfiguration, WKWebsiteDataStore,
};
use once_cell::sync::Lazy;
use raw_window_handle::{HasWindowHandle, RawWindowHandle};

use std::{
borrow::Cow,
collections::HashSet,
collections::{HashMap, HashSet},
ffi::{c_void, CStr},
os::raw::c_char,
panic::{catch_unwind, AssertUnwindSafe},
Expand Down Expand Up @@ -190,15 +190,16 @@ impl InnerWebView {
// Task handler for custom protocol
extern "C" fn start_task<'a>(
this: &AnyObject,
_: objc2::runtime::Sel,
_webview: &WKWebView,
task: *mut ProtocolObject<dyn WKURLSchemeTask>, // FIXME: not sure if this work.
_sel: objc2::runtime::Sel,
webview: *mut WryWebView,
task: *mut ProtocolObject<dyn WKURLSchemeTask>,
) {
unsafe {
#[cfg(feature = "tracing")]
tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty).entered();

let webview_id = *this.get_ivar::<u32>("webview_id");
let task_key = (*task).hash(); // hash by task object address
let task_uuid = (*webview).add_custom_task_key(task_key);

let ivar = this.class().instance_variable("webview_id").unwrap();
let webview_id: u32 = ivar.load::<u32>(this).clone();
Expand Down Expand Up @@ -279,23 +280,49 @@ impl InnerWebView {
// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
// [objc2] FIXME: retain the task?
// let () = msg_send![task, retain];
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |sent_response| {
fn check_webview_id_valid(webview_id: u32) -> crate::Result<()> {
match WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
true => Ok(()),
false => Err(crate::Error::CustomProtocolTaskInvalid),
if !WEBVIEW_IDS.lock().unwrap().contains(&webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valud(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

// FIXME: This is 10000% unsafe. `task` and `webview` are not guaranteed to be valid.
// We should consider use sync command only.
unsafe fn response(
task: *mut ProtocolObject<dyn WKURLSchemeTask>,
webview: *mut WryWebView,
task_key: usize,
task_uuid: Retained<NSUUID>,
webview_id: u32,
url: Retained<NSURL>,
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
check_task_is_valud(&*webview, task_key, task_uuid.clone())?;

let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
Expand Down Expand Up @@ -338,8 +365,8 @@ impl InnerWebView {
)
.unwrap();

// [objc2] FIXME: https://github.com/tauri-apps/wry/commit/8b691df1ac57eb5eb15082c5f6d72e871965c61e
check_webview_id_valid(webview_id)?;
check_task_is_valud(&*webview, task_key, task_uuid.clone())?;
(*task).didReceiveResponse(&response);

// Send data
Expand All @@ -348,19 +375,28 @@ impl InnerWebView {
// MIGRATE NOTE: we copied the content to the NSData because content will be freed
// when out of scope but NSData will also free the content when it's done and cause doube free.
let data = NSData::initWithBytes_length(data, bytes, content.len());
// [objc2] FIXME: https://github.com/tauri-apps/wry/commit/8b691df1ac57eb5eb15082c5f6d72e871965c61e
check_webview_id_valid(webview_id)?;
check_task_is_valud(&*webview, task_key, task_uuid.clone())?;
(*task).didReceiveData(&data);

// Finish
// [objc2] FIXME: https://github.com/tauri-apps/wry/commit/8b691df1ac57eb5eb15082c5f6d72e871965c61e
check_webview_id_valid(webview_id)?;
check_task_is_valud(&*webview, task_key, task_uuid.clone())?;
(*task).didFinish();

(*webview).remove_custom_task_key(task_key);
Ok(())
}

let _ = response(task, webview_id, url.clone(), sent_response);
let _ = response(
task,
webview,
task_key,
task_uuid,
webview_id,
url.clone(),
sent_response,
);
});

#[cfg(feature = "tracing")]
Expand All @@ -378,11 +414,12 @@ impl InnerWebView {
}
}
extern "C" fn stop_task(
_: &AnyObject,
_: objc2::runtime::Sel,
_webview: &WKWebView,
_task: &ProtocolObject<dyn WKURLSchemeTask>,
_this: &ProtocolObject<dyn WKURLSchemeHandler>,
_sel: objc2::runtime::Sel,
webview: &mut WryWebView,
task: &ProtocolObject<dyn WKURLSchemeTask>,
) {
webview.remove_custom_task_key(task.hash());
}

let mut wv_ids = WEBVIEW_IDS.lock().unwrap();
Expand Down Expand Up @@ -456,8 +493,8 @@ impl InnerWebView {
if catch_unwind(|| {
config_unwind_safe.setURLSchemeHandler_forURLScheme(
Some(&*(handler_unwind_safe.cast::<ProtocolObject<dyn WKURLSchemeHandler>>())),
&NSString::from_str(&name),
);
&NSString::from_str(&name),
);
})
.is_err()
{
Expand All @@ -476,6 +513,7 @@ impl InnerWebView {
},
#[cfg(target_os = "macos")]
accept_first_mouse: Bool::new(attributes.accept_first_mouse),
custom_protocol_task_ids: HashMap::new(),
});

config.setWebsiteDataStore(&data_store);
Expand Down Expand Up @@ -1177,6 +1215,7 @@ pub struct WryWebViewIvars {
drag_drop_handler: Box<dyn Fn(DragDropEvent) -> bool>,
#[cfg(target_os = "macos")]
accept_first_mouse: objc2::runtime::Bool,
custom_protocol_task_ids: HashMap<usize, Retained<NSUUID>>,
}

declare_class!(
Expand Down Expand Up @@ -1283,6 +1322,24 @@ declare_class!(
}
);

// Custom Protocol Task Checker
impl WryWebView {
fn add_custom_task_key(&mut self, task_id: usize) -> Retained<NSUUID> {
let task_uuid = NSUUID::new();
self
.ivars_mut()
.custom_protocol_task_ids
.insert(task_id, task_uuid.clone());
task_uuid
}
fn remove_custom_task_key(&mut self, task_id: usize) {
self.ivars_mut().custom_protocol_task_ids.remove(&task_id);
}
fn get_custom_task_uuid(&self, task_id: usize) -> Option<Retained<NSUUID>> {
self.ivars().custom_protocol_task_ids.get(&task_id).cloned()
}
}

struct WryWebViewDelegateIvars {
controller: Retained<WKUserContentController>,
ipc_handler: Box<dyn Fn(Request<String>)>,
Expand Down

0 comments on commit 2fb6be9

Please sign in to comment.