diff --git a/src/dinput8.rs b/src/dinput8.rs index 0c962b2..64d55dc 100644 --- a/src/dinput8.rs +++ b/src/dinput8.rs @@ -16,7 +16,7 @@ pub unsafe extern "C" fn directinput8_create( dwversion: u32, riid: *const GUID, out: *mut Option, - _outer: Option, + outer: Option, ) -> HRESULT { // Instead of trying to load the original dinput8.dll and calling the original `DirectInput8Create`, // we can simply load the dinput8 interface via COM and return it up to our caller. This is basically @@ -24,19 +24,30 @@ pub unsafe extern "C" fn directinput8_create( // // Reference: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/ee416756(v=vs.85) let f = || -> Result { + // Initialize COM with the default apartment type. + // NOTE: Disabled for now. The documentation does not really specify what the original `DirectInput8Create` + // does in this case, and I'm too lazy to disassemble dinput8.dll to figure it out. + // unsafe { CoInitializeEx(None, COINIT_MULTITHREADED)? }; + match unsafe { riid.as_ref() } { Some(&IDirectInput8A::IID) => { let dinput: IDirectInput8A = - CoCreateInstance(&CLSID_DirectInput8, None, CLSCTX_INPROC_SERVER)?; + CoCreateInstance(&CLSID_DirectInput8, outer.as_ref(), CLSCTX_INPROC_SERVER)?; - dinput.Initialize(hinst, dwversion)?; + // Per the documentation, if pUnkOuter != NULL then the resulting object must be initialized manually. + if outer.is_none() { + dinput.Initialize(hinst, dwversion)?; + } Ok(dinput.cast()?) } Some(&IDirectInput8W::IID) => { let dinput: IDirectInput8W = - CoCreateInstance(&CLSID_DirectInput8, None, CLSCTX_INPROC_SERVER)?; + CoCreateInstance(&CLSID_DirectInput8, outer.as_ref(), CLSCTX_INPROC_SERVER)?; - dinput.Initialize(hinst, dwversion)?; + // Per the documentation, if pUnkOuter != NULL then the resulting object must be initialized manually. + if outer.is_none() { + dinput.Initialize(hinst, dwversion)?; + } Ok(dinput.cast()?) } _ => return Err(E_NOINTERFACE.into()), diff --git a/src/lib.rs b/src/lib.rs index 90d6846..26acc10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,6 @@ use anyhow::Context; use sha2::{Digest, Sha256}; use windows::core::PCWSTR; -use windows::Win32::System::Com::{CoInitializeEx, COINIT_MULTITHREADED}; use windows::Win32::{ Foundation::{HANDLE, HINSTANCE}, System::{ @@ -102,9 +101,6 @@ fn get_module_slice(info: &MODULEINFO) -> *const [u8] { /// Called upon DLL attach. This function verifies the UDK and initializes /// hooks if the UDK matches our known hash. fn dll_attach() -> anyhow::Result<()> { - // Ensure that COM is initialized. - unsafe { CoInitializeEx(None, COINIT_MULTITHREADED) }.context("failed to initialize COM")?; - let process = unsafe { GetCurrentProcess() }; let module = unsafe { GetModuleHandleA(None) }.context("failed to get module handle")?;