Skip to content

Commit

Permalink
Support richer COM interface hierarchies (#1608)
Browse files Browse the repository at this point in the history
  • Loading branch information
rylev authored Mar 17, 2022
1 parent d511ecc commit 86e8088
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 82 deletions.
67 changes: 46 additions & 21 deletions crates/libs/interface/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ macro_rules! expected_token {
/// }
/// ```
struct Interface {
pub visibility: syn::Visibility,
pub name: syn::Ident,
pub parent: Option<syn::Path>,
pub methods: Vec<InterfaceMethod>,
visibility: syn::Visibility,
name: syn::Ident,
parent: syn::Path,
methods: Vec<InterfaceMethod>,
docs: Vec<syn::Attribute>,
}

Expand All @@ -67,8 +67,8 @@ impl Interface {
let vis = &self.visibility;
let name = &self.name;
let docs = &self.docs;
let parent = self.parent.as_ref().map(|p| quote!(#p)).unwrap_or_else(|| quote!(::windows::core::IUnknown));
let vtable_name = quote::format_ident!("{}_Vtbl", name);
let parent = self.parent();
let vtable_name = quote::format_ident!("{}Vtbl", name);
let guid = guid.to_tokens()?;
let implementation = self.gen_implementation();
let com_trait = self.get_com_trait();
Expand Down Expand Up @@ -143,9 +143,11 @@ impl Interface {
}
})
.collect::<Vec<_>>();
let parent = self.parent_trait_constraint();

quote! {
#[allow(non_camel_case_types)]
#vis trait #name: Sized {
#vis trait #name: #parent Sized {
#(#methods)*
}
}
Expand All @@ -154,8 +156,8 @@ impl Interface {
/// Generates the vtable for a COM interface
fn gen_vtable(&self, vtable_name: &syn::Ident) -> proc_macro2::TokenStream {
let name = &self.name;
// TODO
let parent_vtable = quote!(::windows::core::IUnknownVtbl);
let parent_vtable = self.parent_vtable();
let parent_vtable_generics = if self.parent_is_iunknown() { quote!(Identity, OFFSET) } else { quote!(Identity, Impl, OFFSET) };
let vtable_entries = self
.methods
.iter()
Expand Down Expand Up @@ -207,15 +209,14 @@ impl Interface {
#[repr(C)]
#[doc(hidden)]
pub struct #vtable_name {
// TODO: handle non-IUnknown parents
pub base: ::windows::core::IUnknownVtbl,
pub base: #parent_vtable,
#(#vtable_entries)*
}

impl #vtable_name {
pub const fn new<Identity: ::windows::core::IUnknownImpl, Impl: #trait_name, const OFFSET: isize>() -> Self {
#(#functions)*
Self { base: #parent_vtable::new::<Identity, OFFSET>(), #(#entries),* }
Self { base: #parent_vtable::new::<#parent_vtable_generics>(), #(#entries),* }
}

pub fn matches(iid: &windows::core::GUID) -> bool {
Expand All @@ -232,8 +233,7 @@ impl Interface {
quote! {
impl ::core::convert::From<#name> for ::windows::core::IUnknown {
fn from(value: #name) -> Self {
// TODO: handle when direct parent is not IUnknown
value.0
unsafe { ::core::mem::transmute(value) }
}
}
impl ::core::convert::From<&#name> for ::windows::core::IUnknown {
Expand All @@ -248,8 +248,7 @@ impl Interface {
}
impl<'a> ::windows::core::IntoParam<'a, ::windows::core::IUnknown> for &'a #name {
fn into_param(self) -> ::windows::core::Param<'a, ::windows::core::IUnknown> {
// TODO: handle when direct parent is not IUnknown
::windows::core::Param::Borrowed(&self.0)
::windows::core::Param::Borrowed(unsafe { ::core::mem::transmute(self) })
}
}
impl ::core::clone::Clone for #name {
Expand All @@ -270,6 +269,35 @@ impl Interface {
}
}
}

fn parent(&self) -> proc_macro2::TokenStream {
let p = &self.parent;
quote!(#p)
}

fn parent_vtable(&self) -> proc_macro2::TokenStream {
let i = self.parent_ident();
let i = quote::format_ident!("{}Vtbl", i);
quote!(#i)
}

fn parent_is_iunknown(&self) -> bool {
self.parent.is_ident("IUnknown")
}

fn parent_ident(&self) -> &syn::Ident {
&self.parent.segments.last().as_ref().expect("segements should never be empty").ident
}

/// Gets the parent trait constrait which is nothing if the parent is IUnknown
fn parent_trait_constraint(&self) -> proc_macro2::TokenStream {
let i = self.parent_ident();
if i == "IUnknown" {
return quote!();
}
let i = quote::format_ident!("{}_Impl", i);
quote!(#i +)
}
}

impl Parse for Interface {
Expand All @@ -289,11 +317,8 @@ impl Parse for Interface {
let _ = input.parse::<syn::Token![unsafe]>()?;
let _ = input.parse::<syn::Token![trait]>()?;
let name = input.parse::<syn::Ident>()?;
let mut parent = None;
if name != "IUnknown" {
let _ = input.parse::<syn::Token![:]>().map_err(|_| syn::Error::new(name.span(), format!("Interfaces must inherit from another interface like so: `interface {}: IParentInterface`", name)))?;
parent = Some(input.parse::<syn::Path>()?);
}
let _ = input.parse::<syn::Token![:]>().map_err(|_| syn::Error::new(name.span(), format!("Interfaces must inherit from another interface like so: `interface {}: IParentInterface`", name)))?;
let parent = input.parse::<syn::Path>()?;
let content;
syn::braced!(content in input);
let mut methods = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/libs/windows/src/core/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub unsafe trait Interface: Sized {
#[doc(hidden)]
unsafe fn assume_vtable<T: Interface>(&self) -> &T::Vtable {
let this: RawPtr = core::mem::transmute_copy(self);
&(*(*(this as *mut *mut _) as *mut _))
&**(this as *mut *mut T::Vtable)
}

#[doc(hidden)]
Expand Down
166 changes: 106 additions & 60 deletions crates/tests/nightly_interface/tests/com.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,99 @@ pub unsafe trait ICustomUri: IUnknown {
// etc
}

#[implement(ICustomUri)]
struct CustomUri;

impl ICustomUri_Impl for CustomUri {
unsafe fn GetPropertyBSTR(&self, property: Uri_PROPERTY, value: *mut BSTR, flags: u32) -> HRESULT {
assert!(flags == 0);
assert!(property == Uri_PROPERTY_DOMAIN);
*value = "property".into();
S_OK
}
unsafe fn GetPropertyLength(&self) -> HRESULT {
todo!()
/// A custom declaration of implementation of `IPersist`
#[interface("0000010c-0000-0000-C000-000000000046")]
pub unsafe trait ICustomPersist: IUnknown {
unsafe fn GetClassID(&self, clsid: *mut GUID) -> HRESULT;
}

/// A custom declaration of implementation of `IPersistMemory`
#[interface("BD1AE5E0-A6AE-11CE-BD37-504200C10000")]
pub unsafe trait ICustomPersistMemory: ICustomPersist {
unsafe fn IsDirty(&self) -> HRESULT;
unsafe fn Load(&self, input: *const core::ffi::c_void, size: u32) -> HRESULT;
unsafe fn Save(&self, output: *mut core::ffi::c_void, clear_dirty: BOOL, size: u32) -> HRESULT;
unsafe fn GetSizeMax(&self, len: *mut u32) -> HRESULT;
unsafe fn InitNew(&self) -> HRESULT;
}

/// A custom in-memory store
#[implement(ICustomPersistMemory, ICustomPersist)]
#[derive(Default)]
struct Persist(std::sync::RwLock<PersistState>);

impl Persist {
fn new() -> Self {
Self(std::sync::RwLock::new(PersistState::default()))
}
unsafe fn GetPropertyDWORD(&self, property: Uri_PROPERTY, value: *mut u32, flags: u32) -> HRESULT {
assert!(flags == 0);
assert!(property == Uri_PROPERTY_PORT);
*value = 123;
}

#[derive(Default)]
struct PersistState {
memory: [u8; 10],
dirty: bool,
}

impl ICustomPersist_Impl for Persist {
unsafe fn GetClassID(&self, clsid: *mut GUID) -> HRESULT {
*clsid = "117fb826-2155-483a-b50d-bc99a2c7cca3".into();
S_OK
}
unsafe fn HasProperty(&self) {
todo!()
}

impl ICustomPersistMemory_Impl for Persist {
unsafe fn IsDirty(&self) -> HRESULT {
let reader = self.0.read().unwrap();
if reader.dirty {
S_OK
} else {
S_FALSE
}
}
unsafe fn GetAbsoluteUri(&self) -> HRESULT {
todo!()

unsafe fn Load(&self, input: *const core::ffi::c_void, size: u32) -> HRESULT {
let mut writer = self.0.write().unwrap();
if size <= writer.memory.len() as _ {
std::ptr::copy(input, writer.memory.as_mut_ptr() as _, size as _);
writer.dirty = true;
S_OK
} else {
E_OUTOFMEMORY
}
}
unsafe fn GetAuthority(&self) -> HRESULT {
todo!()

unsafe fn Save(&self, output: *mut core::ffi::c_void, clear_dirty: BOOL, size: u32) -> HRESULT {
let mut writer = self.0.write().unwrap();
if size <= writer.memory.len() as _ {
std::ptr::copy(writer.memory.as_mut_ptr() as _, output, size as _);
if clear_dirty.as_bool() {
writer.dirty = false;
}
S_OK
} else {
E_OUTOFMEMORY
}
}
unsafe fn GetDisplayUri(&self) -> i32 {
todo!()

unsafe fn GetSizeMax(&self, len: *mut u32) -> HRESULT {
let reader = self.0.read().unwrap();
*len = reader.memory.len() as _;
S_OK
}
unsafe fn GetDomain(&self, value: *mut BSTR) -> HRESULT {
*value = "kennykerr.ca".into();

unsafe fn InitNew(&self) -> HRESULT {
let mut writer = self.0.write().unwrap();
writer.memory = Default::default();
writer.dirty = false;
S_OK
}
}

#[test]
fn test_custom_interface() -> windows::core::Result<()> {
unsafe {
// Use the OS implementation through the OS interface
// Use the OS implementation of Uri through the custom `ICustomUri` interface
let a: IUri = CreateUri("http://kennykerr.ca", Default::default(), 0)?;
let domain = a.GetDomain()?;
assert_eq!(domain, "kennykerr.ca");
let mut property = BSTR::new();
a.GetPropertyBSTR(Uri_PROPERTY_DOMAIN, &mut property, 0)?;
assert_eq!(property, "kennykerr.ca");
let mut property = 0;
a.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0)?;
assert_eq!(property, 80);

// Call the OS implementation through the custom interface
let b: ICustomUri = a.cast()?;
let mut domain = BSTR::new();
b.GetDomain(&mut domain).ok()?;
Expand All @@ -79,30 +120,35 @@ fn test_custom_interface() -> windows::core::Result<()> {
a.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0)?;
assert_eq!(property, 80);

// Use the custom implementation through the OS interface
let c: ICustomUri = CustomUri.into();
// This works because `ICustomUri` and `IUri` share the same guid
let c: IUri = c.cast()?;
let domain = c.GetDomain()?;
assert_eq!(domain, "kennykerr.ca");
let mut property = BSTR::new();
c.GetPropertyBSTR(Uri_PROPERTY_DOMAIN, &mut property, 0)?;
assert_eq!(property, "property");
let mut property = 0;
c.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0)?;
assert_eq!(property, 123);
// Use the custom implementation of `Persist` through the OS `IPersistMemory` interface
let p: ICustomPersistMemory = Persist::new().into();
// This works because `ICustomPersistMemory` and `IPersistMemory` share the same guid
let p: IPersistMemory = p.cast()?;
assert_eq!(p.GetClassID()?, "117fb826-2155-483a-b50d-bc99a2c7cca3".into());
// TODO: can't test IsDirty until this is fixed: https://github.com/microsoft/win32metadata/issues/838
assert_eq!(p.GetSizeMax()?, 10);
p.Load(&[0xAAu8, 0xBB, 0xCC])?;
let mut memory = [0x00u8, 0x00, 0x00, 0x00];
p.Save(&mut memory, true)?;
assert_eq!(memory, [0xAAu8, 0xBB, 0xCC, 0x00]);

// Call the custom implementation through the custom interface
let d: ICustomUri = c.cast()?;
let mut domain = BSTR::new();
d.GetDomain(&mut domain).ok()?;
assert_eq!(domain, "kennykerr.ca");
let mut property = BSTR::new();
d.GetPropertyBSTR(Uri_PROPERTY_DOMAIN, &mut property, 0).ok()?;
assert_eq!(property, "property");
let mut property = 0;
d.GetPropertyDWORD(Uri_PROPERTY_PORT, &mut property, 0).ok()?;
assert_eq!(property, 123);
// Use the custom implementation of `Persist` through the custom interface of `ICustomPersist`
let p: ICustomPersistMemory = p.cast()?;
let mut size = 0;
p.GetSizeMax(&mut size).ok()?;
assert_eq!(size, 10);
assert_eq!(p.IsDirty(), S_FALSE);
p.Load(&[0xAAu8, 0xBB, 0xCC] as *const _ as *const _, 3).ok()?;
assert_eq!(p.IsDirty(), S_OK);
let mut memory = [0x00u8, 0x00, 0x00, 0x00];
p.Save(&mut memory as *mut _ as *mut _, true.into(), 4).ok()?;
assert_eq!(p.IsDirty(), S_FALSE);
assert_eq!(memory, [0xAAu8, 0xBB, 0xCC, 0x00]);

let p: ICustomPersist = p.cast()?;
let mut b = GUID::default();
p.GetClassID(&mut b).ok()?;
assert_eq!(b, "117fb826-2155-483a-b50d-bc99a2c7cca3".into());

Ok(())
}
Expand Down

0 comments on commit 86e8088

Please sign in to comment.