Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows unicode (_W) api for all Win32 functions #89

Merged
merged 12 commits into from
Aug 1, 2024
192 changes: 118 additions & 74 deletions src/windows/enumerate.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
use std::collections::HashSet;
use std::ffi::{CStr, CString};
use std::{mem, ptr};

use winapi::ctypes::c_void;
use winapi::shared::guiddef::*;
use winapi::shared::minwindef::*;
use winapi::shared::winerror::*;
use winapi::um::cfgmgr32::*;
use winapi::um::cguid::GUID_NULL;
use winapi::um::errhandlingapi::GetLastError;
use winapi::um::setupapi::*;
use winapi::um::winnt::KEY_READ;
use winapi::um::winnt::{KEY_READ, REG_SZ};
use winapi::um::winreg::*;

use crate::{Error, ErrorKind, Result, SerialPortInfo, SerialPortType, UsbPortInfo};

/// takes normal Rust `str` and outputs a null terminated UTF-16 encoded string
fn as_utf16(utf8: &str) -> Vec<u16> {
utf8.encode_utf16().chain(Some(0)).collect()
}

/// takes a UTF-16 encoded slice (null termination not required)
/// and converts to a UTF8 Rust string. Trailing null chars are removed
fn from_utf16_lossy_trimmed(utf16: &[u16]) -> String {
String::from_utf16_lossy(utf16)
.trim_end_matches(0 as char)
.to_string()
}

/// According to the MSDN docs, we should use SetupDiGetClassDevs, SetupDiEnumDeviceInfo
/// and SetupDiGetDeviceInstanceId in order to enumerate devices.
/// https://msdn.microsoft.com/en-us/windows/hardware/drivers/install/enumerating-installed-devices
Expand All @@ -24,13 +36,10 @@ fn get_ports_guids() -> Result<Vec<GUID>> {
//
// The list of system defined classes can be found here:
// https://learn.microsoft.com/en-us/windows-hardware/drivers/install/system-defined-device-setup-classes-available-to-vendors
let class_names = [
// Note; since names are valid UTF-8, unwrap can't fail
CString::new("Ports").unwrap(),
CString::new("Modem").unwrap(),
];
let class_names = ["Ports", "Modem"];
let mut guids: Vec<GUID> = Vec::new();
for class_name in class_names {
let class_name_w = as_utf16(class_name);
let mut num_guids: DWORD = 1; // Initially assume that there is only 1 guid per name.
let class_start_idx = guids.len(); // start idx for this name (for potential resize with multiple guids)

Expand All @@ -40,8 +49,8 @@ fn get_ports_guids() -> Result<Vec<GUID>> {
let guid_buffer = &mut guids[class_start_idx..];
// Find out how many GUIDs are associated with this class name. num_guids will tell us how many there actually are.
let res = unsafe {
SetupDiClassGuidsFromNameA(
class_name.as_ptr(),
SetupDiClassGuidsFromNameW(
class_name_w.as_ptr(),
guid_buffer.as_mut_ptr(),
guid_buffer.len() as DWORD,
&mut num_guids,
Expand Down Expand Up @@ -186,7 +195,7 @@ impl PortDevices {
// Ports class (given by `guid`).
pub fn new(guid: &GUID) -> Self {
PortDevices {
hdi: unsafe { SetupDiGetClassDevsA(guid, ptr::null(), ptr::null_mut(), DIGCF_PRESENT) },
hdi: unsafe { SetupDiGetClassDevsW(guid, ptr::null(), ptr::null_mut(), DIGCF_PRESENT) },
dev_idx: 0,
}
}
Expand Down Expand Up @@ -236,32 +245,27 @@ struct PortDevice {
}

impl PortDevice {
// Retrieves the device instance id string associated with this device's parent.
// This is useful for determining the serial number of a composite USB device.
/// Retrieves the device instance id string associated with this device's parent.
/// This is useful for determining the serial number of a composite USB device.
fn parent_instance_id(&mut self) -> Option<String> {
let mut result_buf = [0i8; MAX_PATH];
let mut result_buf = [0u16; MAX_PATH];
let mut parent_device_instance_id = 0;

let res =
unsafe { CM_Get_Parent(&mut parent_device_instance_id, self.devinfo_data.DevInst, 0) };
if res == CR_SUCCESS {
let buffer_len = result_buf.len() - 1;
let res = unsafe {
CM_Get_Device_IDA(
CM_Get_Device_IDW(
parent_device_instance_id,
result_buf.as_mut_ptr(),
(result_buf.len() - 1) as ULONG,
buffer_len as ULONG,
0,
)
};

if res == CR_SUCCESS {
let end_of_buffer = result_buf.len() - 1;
result_buf[end_of_buffer] = 0;
Some(unsafe {
CStr::from_ptr(result_buf.as_ptr())
.to_string_lossy()
.into_owned()
})
Some(from_utf16_lossy_trimmed(&result_buf))
} else {
None
}
Expand All @@ -270,35 +274,34 @@ impl PortDevice {
}
}

// Retrieves the device instance id string associated with this device. Some examples of
// instance id strings are:
// MicroPython Board: USB\VID_F055&PID_9802\385435603432
// FTDI USB Adapter: FTDIBUS\VID_0403+PID_6001+A702TB52A\0000
// Black Magic Probe (Composite device with 2 UARTS):
// GDB Port: USB\VID_1D50&PID_6018&MI_00\6&A694CA9&0&0000
// UART Port: USB\VID_1D50&PID_6018&MI_02\6&A694CA9&0&0002
/// Retrieves the device instance id string associated with this device. Some examples of
/// instance id strings are:
/// * MicroPython Board: USB\VID_F055&PID_9802\385435603432
/// * FTDI USB Adapter: FTDIBUS\VID_0403+PID_6001+A702TB52A\0000
/// * Black Magic Probe (Composite device with 2 UARTS):
/// * GDB Port: USB\VID_1D50&PID_6018&MI_00\6&A694CA9&0&0000
/// * UART Port: USB\VID_1D50&PID_6018&MI_02\6&A694CA9&0&0002
///
/// Reference: https://learn.microsoft.com/en-us/windows-hardware/drivers/install/device-instance-ids
fn instance_id(&mut self) -> Option<String> {
let mut result_buf = [0i8; MAX_PATH];
let mut result_buf = [0u16; MAX_DEVICE_ID_LEN];
let working_buffer_len = result_buf.len() - 1; // always null terminated
let mut desired_result_len = 0; // possibly larger than the buffer
let res = unsafe {
SetupDiGetDeviceInstanceIdA(
SetupDiGetDeviceInstanceIdW(
self.hdi,
&mut self.devinfo_data,
result_buf.as_mut_ptr(),
(result_buf.len() - 1) as DWORD,
ptr::null_mut(),
working_buffer_len as DWORD,
&mut desired_result_len,
)
};
if res == FALSE {
// Try to retrieve hardware id property.
self.property(SPDRP_HARDWAREID)
} else {
let end_of_buffer = result_buf.len() - 1;
result_buf[end_of_buffer] = 0;
Some(unsafe {
CStr::from_ptr(result_buf.as_ptr())
.to_string_lossy()
.into_owned()
})
let actual_result_len = working_buffer_len.min(desired_result_len as usize);
Some(from_utf16_lossy_trimmed(&result_buf[..actual_result_len]))
}
}

Expand All @@ -324,6 +327,7 @@ impl PortDevice {

// Retrieves the port name (i.e. COM6) associated with this device.
pub fn name(&mut self) -> String {
// https://learn.microsoft.com/en-us/windows/win32/api/setupapi/nf-setupapi-setupdiopendevregkey
let hkey = unsafe {
SetupDiOpenDevRegKey(
self.hdi,
Expand All @@ -335,27 +339,43 @@ impl PortDevice {
)
};

if hkey as *mut c_void == winapi::um::handleapi::INVALID_HANDLE_VALUE {
// failed to open registry key. Return empty string as the failure case
return String::new();
}

// https://learn.microsoft.com/en-us/windows/win32/api/winreg/nf-winreg-regqueryvalueexw
let mut port_name_buffer = [0u16; MAX_PATH];
let mut port_name_len = port_name_buffer.len() as DWORD;
let value_name: Vec<u16> = "PortName".encode_utf16().chain(Some(0)).collect();
let buffer_byte_len = 2 * port_name_buffer.len() as DWORD;
let mut byte_len = buffer_byte_len;
let mut value_type = 0;

unsafe {
let value_name = as_utf16("PortName");
let err = unsafe {
RegQueryValueExW(
sirhcel marked this conversation as resolved.
Show resolved Hide resolved
hkey,
value_name.as_ptr(),
ptr::null_mut(),
ptr::null_mut(),
&mut value_type,
port_name_buffer.as_mut_ptr() as *mut u8,
&mut port_name_len,
&mut byte_len,
)
};
unsafe { RegCloseKey(hkey) };
if FAILED(err) {
// failed to query registry for some reason. Return empty string as the failure case
return String::new();
}
// https://learn.microsoft.com/en-us/windows/win32/sysinfo/registry-value-types
if value_type != REG_SZ || byte_len % 2 != 0 || byte_len > buffer_byte_len {
// read something but it wasn't the expected registry type
return String::new();
}
// len of u16 chars, not bytes
let len = buffer_byte_len as usize / 2;
let port_name = &port_name_buffer[0..len];

let port_name = &port_name_buffer[0..port_name_len as usize];

String::from_utf16_lossy(port_name)
.trim_end_matches(0 as char)
.to_string()
from_utf16_lossy_trimmed(port_name)
}

// Determines the port_type for this device, and if it's a USB port populate the various fields.
Expand All @@ -374,29 +394,29 @@ impl PortDevice {
// Retrieves a device property and returns it, if it exists. Returns None if the property
// doesn't exist.
fn property(&mut self, property_id: DWORD) -> Option<String> {
let mut value_type = 0;
let mut property_buf = [0u16; MAX_PATH];

let res = unsafe {
SetupDiGetDeviceRegistryPropertyW(
self.hdi,
&mut self.devinfo_data,
property_id,
ptr::null_mut(),
&mut value_type,
property_buf.as_mut_ptr() as PBYTE,
property_buf.len() as DWORD,
ptr::null_mut(),
)
};

if res == FALSE && unsafe { GetLastError() } != ERROR_INSUFFICIENT_BUFFER {
if res == FALSE || value_type != REG_SZ {
return None;
}

// Using the unicode version of 'SetupDiGetDeviceRegistryProperty' seems to report the
// entire mfg registry string. This typically includes some driver information that we should discard.
// Example string: 'FTDI5.inf,%ftdi%;FTDI'
String::from_utf16_lossy(&property_buf)
.trim_end_matches(0 as char)
from_utf16_lossy_trimmed(&property_buf)
sirhcel marked this conversation as resolved.
Show resolved Hide resolved
.split(';')
.last()
.map(str::to_string)
Expand All @@ -411,15 +431,15 @@ impl PortDevice {
fn get_registry_com_ports() -> HashSet<String> {
let mut ports_list = HashSet::new();

let reg_key = b"HARDWARE\\DEVICEMAP\\SERIALCOMM\0";
let key_ptr = reg_key.as_ptr() as *const i8;
let reg_key = as_utf16("HARDWARE\\DEVICEMAP\\SERIALCOMM");
let key_ptr = reg_key.as_ptr();
let mut ports_key = std::ptr::null_mut();

// SAFETY: ffi, all inputs are correct
let open_res =
unsafe { RegOpenKeyExA(HKEY_LOCAL_MACHINE, key_ptr, 0, KEY_READ, &mut ports_key) };
unsafe { RegOpenKeyExW(HKEY_LOCAL_MACHINE, key_ptr, 0, KEY_READ, &mut ports_key) };
if SUCCEEDED(open_res) {
let mut class_name_buff = [0i8; MAX_PATH];
let mut class_name_buff = [0u16; MAX_PATH];
let mut class_name_size = MAX_PATH as u32;
let mut sub_key_count = 0;
let mut largest_sub_key = 0;
Expand All @@ -434,7 +454,7 @@ fn get_registry_com_ports() -> HashSet<String> {
};
// SAFETY: ffi, all inputs are correct
let query_res = unsafe {
RegQueryInfoKeyA(
RegQueryInfoKeyW(
ports_key,
class_name_buff.as_mut_ptr(),
&mut class_name_size,
Expand All @@ -451,36 +471,40 @@ fn get_registry_com_ports() -> HashSet<String> {
};
if SUCCEEDED(query_res) {
for idx in 0..num_key_values {
let mut val_name_buff = [0i8; MAX_PATH];
let mut val_name_buff = [0u16; MAX_PATH];
let mut val_name_size = MAX_PATH as u32;
let mut value_type = 0;
// if 100 chars is not enough for COM<number> something is very wrong
let mut val_data = [0; 100];
let mut data_size = val_data.len() as u32;
let mut val_data = [0u16; MAX_PATH];
let buffer_byte_len = 2 * val_data.len() as DWORD; // len doubled
let mut byte_len = buffer_byte_len;

// SAFETY: ffi, all inputs are correct
let res = unsafe {
RegEnumValueA(
RegEnumValueW(
ports_key,
idx,
val_name_buff.as_mut_ptr(),
&mut val_name_size,
std::ptr::null_mut(),
&mut value_type,
sirhcel marked this conversation as resolved.
Show resolved Hide resolved
val_data.as_mut_ptr(),
&mut data_size,
val_data.as_mut_ptr() as *mut u8,
&mut byte_len,
)
};
if FAILED(res) || val_data.len() < data_size as usize {
if FAILED(res)
|| value_type != REG_SZ // only valid for text values
|| byte_len % 2 != 0 // out byte len should be a multiple of char size
sirhcel marked this conversation as resolved.
Show resolved Hide resolved
|| byte_len > buffer_byte_len
{
break;
}
// key data is returned as u16
// SAFETY: data_size is checked and pointer is valid
let val_data = CStr::from_bytes_with_nul(unsafe {
std::slice::from_raw_parts(val_data.as_ptr(), data_size as usize)
let val_data = from_utf16_lossy_trimmed(unsafe {
let utf16_len = byte_len / 2; // utf16 len
std::slice::from_raw_parts(val_data.as_ptr(), utf16_len as usize)
});

if let Ok(port) = val_data {
ports_list.insert(port.to_string_lossy().into_owned());
}
ports_list.insert(val_data);
}
}
// SAFETY: ffi, all inputs are correct
Expand Down Expand Up @@ -543,6 +567,26 @@ mod test {

use quickcheck_macros::quickcheck;

#[test]
fn from_utf16_lossy_trimmed_trimming_empty() {
assert_eq!("", from_utf16_lossy_trimmed(&[]));
assert_eq!("", from_utf16_lossy_trimmed(&[0]));
}

#[test]
fn from_utf16_lossy_trimmed_trimming() {
let test_str = "Testing";
let wtest_str: Vec<u16> = as_utf16(test_str);
let wtest_str_trailing = wtest_str
.iter()
.copied()
.chain([0, 0, 0, 0]) // add some null chars
.collect::<Vec<_>>();
let and_back = from_utf16_lossy_trimmed(&wtest_str_trailing);

assert_eq!(test_str, and_back);
}

// Check that passing some random data to HwidMatches::new() does not cause a panic.
#[quickcheck]
fn quickcheck_hwidmatches_new_does_not_panic_from_random_input(hwid: String) -> bool {
Expand Down
Loading