diff --git a/platforms/windows/src/adapter.rs b/platforms/windows/src/adapter.rs index 1a6ad4009..4e69aac0c 100644 --- a/platforms/windows/src/adapter.rs +++ b/platforms/windows/src/adapter.rs @@ -4,11 +4,13 @@ // the LICENSE-MIT file), at your option. use accesskit::{ - ActionHandler, ActionRequest, ActivationHandler, Live, NodeBuilder, NodeId, Role, - Tree as TreeData, TreeUpdate, + ActionHandler, ActivationHandler, Live, NodeBuilder, NodeId, Role, Tree as TreeData, TreeUpdate, }; use accesskit_consumer::{DetachedNode, FilterResult, Node, Tree, TreeChangeHandler, TreeState}; -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::HashSet, + sync::{atomic::Ordering, Arc}, +}; use windows::Win32::{ Foundation::*, UI::{Accessibility::*, WindowsAndMessaging::*}, @@ -162,20 +164,10 @@ enum State { is_window_focused: bool, action_handler: Arc, }, - Placeholder { - placeholder_context: Arc, - is_window_focused: bool, - action_handler: Arc, - }, + Placeholder(Arc), Active(Arc), } -struct PlaceholderActionHandler; - -impl ActionHandler for PlaceholderActionHandler { - fn do_action(&mut self, _request: ActionRequest) {} -} - pub struct Adapter { state: State, } @@ -239,20 +231,17 @@ impl Adapter { ) -> Option { match &self.state { State::Inactive { .. } => None, - State::Placeholder { - placeholder_context, - is_window_focused, - action_handler, - } => { - let tree = Tree::new(update_factory(), *is_window_focused); - let context = - Context::new(placeholder_context.hwnd, tree, Arc::clone(action_handler)); + State::Placeholder(context) => { + let is_window_focused = context.read_tree().state().is_host_focused(); + let tree = Tree::new(update_factory(), is_window_focused); + *context.tree.write().unwrap() = tree; + context.is_placeholder.store(false, Ordering::SeqCst); let result = context .read_tree() .state() .focus_id() - .map(|id| QueuedEvents(vec![focus_event(&context, id)])); - self.state = State::Active(context); + .map(|id| QueuedEvents(vec![focus_event(context, id)])); + self.state = State::Active(Arc::clone(context)); result } State::Active(context) => { @@ -280,11 +269,11 @@ impl Adapter { *is_window_focused = is_focused; None } - State::Placeholder { - is_window_focused, .. - } => { - *is_window_focused = is_focused; - None + State::Placeholder(context) => { + let mut handler = AdapterChangeHandler::new(context); + let mut tree = context.tree.write().unwrap(); + tree.update_host_focus_state_and_process_changes(is_focused, &mut handler); + Some(QueuedEvents(handler.queue)) } State::Active(context) => { let mut handler = AdapterChangeHandler::new(context); @@ -330,7 +319,7 @@ impl Adapter { Some(initial_state) => { let hwnd = *hwnd; let tree = Tree::new(initial_state, *is_window_focused); - let context = Context::new(hwnd, tree, Arc::clone(action_handler)); + let context = Context::new(hwnd, tree, Arc::clone(action_handler), false); let node_id = context.read_tree().state().root_id(); let platform_node = PlatformNode::new(&context, node_id); self.state = State::Active(context); @@ -343,29 +332,15 @@ impl Adapter { tree: Some(TreeData::new(PLACEHOLDER_ROOT_ID)), focus: PLACEHOLDER_ROOT_ID, }; - let placeholder_tree = Tree::new(placeholder_update, false); - let placeholder_context = Context::new( - hwnd, - placeholder_tree, - Arc::new(ActionHandlerWrapper::new(PlaceholderActionHandler {})), - ); - let platform_node = - PlatformNode::new(&placeholder_context, PLACEHOLDER_ROOT_ID); - self.state = State::Placeholder { - placeholder_context, - is_window_focused: *is_window_focused, - action_handler: Arc::clone(action_handler), - }; + let placeholder_tree = Tree::new(placeholder_update, *is_window_focused); + let context = + Context::new(hwnd, placeholder_tree, Arc::clone(action_handler), true); + let platform_node = PlatformNode::unspecified_root(&context); + self.state = State::Placeholder(context); (hwnd, platform_node) } }, - State::Placeholder { - placeholder_context, - .. - } => ( - placeholder_context.hwnd, - PlatformNode::new(placeholder_context, PLACEHOLDER_ROOT_ID), - ), + State::Placeholder(context) => (context.hwnd, PlatformNode::unspecified_root(context)), State::Active(context) => { let node_id = context.read_tree().state().root_id(); (context.hwnd, PlatformNode::new(context, node_id)) diff --git a/platforms/windows/src/context.rs b/platforms/windows/src/context.rs index 523200705..d2acbd748 100644 --- a/platforms/windows/src/context.rs +++ b/platforms/windows/src/context.rs @@ -5,7 +5,7 @@ use accesskit::{ActionHandler, ActionRequest, Point}; use accesskit_consumer::Tree; -use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; +use std::sync::{atomic::AtomicBool, Arc, Mutex, RwLock, RwLockReadGuard}; use windows::Win32::Foundation::*; use crate::util::*; @@ -32,6 +32,7 @@ pub(crate) struct Context { pub(crate) hwnd: HWND, pub(crate) tree: RwLock, pub(crate) action_handler: Arc, + pub(crate) is_placeholder: AtomicBool, } impl Context { @@ -39,11 +40,13 @@ impl Context { hwnd: HWND, tree: Tree, action_handler: Arc, + is_placeholder: bool, ) -> Arc { Arc::new(Self { hwnd, tree: RwLock::new(tree), action_handler, + is_placeholder: AtomicBool::new(is_placeholder), }) } diff --git a/platforms/windows/src/node.rs b/platforms/windows/src/node.rs index 9e6b6c383..9ec43371b 100644 --- a/platforms/windows/src/node.rs +++ b/platforms/windows/src/node.rs @@ -15,7 +15,7 @@ use accesskit::{ }; use accesskit_consumer::{DetachedNode, FilterResult, Node, NodeState, TreeState}; use paste::paste; -use std::sync::{Arc, Weak}; +use std::sync::{atomic::Ordering, Arc, Weak}; use windows::{ core::*, Win32::{Foundation::*, System::Com::*, UI::Accessibility::*}, @@ -492,14 +492,21 @@ impl<'a> NodeWrapper<'a> { )] pub(crate) struct PlatformNode { pub(crate) context: Weak, - pub(crate) node_id: NodeId, + pub(crate) node_id: Option, } impl PlatformNode { pub(crate) fn new(context: &Arc, node_id: NodeId) -> Self { Self { context: Arc::downgrade(context), - node_id, + node_id: Some(node_id), + } + } + + pub(crate) fn unspecified_root(context: &Arc) -> Self { + Self { + context: Arc::downgrade(context), + node_id: None, } } @@ -523,16 +530,25 @@ impl PlatformNode { self.with_tree_state_and_context(|state, _| f(state)) } + fn node<'a>(&self, state: &'a TreeState) -> Result> { + if let Some(id) = self.node_id { + if let Some(node) = state.node_by_id(id) { + Ok(node) + } else { + Err(element_not_available()) + } + } else { + Ok(state.root()) + } + } + fn resolve_with_context(&self, f: F) -> Result where for<'a> F: FnOnce(Node<'a>, &Context) -> Result, { self.with_tree_state_and_context(|state, context| { - if let Some(node) = state.node_by_id(self.node_id) { - f(node, context) - } else { - Err(element_not_available()) - } + let node = self.node(state)?; + f(node, context) }) } @@ -541,11 +557,8 @@ impl PlatformNode { for<'a> F: FnOnce(Node<'a>, &TreeState, &Context) -> Result, { self.with_tree_state_and_context(|state, context| { - if let Some(node) = state.node_by_id(self.node_id) { - f(node, state, context) - } else { - Err(element_not_available()) - } + let node = self.node(state)?; + f(node, state, context) }) } @@ -561,10 +574,8 @@ impl PlatformNode { for<'a> F: FnOnce(Node<'a>, &Context) -> Result, { self.with_tree_state_and_context(|state, context| { - if let Some(node) = state - .node_by_id(self.node_id) - .filter(Node::supports_text_ranges) - { + let node = self.node(state)?; + if node.supports_text_ranges() { f(node, context) } else { Err(element_not_available()) @@ -581,34 +592,46 @@ impl PlatformNode { fn do_action(&self, f: F) -> Result<()> where - F: FnOnce() -> ActionRequest, + F: FnOnce() -> (Action, Option), { let context = self.upgrade_context()?; + if context.is_placeholder.load(Ordering::SeqCst) { + return Ok(()); + } let tree = context.read_tree(); - if tree.state().has_node(self.node_id) { - drop(tree); - let request = f(); - context.do_action(request); - Ok(()) + let node_id = if let Some(id) = self.node_id { + if !tree.state().has_node(id) { + return Err(element_not_available()); + } + id } else { - Err(element_not_available()) - } + tree.state().root_id() + }; + drop(tree); + let (action, data) = f(); + let request = ActionRequest { + target: node_id, + action, + data, + }; + context.do_action(request); + Ok(()) } fn do_default_action(&self) -> Result<()> { - self.do_action(|| ActionRequest { - action: Action::Default, - target: self.node_id, - data: None, - }) + self.do_action(|| (Action::Default, None)) } fn relative(&self, node_id: NodeId) -> Self { Self { context: self.context.clone(), - node_id, + node_id: Some(node_id), } } + + fn is_root(&self, state: &TreeState) -> bool { + self.node_id.map_or(false, |id| id == state.root_id()) + } } #[allow(non_snake_case)] @@ -651,7 +674,7 @@ impl IRawElementProviderSimple_Impl for PlatformNode { fn HostRawElementProvider(&self) -> Result { self.with_tree_state_and_context(|state, context| { - if self.node_id == state.root_id() { + if self.is_root(state) { unsafe { UiaHostProviderFromHwnd(context.hwnd) } } else { Err(Error::empty()) @@ -682,7 +705,17 @@ impl IRawElementProviderFragment_Impl for PlatformNode { } fn GetRuntimeId(&self) -> Result<*mut SAFEARRAY> { - let runtime_id = runtime_id_from_node_id(self.node_id); + let node_id = if let Some(id) = self.node_id { + id + } else { + // Since this `PlatformNode` isn't associated with a specific + // node ID, but always uses whatever node is currently the root, + // we shouldn't return a UIA runtime ID calculated from an + // AccessKit node ID, as we normally do. Fortunately, + // UIA doesn't seem to actually call `GetRuntimeId` on the root. + return Err(not_implemented()); + }; + let runtime_id = runtime_id_from_node_id(node_id); Ok(safe_array_from_i32_slice(&runtime_id)) } @@ -706,20 +739,16 @@ impl IRawElementProviderFragment_Impl for PlatformNode { } fn SetFocus(&self) -> Result<()> { - self.do_action(|| ActionRequest { - action: Action::Focus, - target: self.node_id, - data: None, - }) + self.do_action(|| (Action::Focus, None)) } fn FragmentRoot(&self) -> Result { self.with_tree_state(|state| { - let root_id = state.root_id(); - if root_id == self.node_id { + if self.is_root(state) { // SAFETY: We know &self is inside a full COM implementation. unsafe { self.cast() } } else { + let root_id = state.root_id(); Ok(self.relative(root_id).into()) } }) @@ -743,7 +772,12 @@ impl IRawElementProviderFragmentRoot_Impl for PlatformNode { fn GetFocus(&self) -> Result { self.with_tree_state(|state| { if let Some(id) = state.focus_id() { - if id != self.node_id { + let self_id = if let Some(id) = self.node_id { + id + } else { + state.root_id() + }; + if id != self_id { return Ok(self.relative(id).into()); } } @@ -885,11 +919,7 @@ patterns! { fn SetValue(&self, value: &PCWSTR) -> Result<()> { self.do_action(|| { let value = unsafe { value.to_string() }.unwrap(); - ActionRequest { - action: Action::SetValue, - target: self.node_id, - data: Some(ActionData::Value(value.into())), - } + (Action::SetValue, Some(ActionData::Value(value.into()))) }) } )), @@ -903,11 +933,7 @@ patterns! { ), ( fn SetValue(&self, value: f64) -> Result<()> { self.do_action(|| { - ActionRequest { - action: Action::SetValue, - target: self.node_id, - data: Some(ActionData::NumericValue(value)), - } + (Action::SetValue, Some(ActionData::NumericValue(value))) }) } )), diff --git a/platforms/windows/src/text.rs b/platforms/windows/src/text.rs index ef2f22ec8..55a2f12a3 100644 --- a/platforms/windows/src/text.rs +++ b/platforms/windows/src/text.rs @@ -474,7 +474,7 @@ impl ITextRangeProvider_Impl for PlatformRange { // Revisit this if we eventually support embedded objects. Ok(PlatformNode { context: self.context.clone(), - node_id: node.id(), + node_id: Some(node.id()), } .into()) })