Skip to content

Commit

Permalink
refactor!: ignore duplicate custom protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
amrbashir committed Sep 18, 2024
1 parent 5915341 commit c12ddd0
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 169 deletions.
5 changes: 5 additions & 0 deletions .changes/duplicate-protocol-error-linux.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wry": "minor"
---

Ignore duplicate custom protocols in `WebviewBuilder::with_custom_protocol` and `WebviewBuilder::with_async_custom_protocol` and use the last registered one.
5 changes: 5 additions & 0 deletions .changes/duplicate-protocol-error.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"wry": "minor"
---

Removed `Error::DuplicateCustomProtocol` variant.
2 changes: 0 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ pub enum Error {
#[cfg(target_os = "windows")]
#[error("WebView2 error: {0}")]
WebView2Error(webview2_com::Error),
#[error("Duplicate custom protocol registered: {0}")]
DuplicateCustomProtocol(String),
#[error(transparent)]
HttpError(#[from] http::Error),
#[error("Infallible error, something went really wrong: {0}")]
Expand Down
14 changes: 7 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ use self::webview2::*;
#[cfg(target_os = "windows")]
use webview2_com::Microsoft::Web::WebView2::Win32::ICoreWebView2Controller;

use std::{borrow::Cow, path::PathBuf, rc::Rc};
use std::{borrow::Cow, collections::HashMap, path::PathBuf, rc::Rc};

use http::{Request, Response};

Expand Down Expand Up @@ -377,7 +377,7 @@ pub struct WebViewAttributes {
/// - Android: Android has `assets` and `resource` path finder to
/// locate your files in those directories. For more information, see [Loading in-app content](https://developer.android.com/guide/webapps/load-local-content) page.
/// - iOS: To get the path of your assets, you can call [`CFBundle::resources_path`](https://docs.rs/core-foundation/latest/core_foundation/bundle/struct.CFBundle.html#method.resources_path). So url like `wry://assets/index.html` could get the html file in assets directory.
pub custom_protocols: Vec<(String, Box<dyn Fn(Request<Vec<u8>>, RequestAsyncResponder)>)>,
pub custom_protocols: HashMap<String, Box<dyn Fn(Request<Vec<u8>>, RequestAsyncResponder)>>,

/// The IPC handler to receive the message from Javascript on webview
/// using `window.ipc.postMessage("insert_message_here")` to host Rust code.
Expand Down Expand Up @@ -515,8 +515,8 @@ impl Default for WebViewAttributes {
url: None,
headers: None,
html: None,
initialization_scripts: vec![],
custom_protocols: vec![],
initialization_scripts: Default::default(),
custom_protocols: Default::default(),
ipc_handler: None,
drag_drop_handler: None,
navigation_handler: None,
Expand Down Expand Up @@ -719,13 +719,13 @@ impl<'a> WebViewBuilder<'a> {
where
F: Fn(Request<Vec<u8>>) -> Response<Cow<'static, [u8]>> + 'static,
{
self.attrs.custom_protocols.push((
self.attrs.custom_protocols.insert(
name,
Box::new(move |request, responder| {
let http_response = handler(request);
responder.respond(http_response);
}),
));
);
self
}

Expand Down Expand Up @@ -761,7 +761,7 @@ impl<'a> WebViewBuilder<'a> {
where
F: Fn(Request<Vec<u8>>, RequestAsyncResponder) + 'static,
{
self.attrs.custom_protocols.push((name, Box::new(handler)));
self.attrs.custom_protocols.insert(name, Box::new(handler));
self
}

Expand Down
280 changes: 120 additions & 160 deletions src/webkitgtk/web_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use webkit2gtk::{
pub struct WebContextImpl {
context: WebContext,
webview_uri_loader: Rc<WebViewUriLoader>,
registered_protocols: HashSet<String>,
automation: bool,
app_info: Option<ApplicationInfo>,
}
Expand Down Expand Up @@ -89,7 +88,6 @@ impl WebContextImpl {
Self {
context,
automation,
registered_protocols: Default::default(),
webview_uri_loader: Rc::default(),
app_info: Some(app_info),
}
Expand All @@ -107,22 +105,10 @@ pub trait WebContextExt {
fn context(&self) -> &WebContext;

/// Register a custom protocol to the web context.
///
/// When duplicate schemes are registered, the duplicate handler will still be submitted and the
/// `Err(Error::DuplicateCustomProtocol)` will be returned. It is safe to ignore if you are
/// relying on the platform's implementation to properly handle duplicated scheme handlers.
fn register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(Request<Vec<u8>>, RequestAsyncResponder) + 'static;

/// Register a custom protocol to the web context, only if it is not a duplicate scheme.
///
/// If a duplicate scheme has been passed, its handler will **NOT** be registered and the
/// function will return `Err(Error::DuplicateCustomProtocol)`.
fn try_register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(Request<Vec<u8>>, RequestAsyncResponder) + 'static;

/// Add a [`WebView`] to the queue waiting to be opened.
///
/// See the [`WebViewUriLoader`] for more information.
Expand Down Expand Up @@ -156,23 +142,127 @@ impl WebContextExt for super::WebContext {
where
F: Fn(Request<Vec<u8>>, RequestAsyncResponder) + 'static,
{
actually_register_uri_scheme(self, name, handler)?;
if self.os.registered_protocols.insert(name.to_string()) {
Ok(())
} else {
Err(Error::DuplicateCustomProtocol(name.to_string()))
}
}
// Enable secure context
self
.os
.context
.security_manager()
.ok_or(Error::MissingManager)?
.register_uri_scheme_as_secure(name);

self.os.context.register_uri_scheme(name, move |request| {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty).entered();

if let Some(uri) = request.uri() {
let uri = uri.as_str();

#[cfg(feature = "tracing")]
span.record("uri", uri);

#[allow(unused_mut)]
let mut http_request = Request::builder().uri(uri).method("GET");

// Set request http headers
if let Some(headers) = request.http_headers() {
if let Some(map) = http_request.headers_mut() {
headers.foreach(move |k, v| {
if let Ok(name) = HeaderName::from_bytes(k.as_bytes()) {
if let Ok(value) = HeaderValue::from_bytes(v.as_bytes()) {
map.insert(name, value);
}
}
});
}
}

// Set request http method
if let Some(method) = request.http_method() {
http_request = http_request.method(method.as_str());
}

let body;
#[cfg(feature = "linux-body")]
{
use gtk::{gdk::prelude::InputStreamExtManual, gio::Cancellable};

// Set request http body
let cancellable: Option<&Cancellable> = None;
body = request
.http_body()
.map(|s| {
const BUFFER_LEN: usize = 1024;
let mut result = Vec::new();
let mut buffer = vec![0; BUFFER_LEN];
while let Ok(count) = s.read(&mut buffer[..], cancellable) {
if count == BUFFER_LEN {
result.append(&mut buffer);
buffer.resize(BUFFER_LEN, 0);
} else {
buffer.truncate(count);
result.append(&mut buffer);
break;
}
}
result
})
.unwrap_or_default();
}
#[cfg(not(feature = "linux-body"))]
{
body = Vec::new();
}

let http_request = match http_request.body(body) {
Ok(req) => req,
Err(_) => {
request.finish_error(&mut gtk::glib::Error::new(
glib::UriError::Failed,
"Internal server error: could not create request.",
));
return;
}
};

let request_ = MainThreadRequest(request.clone());
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |http_response| {
MainContext::default().invoke(move || {
let buffer = http_response.body();
let input = gtk::gio::MemoryInputStream::from_bytes(&gtk::glib::Bytes::from(buffer));
let content_type = http_response
.headers()
.get(CONTENT_TYPE)
.and_then(|h| h.to_str().ok());

let response = URISchemeResponse::new(&input, buffer.len() as i64);
response.set_status(http_response.status().as_u16() as u32, None);
if let Some(content_type) = content_type {
response.set_content_type(content_type);
}

let headers = MessageHeaders::new(MessageHeadersType::Response);
for (name, value) in http_response.headers().into_iter() {
headers.append(name.as_str(), value.to_str().unwrap_or(""));
}
response.set_http_headers(headers);
request_.finish_with_response(&response);
});

});

#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
handler(http_request, RequestAsyncResponder { responder });
} else {
request.finish_error(&mut glib::Error::new(
glib::FileError::Exist,
"Could not get uri.",
));
}
});

fn try_register_uri_scheme<F>(&mut self, name: &str, handler: F) -> crate::Result<()>
where
F: Fn(Request<Vec<u8>>, RequestAsyncResponder) + 'static,
{
if self.os.registered_protocols.insert(name.to_string()) {
actually_register_uri_scheme(self, name, handler)
} else {
Err(Error::DuplicateCustomProtocol(name.to_string()))
}
Ok(())
}

fn queue_load_uri(&self, webview: WebView, url: String, headers: Option<http::HeaderMap>) {
Expand Down Expand Up @@ -263,136 +353,6 @@ impl WebContextExt for super::WebContext {
}
}

fn actually_register_uri_scheme<F>(
context: &mut super::WebContext,
name: &str,
handler: F,
) -> crate::Result<()>
where
F: Fn(Request<Vec<u8>>, RequestAsyncResponder) + 'static,
{
let context = &context.os.context;
// Enable secure context
context
.security_manager()
.ok_or(Error::MissingManager)?
.register_uri_scheme_as_secure(name);

context.register_uri_scheme(name, move |request| {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty).entered();

if let Some(uri) = request.uri() {
let uri = uri.as_str();

#[cfg(feature = "tracing")]
span.record("uri", uri);

#[allow(unused_mut)]
let mut http_request = Request::builder().uri(uri).method("GET");

// Set request http headers
if let Some(headers) = request.http_headers() {
if let Some(map) = http_request.headers_mut() {
headers.foreach(move |k, v| {
if let Ok(name) = HeaderName::from_bytes(k.as_bytes()) {
if let Ok(value) = HeaderValue::from_bytes(v.as_bytes()) {
map.insert(name, value);
}
}
});
}
}

// Set request http method
if let Some(method) = request.http_method() {
http_request = http_request.method(method.as_str());
}

let body;
#[cfg(feature = "linux-body")]
{
use gtk::{gdk::prelude::InputStreamExtManual, gio::Cancellable};

// Set request http body
let cancellable: Option<&Cancellable> = None;
body = request
.http_body()
.map(|s| {
const BUFFER_LEN: usize = 1024;
let mut result = Vec::new();
let mut buffer = vec![0; BUFFER_LEN];
while let Ok(count) = s.read(&mut buffer[..], cancellable) {
if count == BUFFER_LEN {
result.append(&mut buffer);
buffer.resize(BUFFER_LEN, 0);
} else {
buffer.truncate(count);
result.append(&mut buffer);
break;
}
}
result
})
.unwrap_or_default();
}
#[cfg(not(feature = "linux-body"))]
{
body = Vec::new();
}

let http_request = match http_request.body(body) {
Ok(req) => req,
Err(_) => {
request.finish_error(&mut gtk::glib::Error::new(
glib::UriError::Failed,
"Internal server error: could not create request.",
));
return;
}
};

let request_ = MainThreadRequest(request.clone());
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |http_response| {
MainContext::default().invoke(move || {
let buffer = http_response.body();
let input = gtk::gio::MemoryInputStream::from_bytes(&gtk::glib::Bytes::from(buffer));
let content_type = http_response
.headers()
.get(CONTENT_TYPE)
.and_then(|h| h.to_str().ok());

let response = URISchemeResponse::new(&input, buffer.len() as i64);
response.set_status(http_response.status().as_u16() as u32, None);
if let Some(content_type) = content_type {
response.set_content_type(content_type);
}

let headers = MessageHeaders::new(MessageHeadersType::Response);
for (name, value) in http_response.headers().into_iter() {
headers.append(name.as_str(), value.to_str().unwrap_or(""));
}
response.set_http_headers(headers);
request_.finish_with_response(&response);
});

});

#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
handler(http_request, RequestAsyncResponder { responder });
} else {
request.finish_error(&mut glib::Error::new(
glib::FileError::Exist,
"Could not get uri.",
));
}
});

Ok(())
}

struct MainThreadRequest(URISchemeRequest);

impl MainThreadRequest {
Expand Down

0 comments on commit c12ddd0

Please sign in to comment.