diff --git a/src/de/meta.rs b/src/de/meta.rs index 99039ac..f7e7e66 100644 --- a/src/de/meta.rs +++ b/src/de/meta.rs @@ -50,9 +50,20 @@ where forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes - byte_buf option unit unit_struct seq tuple + byte_buf option unit seq tuple tuple_struct map enum identifier ignored_any } + + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_struct(name, &[], visitor) + } } let mut name = None; diff --git a/src/de/variant_de.rs b/src/de/variant_de.rs index da5e1d4..468ac80 100644 --- a/src/de/variant_de.rs +++ b/src/de/variant_de.rs @@ -33,7 +33,7 @@ impl<'de> serde::Deserializer<'de> for Variant { { match self { Variant::Null => visitor.visit_none(), - Variant::Empty => visitor.visit_none(), + Variant::Empty => visitor.visit_unit(), Variant::String(s) => visitor.visit_string(s), Variant::I1(n) => visitor.visit_i8(n), Variant::I2(n) => visitor.visit_i16(n), diff --git a/src/de/wbem_class_de.rs b/src/de/wbem_class_de.rs index bd782f2..bb6753c 100644 --- a/src/de/wbem_class_de.rs +++ b/src/de/wbem_class_de.rs @@ -212,9 +212,16 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer { visitor.visit_string(class_name) } + fn deserialize_unit(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_unit() + } + forward_to_deserialize_any! { bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes - byte_buf option unit unit_struct seq tuple + byte_buf option unit_struct seq tuple tuple_struct ignored_any } } diff --git a/src/lib.rs b/src/lib.rs index b6bb975..e306b3a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,6 +194,13 @@ //! [WMI Code Creator]: https://www.microsoft.com/en-us/download/details.aspx?id=8572 //! [`notification`]: connection/struct.WMIConnection.html#method.notification //! +//! # Executing Methods +//! +//! The crate also offers support for executing WMI methods on classes and instances. +//! +//! See [`WMIConnection::exec_class_method`], [`WMIConnection::exec_instance_method`] and [`WMIConnection::exec_method_native_wrapper`] +//! for detailed examples. +//! //! # Internals //! //! [`WMIConnection`](WMIConnection) is used to create and execute a WMI query, returning @@ -277,9 +284,11 @@ mod datetime_time; pub mod context; pub mod de; pub mod duration; +pub mod method; pub mod query; pub mod result_enumerator; pub mod safearray; +pub mod ser; pub mod utils; pub mod variant; diff --git a/src/method.rs b/src/method.rs new file mode 100644 index 0000000..532199e --- /dev/null +++ b/src/method.rs @@ -0,0 +1,305 @@ +use std::collections::HashMap; + +use serde::{de, Serialize}; +use windows_core::{BSTR, HSTRING, VARIANT}; + +use crate::{ + de::meta::struct_name_and_fields, result_enumerator::IWbemClassWrapper, + ser::variant_ser::VariantStructSerializer, Variant, WMIConnection, WMIError, WMIResult, +}; + +impl WMIConnection { + /// Wrapper for WMI's [ExecMethod](https://learn.microsoft.com/en-us/windows/win32/api/wbemcli/nf-wbemcli-iwbemservices-execmethod) function. + /// + /// This function is used internally by [`WMIConnection::exec_class_method`] and [`WMIConnection::exec_instance_method`], + /// which are a higher-level abstraction, dealing with Rust data types instead of raw Variants, that should be preferred to use. + /// + /// In the case of a class ("static") method, `object_path` should be the same as `method_class`. + /// + /// Returns `None` if the method has no out parameters and a `void` return type, and an [`IWbemClassWrapper`] containing the output otherwise. + /// A method with a return type other than `void` will always have a generic property named `ReturnValue` in the output class wrapper with the return value of the WMI method call. + /// + /// ```edition2021 + /// # use wmi::{COMLibrary, Variant, WMIConnection, WMIResult}; + /// # fn main() -> WMIResult<()> { + /// # let wmi_con = WMIConnection::new(COMLibrary::new()?)?; + /// let in_params = [ + /// ("CommandLine".to_string(), Variant::from("systeminfo".to_string())) + /// ].into_iter().collect(); + /// + /// // Because Create has a return value and out parameters, the Option returned will never be None. + /// // Note: The Create call can be unreliable, so consider using another means of starting processes. + /// let out = wmi_con.exec_method_native_wrapper("Win32_Process", "Win32_Process", "Create", in_params)?.unwrap(); + /// println!("The return code of the Create call is {:?}", out.get_property("ReturnValue")?); + /// # Ok(()) + /// # } + /// ``` + pub fn exec_method_native_wrapper( + &self, + method_class: impl AsRef, + object_path: impl AsRef, + method: impl AsRef, + in_params: HashMap, + ) -> WMIResult> { + let method_class = BSTR::from(method_class.as_ref()); + let object_path = BSTR::from(object_path.as_ref()); + let method = BSTR::from(method.as_ref()); + + // See https://learn.microsoft.com/en-us/windows/win32/api/wbemcli/nf-wbemcli-iwbemclassobject-getmethod + // GetMethod can only be called on a class definition, so we retrieve that before retrieving a specific object + let mut class_definition = None; + unsafe { + self.svc.GetObject( + &method_class, + Default::default(), + &self.ctx.0, + Some(&mut class_definition), + None, + )?; + } + let class_definition = class_definition.ok_or(WMIError::ResultEmpty)?; + // Retrieve the input signature of the WMI method. + // The fields of the resulting IWbemClassObject will have the names and types of the WMI method's input parameters + let mut input_signature = None; + unsafe { + class_definition.GetMethod( + &method, + Default::default(), + &mut input_signature, + std::ptr::null_mut(), + )?; + } + + // The method may have no input parameters, such as in this case: https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/reboot-method-in-class-win32-operatingsystem + let in_params = match input_signature { + Some(input) => { + let inst; + unsafe { + inst = input.SpawnInstance(Default::default())?; + }; + // Set every field of the input object to the corresponding input parameter passed to this function + for (wszname, value) in in_params { + let wszname = HSTRING::from(wszname); + let value = TryInto::::try_into(value)?; + + // See https://learn.microsoft.com/en-us/windows/win32/api/wbemcli/nf-wbemcli-iwbemclassobject-put + // Note that the example shows the variant is expected to be cleared (dropped) after the call to Put, + // so passing &value is acceptable here + unsafe { + inst.Put(&wszname, Default::default(), &value, 0)?; + } + } + Some(inst) + } + None => None, + }; + + // In the case of a method with no out parameters and a VOID return type, there will be no out-parameters object + let mut output = None; + unsafe { + self.svc.ExecMethod( + &object_path, + &method, + Default::default(), + &self.ctx.0, + in_params.as_ref(), + Some(&mut output), + None, + )?; + } + + Ok(output.map(IWbemClassWrapper::new)) + } + + /// Executes a method of a WMI class not tied to any specific instance. Examples include + /// [Create](https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/create-method-in-class-win32-process) of `Win32_Process` + /// and [AddPrinterConnection](https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/addprinterconnection-method-in-class-win32-printer) of `Win32_Printer`. + /// + /// `MethodClass` should have the name of the class on which the method is being invoked. + /// `In` and `Out` can be `()` or any custom structs supporting (de)serialization containing the input and output parameters of the function. + /// + /// A method with a return type other than `void` will always try to populate a generic property named `ReturnValue` in the output object with the return value of the WMI method call. + /// If the method call has a `void` return type and no out parameters, the only acceptable type for `Out` is `()`. + /// + /// Arrays, Options, unknowns, and nested objects cannot be passed as input parameters due to limitations in how variants are constructed by `windows-rs`. + /// + /// This function uses [`WMIConnection::exec_instance_method`] internally, with the name of the method class being the instance path, as is expected by WMI. + /// + /// ```edition2021 + /// # use serde::{Deserialize, Serialize}; + /// # use wmi::{COMLibrary, Variant, WMIConnection, WMIResult}; + /// #[derive(Serialize)] + /// # #[allow(non_snake_case)] + /// struct CreateInput { + /// CommandLine: String + /// } + /// + /// #[derive(Deserialize)] + /// # #[allow(non_snake_case)] + /// struct CreateOutput { + /// ReturnValue: u32, + /// ProcessId: u32 + /// } + /// + /// #[derive(Deserialize)] + /// # #[allow(non_camel_case_types)] + /// struct Win32_Process; + /// + /// # fn main() -> WMIResult<()> { + /// # let wmi_con = WMIConnection::new(COMLibrary::new()?)?; + /// // Note: The Create call can be unreliable, so consider using another means of starting processes. + /// let input = CreateInput { + /// CommandLine: "systeminfo".to_string() + /// }; + /// let output: CreateOutput = wmi_con.exec_class_method::("Create", input)?; + /// + /// println!("The return code of the Create call is {}", output.ReturnValue); + /// println!("The ID of the created process is: {}", output.ProcessId); + /// # Ok(()) + /// # } + /// ``` + pub fn exec_class_method( + &self, + method: impl AsRef, + in_params: In, + ) -> WMIResult + where + MethodClass: de::DeserializeOwned, + In: Serialize, + Out: de::DeserializeOwned, + { + let (method_class, _) = struct_name_and_fields::()?; + self.exec_instance_method::(method, method_class, in_params) + } + + /// Executes a WMI method on a specific instance of a class. Examples include + /// [GetSupportedSize](https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-Volume-getsupportedsizes) of `MSFT_Volume` + /// and [Pause](https://learn.microsoft.com/en-us/windows/win32/cimwin32prov/pause-method-in-class-win32-printer) of `Win32_Printer`. + /// + /// `MethodClass` should have the name of the class on which the method is being invoked. + /// `In` and `Out` can be `()` or any custom structs supporting (de)serialization containing the input and output parameters of the function. + /// `object_path` is the `__Path` variable of the class instance on which the method is being called, which can be obtained from a WMI query. + /// + /// A method with a return type other than `void` will always try to populate a generic property named `ReturnValue` in the output object with the return value of the WMI method call. + /// If the method call has a `void` return type and no out parameters, the only acceptable type for `Out` is `()`. + /// + /// Arrays, Options, unknowns, and nested objects cannot be passed as input parameters due to limitations in how variants are constructed by `windows-rs`. + /// + /// ```edition2021 + /// # use serde::{Deserialize, Serialize}; + /// # use wmi::{COMLibrary, FilterValue, Variant, WMIConnection, WMIResult}; + /// #[derive(Deserialize)] + /// # #[allow(non_snake_case)] + /// struct PrinterOutput { + /// ReturnValue: u32 + /// } + /// + /// #[derive(Deserialize)] + /// # #[allow(non_camel_case_types, non_snake_case)] + /// struct Win32_Printer { + /// __Path: String + /// } + /// + /// # fn main() -> WMIResult<()> { + /// # let wmi_con = WMIConnection::new(COMLibrary::new()?)?; + /// let printers: Vec = wmi_con.query()?; + /// + /// for printer in printers { + /// let output: PrinterOutput = wmi_con.exec_instance_method::("Pause", &printer.__Path, ())?; + /// println!("Pausing the printer returned {}", output.ReturnValue); + /// + /// let output: PrinterOutput = wmi_con.exec_instance_method::("Resume", &printer.__Path, ())?; + /// println!("Resuming the printer returned {}", output.ReturnValue); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn exec_instance_method( + &self, + method: impl AsRef, + object_path: impl AsRef, + in_params: In, + ) -> WMIResult + where + MethodClass: de::DeserializeOwned, + In: Serialize, + Out: de::DeserializeOwned, + { + let (method_class, _) = struct_name_and_fields::()?; + let serializer = VariantStructSerializer::new(); + match in_params.serialize(serializer) { + Ok(field_map) => { + let output = + self.exec_method_native_wrapper(method_class, object_path, method, field_map)?; + + match output { + Some(class_wrapper) => Ok(class_wrapper.into_desr()?), + None => Out::deserialize(Variant::Empty), + } + } + Err(e) => Err(WMIError::ConvertVariantError(e.to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use crate::tests::fixtures::wmi_con; + use serde::{Deserialize, Serialize}; + + #[derive(Deserialize)] + struct Win32_Process { + __Path: String, + } + + #[derive(Serialize)] + struct CreateParams { + CommandLine: String, + } + + #[derive(Deserialize)] + #[allow(non_snake_case)] + struct CreateOutput { + ReturnValue: u32, + ProcessId: u32, + } + + #[test] + fn it_exec_methods() { + // Create powershell instance + let wmi_con = wmi_con(); + let in_params = CreateParams { + CommandLine: "powershell.exe".to_string(), + }; + let out = wmi_con + .exec_class_method::("Create", in_params) + .unwrap(); + + assert_eq!(out.ReturnValue, 0); + + let process = wmi_con + .raw_query::(format!( + "SELECT * FROM Win32_Process WHERE ProcessId = {}", + out.ProcessId + )) + .unwrap() + .into_iter() + .next() + .unwrap(); + + wmi_con + .exec_instance_method::("Terminate", process.__Path, ()) + .unwrap(); + + assert!( + wmi_con + .raw_query::(format!( + "SELECT * FROM Win32_Process WHERE ProcessId = {}", + out.ProcessId + )) + .unwrap() + .len() + == 0 + ); + } +} diff --git a/src/ser/mod.rs b/src/ser/mod.rs new file mode 100644 index 0000000..9531f66 --- /dev/null +++ b/src/ser/mod.rs @@ -0,0 +1 @@ +pub mod variant_ser; diff --git a/src/ser/variant_ser.rs b/src/ser/variant_ser.rs new file mode 100644 index 0000000..c44416e --- /dev/null +++ b/src/ser/variant_ser.rs @@ -0,0 +1,547 @@ +//! This module implements a custom serializer type, [`VariantStructSerializer`], +//! to serialize a Rust struct into a HashMap mapping field name strings to [`Variant`] values +use std::{any::type_name, collections::HashMap, fmt::Display}; + +use crate::Variant; +use serde::{ + ser::{Impossible, SerializeStruct}, + Serialize, Serializer, +}; +use thiserror::Error; + +macro_rules! serialize_struct_err_stub { + ($signature:ident, $type:ty) => { + fn $signature(self, _v: $type) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + }; +} + +macro_rules! serialize_variant_err_stub { + ($signature:ident, $type:ty) => { + fn $signature(self, _v: $type) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + type_name::<$type>().to_string(), + )) + } + }; +} + +macro_rules! serialize_variant { + ($signature:ident, $type:ty) => { + fn $signature(self, v: $type) -> Result { + Ok(Variant::from(v)) + } + }; +} + +struct VariantSerializer {} + +impl Serializer for VariantSerializer { + type Ok = Variant; + type Error = VariantSerializerError; + + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + type SerializeMap = Impossible; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + + serialize_variant!(serialize_bool, bool); + serialize_variant!(serialize_i8, i8); + serialize_variant!(serialize_i16, i16); + serialize_variant!(serialize_i32, i32); + serialize_variant!(serialize_i64, i64); + serialize_variant!(serialize_u8, u8); + serialize_variant!(serialize_u16, u16); + serialize_variant!(serialize_u32, u32); + serialize_variant!(serialize_u64, u64); + serialize_variant!(serialize_f32, f32); + serialize_variant!(serialize_f64, f64); + + fn serialize_unit(self) -> Result { + Ok(Variant::Empty) + } + + fn serialize_str(self, v: &str) -> Result { + Ok(Variant::from(v.to_string())) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + Ok(Variant::from(variant.to_string())) + } + + // Generic serializer code not relevant to this use case + + serialize_variant_err_stub!(serialize_char, char); + serialize_variant_err_stub!(serialize_bytes, &[u8]); + + fn serialize_none(self) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + "None".to_string(), + )) + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(VariantSerializerError::UnsupportedVariantType( + type_name::().to_string(), + )) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + "Sequence".to_string(), + )) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + "Tuple".to_string(), + )) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + name.to_string(), + )) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::UnsupportedVariantType(format!( + "{variant}::{name}" + ))) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + "Map".to_string(), + )) + } + + fn serialize_struct( + self, + name: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::UnsupportedVariantType( + name.to_string(), + )) + } + + fn serialize_struct_variant( + self, + name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::UnsupportedVariantType(format!( + "{variant}::{name}" + ))) + } +} + +/// Serializes a struct to a HashMap of key-value pairs, with the key being the field name, and the value being the field value wrapped in a [`Variant`]. +/// +/// VariantStructSerializer only supports serializing fields with basic Rust data types: `i32`, `()`, etc., as well as any of the former in a newtype. +/// +/// ```edition2021 +/// use serde::Serialize; +/// use wmi::ser::variant_ser::VariantStructSerializer; +/// +/// #[derive(Serialize)] +/// struct TestStruct { +/// number: f32, +/// text: String +/// } +/// +/// let test_struct = TestStruct { +/// number: 0.6, +/// text: "foobar".to_string() +/// }; +/// +/// let fields = test_struct.serialize(VariantStructSerializer::new()).unwrap(); +/// +/// for (field_name, field_value) in fields { +/// println!("{field_name}: {field_value:?}"); +/// } +/// ``` +#[derive(Default)] +pub struct VariantStructSerializer { + variant_map: HashMap, +} + +#[derive(Debug, Error)] +pub enum VariantSerializerError { + #[error("Unknown error when serializing struct:\n{0}")] + Unknown(String), + #[error("VariantStructSerializer can only be used to serialize structs.")] + ExpectedStruct, + #[error("{0} cannot be serialized to a Variant.")] + UnsupportedVariantType(String), +} + +impl VariantStructSerializer { + pub fn new() -> Self { + Self { + variant_map: HashMap::new(), + } + } +} + +impl serde::ser::Error for VariantSerializerError { + fn custom(msg: T) -> Self + where + T: Display, + { + VariantSerializerError::Unknown(msg.to_string()) + } +} + +impl Serializer for VariantStructSerializer { + type Ok = HashMap; + + type Error = VariantSerializerError; + + type SerializeStruct = Self; + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Ok(self) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit(self) -> Result { + Ok(HashMap::new()) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + // The following is code for a generic Serializer implementation not relevant to this use case + + type SerializeSeq = Impossible; + type SerializeTuple = Impossible; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + type SerializeMap = Impossible; + type SerializeStructVariant = Impossible; + + serialize_struct_err_stub!(serialize_bool, bool); + serialize_struct_err_stub!(serialize_i8, i8); + serialize_struct_err_stub!(serialize_i16, i16); + serialize_struct_err_stub!(serialize_i32, i32); + serialize_struct_err_stub!(serialize_i64, i64); + serialize_struct_err_stub!(serialize_u8, u8); + serialize_struct_err_stub!(serialize_u16, u16); + serialize_struct_err_stub!(serialize_u32, u32); + serialize_struct_err_stub!(serialize_u64, u64); + serialize_struct_err_stub!(serialize_f32, f32); + serialize_struct_err_stub!(serialize_f64, f64); + serialize_struct_err_stub!(serialize_char, char); + serialize_struct_err_stub!(serialize_str, &str); + serialize_struct_err_stub!(serialize_bytes, &[u8]); + + fn serialize_none(self) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_seq(self, _len: Option) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(VariantSerializerError::ExpectedStruct) + } +} + +impl SerializeStruct for VariantStructSerializer { + type Ok = ::Ok; + + type Error = VariantSerializerError; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + let variant = value.serialize(VariantSerializer {}); + match variant { + Ok(value) => { + self.variant_map.insert(key.to_string(), value); + Ok(()) + } + Err(_) => Err(VariantSerializerError::UnsupportedVariantType( + type_name::().to_string(), + )), + } + } + + fn end(self) -> Result { + Ok(self.variant_map) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Serialize)] + struct TestStruct { + empty: (), + + string: String, + + i1: i8, + i2: i16, + i4: i32, + i8: i64, + + r4: f32, + r8: f64, + + bool: bool, + + ui1: u8, + ui2: u16, + ui4: u32, + ui8: u64, + } + + #[test] + fn it_serialize_struct() { + let test_struct = TestStruct { + empty: (), + string: "Test String".to_string(), + + i1: i8::MAX, + i2: i16::MAX, + i4: i32::MAX, + i8: i64::MAX, + + r4: f32::MAX, + r8: f64::MAX, + + bool: false, + + ui1: u8::MAX, + ui2: u16::MAX, + ui4: u32::MAX, + ui8: u64::MAX, + }; + let expected_field_map: HashMap = [ + ("empty".to_string(), Variant::Empty), + ( + "string".to_string(), + Variant::String("Test String".to_string()), + ), + ("i1".to_string(), Variant::I1(i8::MAX)), + ("i2".to_string(), Variant::I2(i16::MAX)), + ("i4".to_string(), Variant::I4(i32::MAX)), + ("i8".to_string(), Variant::I8(i64::MAX)), + ("r4".to_string(), Variant::R4(f32::MAX)), + ("r8".to_string(), Variant::R8(f64::MAX)), + ("bool".to_string(), Variant::Bool(false)), + ("ui1".to_string(), Variant::UI1(u8::MAX)), + ("ui2".to_string(), Variant::UI2(u16::MAX)), + ("ui4".to_string(), Variant::UI4(u32::MAX)), + ("ui8".to_string(), Variant::UI8(u64::MAX)), + ] + .into_iter() + .collect(); + + let variant_serializer = VariantStructSerializer::new(); + let field_map = test_struct.serialize(variant_serializer).unwrap(); + + assert_eq!(field_map, expected_field_map); + } + + #[derive(Serialize)] + struct NewtypeTest(u32); + #[derive(Serialize)] + struct NewtypeTestWrapper { + newtype: NewtypeTest, + } + #[test] + fn it_serialize_newtype() { + let test_struct = NewtypeTestWrapper { + newtype: NewtypeTest(17), + }; + + let expected_field_map: HashMap = + [("newtype".to_string(), Variant::UI4(17))] + .into_iter() + .collect(); + + let field_map = test_struct + .serialize(VariantStructSerializer::new()) + .unwrap(); + + assert_eq!(field_map, expected_field_map); + } + + #[derive(Serialize)] + struct UnitTest; + + #[test] + fn it_serialize_unit() { + let expected_field_map = HashMap::new(); + let field_map = UnitTest {} + .serialize(VariantStructSerializer::new()) + .unwrap(); + + assert_eq!(field_map, expected_field_map); + } + + #[derive(Serialize)] + #[allow(dead_code)] + enum EnumTest { + NTFS, + FAT32, + ReFS, + } + + #[derive(Serialize)] + struct EnumStructTest { + enum_test: EnumTest, + } + + #[test] + fn it_serialize_enum() { + let test_enum_struct = EnumStructTest { + enum_test: EnumTest::NTFS, + }; + + let expected_field_map = [("enum_test".to_string(), Variant::from("NTFS".to_string()))] + .into_iter() + .collect(); + + let field_map = test_enum_struct + .serialize(VariantStructSerializer::new()) + .unwrap(); + + assert_eq!(field_map, expected_field_map); + } +} diff --git a/src/variant.rs b/src/variant.rs index cbe91a2..ec66a81 100644 --- a/src/variant.rs +++ b/src/variant.rs @@ -270,6 +270,117 @@ impl Variant { } } +impl TryFrom for VARIANT { + type Error = WMIError; + fn try_from(value: Variant) -> WMIResult { + match value { + Variant::Empty => Ok(VARIANT::new()), + + Variant::String(string) => Ok(VARIANT::from(string.as_str())), + Variant::I1(int8) => Ok(VARIANT::from(int8)), + Variant::I2(int16) => Ok(VARIANT::from(int16)), + Variant::I4(int32) => Ok(VARIANT::from(int32)), + Variant::I8(int64) => Ok(VARIANT::from(int64)), + + Variant::R4(float32) => Ok(VARIANT::from(float32)), + Variant::R8(float64) => Ok(VARIANT::from(float64)), + + Variant::Bool(b) => Ok(VARIANT::from(b)), + + Variant::UI1(uint8) => Ok(VARIANT::from(uint8)), + Variant::UI2(uint16) => Ok(VARIANT::from(uint16)), + Variant::UI4(uint32) => Ok(VARIANT::from(uint32)), + Variant::UI8(uint64) => Ok(VARIANT::from(uint64)), + + // windows-rs' VARIANT does not support creating these types of VARIANT at present + Variant::Null => Err(WMIError::ConvertVariantError( + "Cannot convert Variant::Null to a Windows VARIANT".to_string(), + )), + Variant::Array(_) => Err(WMIError::ConvertVariantError( + "Cannot convert Variant::Array to a Windows VARIANT".to_string(), + )), + Variant::Unknown(_) => Err(WMIError::ConvertVariantError( + "Cannot convert Variant::Unknown to a Windows VARIANT".to_string(), + )), + Variant::Object(_) => Err(WMIError::ConvertVariantError( + "Cannot convert Variant::Object to a Windows VARIANT".to_string(), + )), + } + } +} + +macro_rules! impl_try_from_variant { + ($target_type:ty, $variant_type:ident) => { + impl TryFrom for $target_type { + type Error = WMIError; + + fn try_from(value: Variant) -> Result<$target_type, Self::Error> { + match value { + Variant::$variant_type(item) => Ok(item), + other => Err(WMIError::ConvertVariantError(format!( + "Variant {:?} cannot be turned into a {}", + &other, + stringify!($target_type) + ))), + } + } + } + }; +} + +/// Infallible conversion from a Rust type into a Variant wrapper for that type +macro_rules! impl_wrap_type { + ($target_type:ty, $variant_type:ident) => { + impl From<$target_type> for Variant { + fn from(value: $target_type) -> Self { + Variant::$variant_type(value) + } + } + }; +} + +/// Add conversions from a Rust type to its Variant form and vice versa +macro_rules! bidirectional_variant_convert { + ($target_type:ty, $variant_type:ident) => { + impl_try_from_variant!($target_type, $variant_type); + impl_wrap_type!($target_type, $variant_type); + }; +} + +bidirectional_variant_convert!(String, String); +bidirectional_variant_convert!(i8, I1); +bidirectional_variant_convert!(i16, I2); +bidirectional_variant_convert!(i32, I4); +bidirectional_variant_convert!(i64, I8); +bidirectional_variant_convert!(u8, UI1); +bidirectional_variant_convert!(u16, UI2); +bidirectional_variant_convert!(u32, UI4); +bidirectional_variant_convert!(u64, UI8); +bidirectional_variant_convert!(f32, R4); +bidirectional_variant_convert!(f64, R8); +bidirectional_variant_convert!(bool, Bool); + +impl From<()> for Variant { + fn from(_value: ()) -> Self { + Variant::Empty + } +} + +impl TryFrom for () { + type Error = WMIError; + + fn try_from(value: Variant) -> Result<(), Self::Error> { + match value { + Variant::Empty => Ok(()), + other => Err(WMIError::ConvertVariantError(format!( + "Variant {:?} cannot be turned into a {}", + &other, + stringify!(()) + ))), + } + } +} + /// A wrapper around the [`IUnknown`] interface. \ /// Used to retrive [`IWbemClassObject`][winapi::um::Wmi::IWbemClassObject] /// @@ -304,38 +415,6 @@ impl Serialize for IUnknownWrapper { } } -macro_rules! impl_try_from_variant { - ($target_type:ty, $variant_type:ident) => { - impl TryFrom for $target_type { - type Error = WMIError; - - fn try_from(value: Variant) -> Result<$target_type, Self::Error> { - match value { - Variant::$variant_type(item) => Ok(item), - other => Err(WMIError::ConvertVariantError(format!( - "Variant {:?} cannot be turned into a {}", - &other, - stringify!($target_type) - ))), - } - } - } - }; -} - -impl_try_from_variant!(String, String); -impl_try_from_variant!(i8, I1); -impl_try_from_variant!(i16, I2); -impl_try_from_variant!(i32, I4); -impl_try_from_variant!(i64, I8); -impl_try_from_variant!(u8, UI1); -impl_try_from_variant!(u16, UI2); -impl_try_from_variant!(u32, UI4); -impl_try_from_variant!(u64, UI8); -impl_try_from_variant!(f32, R4); -impl_try_from_variant!(f64, R8); -impl_try_from_variant!(bool, Bool); - #[cfg(test)] mod tests { use super::*; @@ -544,4 +623,39 @@ mod tests { let converted = variant.convert_into_cim_type(cim_type).unwrap(); assert_eq!(converted, Variant::Array(vec![])); } + + #[test] + fn it_bidirectional_string_convert() { + let string = "Test String".to_string(); + let variant = Variant::from(string.clone()); + assert_eq!(variant.try_into().ok(), Some(string.clone())); + + let variant = Variant::from(string.clone()); + let ms_variant = VARIANT::try_from(variant).unwrap(); + let variant = Variant::from(string.clone()); + assert_eq!(Variant::from_variant(&ms_variant).unwrap(), variant); + } + + #[test] + fn it_bidirectional_empty_convert() { + let variant = Variant::from(()); + assert_eq!(variant.try_into().ok(), Some(())); + + let variant = Variant::from(()); + let ms_variant = VARIANT::try_from(variant).unwrap(); + let variant = Variant::from(()); + assert_eq!(Variant::from_variant(&ms_variant).unwrap(), variant); + } + + #[test] + fn it_bidirectional_r8_convert() { + let num = 0.123456789; + let variant = Variant::from(num); + assert_eq!(variant.try_into().ok(), Some(num)); + + let variant = Variant::from(num); + let ms_variant = VARIANT::try_from(variant).unwrap(); + let variant = Variant::from(num); + assert_eq!(Variant::from_variant(&ms_variant).unwrap(), variant); + } }