diff --git a/talpid-core/src/firewall/windows/mod.rs b/talpid-core/src/firewall/windows/mod.rs index d4b0b883d417..fa596111323a 100644 --- a/talpid-core/src/firewall/windows/mod.rs +++ b/talpid-core/src/firewall/windows/mod.rs @@ -455,11 +455,13 @@ pub extern "system" fn log_sink( } } +/// Convert `mb_string`, with the given character encoding `codepage`, to a UTF-16 string. fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result, io::Error> { - if unsafe { *mb_string.as_ptr() } == 0 { + if mb_string.is_empty() { return Ok(vec![]); } + // SAFETY: `mb_string` is null-terminated and valid. let wc_size = unsafe { MultiByteToWideChar( codepage, @@ -475,8 +477,10 @@ fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result, io::Er return Err(io::Error::last_os_error()); } - let mut wc_buffer = Vec::with_capacity(wc_size as usize); + let mut wc_buffer = vec![0u16; usize::try_from(wc_size).unwrap()]; + // SAFETY: `wc_buffer` can contain up to `wc_size` characters, including a null + // terminator. let chars_written = unsafe { MultiByteToWideChar( codepage, @@ -492,11 +496,35 @@ fn multibyte_to_wide(mb_string: &CStr, codepage: u32) -> Result, io::Er return Err(io::Error::last_os_error()); } - unsafe { wc_buffer.set_len((chars_written - 1) as usize) }; + wc_buffer.truncate(usize::try_from(chars_written - 1).unwrap()); Ok(wc_buffer) } +#[cfg(test)] +mod test { + use super::multibyte_to_wide; + use windows_sys::Win32::Globalization::CP_UTF8; + + #[test] + fn test_multibyte_to_wide() { + // € = 0x20AC in UTF-16 + let converted = multibyte_to_wide(c"€€", CP_UTF8); + const EXPECTED: &[u16] = &[0x20AC, 0x20AC]; + assert!( + matches!(converted.as_deref(), Ok(EXPECTED)), + "expected Ok({EXPECTED:?}), got {converted:?}", + ); + + // boundary case + let converted = multibyte_to_wide(c"", CP_UTF8); + assert!( + matches!(converted.as_deref(), Ok([])), + "unexpected result {converted:?}" + ); + } +} + // Convert `result` into an option and log the error, if any. fn consume_and_log_hyperv_err( action: &'static str,