Skip to content

Commit

Permalink
feat: add TSend generic to Channel (#10139)
Browse files Browse the repository at this point in the history
* add TSend to Channel

* add changeset

* fix tray Channel

* Update .changes/ipc-channel-generic.md

---------

Co-authored-by: Lucas Fernandes Nogueira <[email protected]>
  • Loading branch information
Brendonovich and lucasfernog authored Jul 10, 2024
1 parent 15e1259 commit 57612ab
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 16 deletions.
5 changes: 5 additions & 0 deletions .changes/ipc-channel-generic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"tauri": major:breaking
---

Add `TSend` generic to `ipc::Channel` for typesafe `send` calls and type inspection in `tauri-specta`
32 changes: 24 additions & 8 deletions core/tauri/src/ipc/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,19 @@ pub struct ChannelDataIpcQueue(pub(crate) Arc<Mutex<HashMap<u32, InvokeBody>>>);

/// An IPC channel.
#[derive(Clone)]
pub struct Channel {
pub struct Channel<TSend = InvokeBody> {
id: u32,
on_message: Arc<dyn Fn(InvokeBody) -> crate::Result<()> + Send + Sync>,
phantom: std::marker::PhantomData<TSend>,
}

#[cfg(feature = "specta")]
const _: () = {
#[derive(specta::Type)]
#[specta(remote = Channel, rename = "TAURI_CHANNEL")]
struct Channel<TSend>(std::marker::PhantomData<TSend>);
};

impl Serialize for Channel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down Expand Up @@ -88,7 +96,7 @@ impl FromStr for JavaScriptChannelId {

impl JavaScriptChannelId {
/// Gets a [`Channel`] for this channel ID on the given [`Webview`].
pub fn channel_on<R: Runtime>(&self, webview: Webview<R>) -> Channel {
pub fn channel_on<R: Runtime, TSend>(&self, webview: Webview<R>) -> Channel<TSend> {
let callback_id = self.0;
let counter = AtomicUsize::new(0);

Expand Down Expand Up @@ -128,7 +136,7 @@ impl<'de> Deserialize<'de> for JavaScriptChannelId {
}
}

impl Channel {
impl<TSend> Channel<TSend> {
/// Creates a new channel with the given message handler.
pub fn new<F: Fn(InvokeBody) -> crate::Result<()> + Send + Sync + 'static>(
on_message: F,
Expand All @@ -144,10 +152,15 @@ impl Channel {
let channel = Self {
id,
on_message: Arc::new(on_message),
phantom: Default::default(),
};

#[cfg(mobile)]
crate::plugin::mobile::register_channel(channel.clone());
crate::plugin::mobile::register_channel(Channel {
id,
on_message: channel.on_message.clone(),
phantom: Default::default(),
});

channel
}
Expand Down Expand Up @@ -178,13 +191,16 @@ impl Channel {
}

/// Sends the given data through the channel.
pub fn send<T: IpcResponse>(&self, data: T) -> crate::Result<()> {
pub fn send(&self, data: TSend) -> crate::Result<()>
where
TSend: IpcResponse,
{
let body = data.body()?;
(self.on_message)(body)
}
}

impl<'de, R: Runtime> CommandArg<'de, R> for Channel {
impl<'de, R: Runtime, TSend: Clone> CommandArg<'de, R> for Channel<TSend> {
/// Grabs the [`Webview`] from the [`CommandItem`] and returns the associated [`Channel`].
fn from_command(command: CommandItem<'de, R>) -> Result<Self, InvokeError> {
let name = command.name;
Expand All @@ -196,8 +212,8 @@ impl<'de, R: Runtime> CommandArg<'de, R> for Channel {
.map(|id| id.channel_on(webview))
.map_err(|_| {
InvokeError::from_anyhow(anyhow::anyhow!(
"invalid channel value `{value}`, expected a string in the `{IPC_PAYLOAD_PREFIX}ID` format"
))
"invalid channel value `{value}`, expected a string in the `{IPC_PAYLOAD_PREFIX}ID` format"
))
})
}
}
Expand Down
6 changes: 3 additions & 3 deletions core/tauri/src/menu/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ fn new<R: Runtime>(
kind: ItemKind,
options: Option<NewOptions>,
channels: State<'_, MenuChannels>,
handler: Channel,
handler: Channel<MenuId>,
) -> crate::Result<(ResourceId, MenuId)> {
let options = options.unwrap_or_default();
let mut resources_table = app.resources_table();
Expand Down Expand Up @@ -866,7 +866,7 @@ fn set_icon<R: Runtime>(
}
}

struct MenuChannels(Mutex<HashMap<MenuId, Channel>>);
struct MenuChannels(Mutex<HashMap<MenuId, Channel<MenuId>>>);

pub(crate) fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::new("menu")
Expand All @@ -877,7 +877,7 @@ pub(crate) fn init<R: Runtime>() -> TauriPlugin<R> {
.on_event(|app, e| {
if let RunEvent::MenuEvent(e) = e {
if let Some(channel) = app.state::<MenuChannels>().0.lock().unwrap().get(&e.id) {
let _ = channel.send(&e.id);
let _ = channel.send(e.id.clone());
}
}
})
Expand Down
4 changes: 2 additions & 2 deletions core/tauri/src/plugin/mobile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type PendingPluginCallHandler = Box<dyn FnOnce(PluginResponse) + Send + 'static>
static PENDING_PLUGIN_CALLS_ID: AtomicI32 = AtomicI32::new(0);
static PENDING_PLUGIN_CALLS: OnceLock<Mutex<HashMap<i32, PendingPluginCallHandler>>> =
OnceLock::new();
static CHANNELS: OnceLock<Mutex<HashMap<u32, Channel>>> = OnceLock::new();
static CHANNELS: OnceLock<Mutex<HashMap<u32, Channel<serde_json::Value>>>> = OnceLock::new();

/// Possible errors when invoking a plugin.
#[derive(Debug, thiserror::Error)]
Expand All @@ -53,7 +53,7 @@ pub enum PluginInvokeError {
CannotSerializePayload(serde_json::Error),
}

pub(crate) fn register_channel(channel: Channel) {
pub(crate) fn register_channel(channel: Channel<serde_json::Value>) {
CHANNELS
.get_or_init(Default::default)
.lock()
Expand Down
4 changes: 2 additions & 2 deletions core/tauri/src/tray/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
AppHandle, Manager, Runtime, Webview,
};

use super::TrayIcon;
use super::{TrayIcon, TrayIconEvent};

#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
Expand All @@ -36,7 +36,7 @@ struct TrayIconOptions {
fn new<R: Runtime>(
webview: Webview<R>,
options: TrayIconOptions,
handler: Channel,
handler: Channel<TrayIconEvent>,
) -> crate::Result<(ResourceId, String)> {
let mut builder = if let Some(id) = options.id {
TrayIconBuilder::<R>::with_id(id)
Expand Down
2 changes: 1 addition & 1 deletion core/tauri/src/webview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,7 @@ fn main() {
for v in map.values() {
if let serde_json::Value::String(s) = v {
let _ = crate::ipc::JavaScriptChannelId::from_str(s)
.map(|id| id.channel_on(webview.clone()));
.map(|id| id.channel_on::<R, ()>(webview.clone()));
}
}
}
Expand Down

0 comments on commit 57612ab

Please sign in to comment.