diff --git a/zvariant/src/lib.rs b/zvariant/src/lib.rs index 398db9919..6a3a8e06b 100644 --- a/zvariant/src/lib.rs +++ b/zvariant/src/lib.rs @@ -142,8 +142,8 @@ mod tests { use crate::Fd; use crate::{ serialized::{Context, Format}, - Array, Basic, DeserializeDict, DeserializeValue, Dict, Error, ObjectPath, Result, - SerializeDict, SerializeValue, Signature, Str, Structure, Type, Value, BE, LE, + Array, Basic, DeserializeDict, DeserializeValue, Dict, Error, ObjectPath, OwnedValue, + Result, SerializeDict, SerializeValue, Signature, Str, Structure, Type, Value, BE, LE, NATIVE_ENDIAN, }; @@ -1483,7 +1483,7 @@ mod tests { let _: UnitStruct = encoded.deserialize().unwrap().0; #[repr(u8)] - #[derive(Deserialize_repr, Serialize_repr, Type, Debug, PartialEq)] + #[derive(Deserialize_repr, Serialize_repr, Type, Value, OwnedValue, Debug, PartialEq)] enum Enum { Variant1, Variant2, @@ -1496,8 +1496,12 @@ mod tests { let decoded: Enum = encoded.deserialize().unwrap().0; assert_eq!(decoded, Enum::Variant3); + assert_eq!(Value::from(Enum::Variant1), Value::U8(0)); + assert_eq!(Enum::try_from(Value::U8(2)), Ok(Enum::Variant3)); + assert_eq!(Enum::try_from(Value::U8(4)), Err(Error::IncorrectType)); + #[repr(i64)] - #[derive(Deserialize_repr, Serialize_repr, Type, Debug, PartialEq)] + #[derive(Deserialize_repr, Serialize_repr, Type, Value, OwnedValue, Debug, PartialEq)] enum Enum2 { Variant1, Variant2, @@ -1510,7 +1514,11 @@ mod tests { let decoded: Enum2 = encoded.deserialize().unwrap().0; assert_eq!(decoded, Enum2::Variant2); - #[derive(Deserialize, Serialize, Type, Debug, PartialEq)] + assert_eq!(Value::from(Enum2::Variant1), Value::I64(0)); + assert_eq!(Enum2::try_from(Value::I64(2)), Ok(Enum2::Variant3)); + assert_eq!(Enum2::try_from(Value::I64(4)), Err(Error::IncorrectType)); + + #[derive(Deserialize, Serialize, Type, Value, OwnedValue, Debug, PartialEq)] enum NoReprEnum { Variant1, Variant2, @@ -1527,10 +1535,10 @@ mod tests { let decoded: NoReprEnum = encoded.deserialize().unwrap().0; assert_eq!(decoded, NoReprEnum::Variant2); - #[derive(Deserialize, Serialize, Type, Debug, PartialEq)] - #[zvariant(signature = "s")] + #[derive(Deserialize, Serialize, Type, Value, OwnedValue, Debug, PartialEq)] + #[zvariant(signature = "s", rename_all = "snake_case")] enum StrEnum { - Variant1, + VariantOne, Variant2, Variant3, } @@ -1541,6 +1549,20 @@ mod tests { let decoded: StrEnum = encoded.deserialize().unwrap().0; assert_eq!(decoded, StrEnum::Variant2); + assert_eq!( + StrEnum::try_from(Value::Str("variant_one".into())), + Ok(StrEnum::VariantOne) + ); + assert_eq!( + StrEnum::try_from(Value::Str("variant2".into())), + Ok(StrEnum::Variant2) + ); + assert_eq!( + StrEnum::try_from(Value::Str("variant4".into())), + Err(Error::IncorrectType) + ); + assert_eq!(StrEnum::try_from(Value::U32(0)), Err(Error::IncorrectType)); + #[derive(Deserialize, Serialize, Type)] enum NewType { Variant1(f64), diff --git a/zvariant_derive/src/utils.rs b/zvariant_derive/src/utils.rs index 81b6f4e06..4e6ad56da 100644 --- a/zvariant_derive/src/utils.rs +++ b/zvariant_derive/src/utils.rs @@ -22,4 +22,8 @@ def_attrs! { pub StructAttributes("struct") { signature str, rename_all str, deny_unknown_fields none }; /// Attributes defined on fields. pub FieldAttributes("field") { rename str }; + /// Attributes defined on enumerations. + pub EnumAttributes("enum") { signature str, rename_all str }; + /// Attributes defined on variants. + pub VariantAttributes("variant") { rename str }; } diff --git a/zvariant_derive/src/value.rs b/zvariant_derive/src/value.rs index dfa8fe0a5..6b71226bb 100644 --- a/zvariant_derive/src/value.rs +++ b/zvariant_derive/src/value.rs @@ -1,9 +1,10 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use syn::{ - spanned::Spanned, Attribute, Data, DataEnum, DeriveInput, Error, Expr, Fields, Generics, Ident, - Lifetime, LifetimeDef, + spanned::Spanned, Attribute, Data, DataEnum, DeriveInput, Error, Fields, Generics, Ident, + Lifetime, LifetimeDef, Variant, }; +use zvariant_utils::case; use crate::utils::*; @@ -236,34 +237,45 @@ fn impl_enum( Some(repr_attr) => repr_attr.parse_args()?, None => quote! { u32 }, }; + let enum_attrs = EnumAttributes::parse(&attrs)?; + let str_enum = enum_attrs + .signature + .map(|sig| sig == "s") + .unwrap_or_default(); let mut variant_names = vec![]; - let mut variant_values = vec![]; + let mut str_values = vec![]; for variant in &data.variants { + let variant_attrs = VariantAttributes::parse(&variant.attrs)?; // Ensure all variants of the enum are unit type match variant.fields { Fields::Unit => { variant_names.push(&variant.ident); - let value = match &variant - .discriminant - .as_ref() - .ok_or_else(|| Error::new(variant.span(), "expected `Name = Value` variants"))? - .1 - { - Expr::Lit(lit_exp) => &lit_exp.lit, - _ => { - return Err(Error::new( - variant.span(), - "expected `Name = Value` variants", - )) - } - }; - variant_values.push(value); + if str_enum { + let str_value = enum_name_for_variant( + variant, + variant_attrs.rename, + enum_attrs.rename_all.as_ref().map(AsRef::as_ref), + )?; + str_values.push(str_value); + } } _ => return Err(Error::new(variant.span(), "must be a unit variant")), } } + let into_val = if str_enum { + quote! { + match e { + #( + #name::#variant_names => #str_values, + )* + } + } + } else { + quote! { e as #repr } + }; + let (value_type, into_value) = match value_type { ValueType::Value => ( quote! { #zv::Value<'_> }, @@ -271,13 +283,7 @@ fn impl_enum( impl ::std::convert::From<#name> for #zv::Value<'_> { #[inline] fn from(e: #name) -> Self { - let u: #repr = match e { - #( - #name::#variant_names => #variant_values - ),* - }; - - <#zv::Value as ::std::convert::From<_>>::from(u).into() + <#zv::Value as ::std::convert::From<_>>::from(#into_val) } } }, @@ -290,14 +296,8 @@ fn impl_enum( #[inline] fn try_from(e: #name) -> #zv::Result { - let u: #repr = match e { - #( - #name::#variant_names => #variant_values - ),* - }; - <#zv::OwnedValue as ::std::convert::TryFrom<_>>::try_from( - <#zv::Value as ::std::convert::From<_>>::from(u) + <#zv::Value as ::std::convert::From<_>>::from(#into_val) ) } } @@ -305,23 +305,68 @@ fn impl_enum( ), }; + let from_val = if str_enum { + quote! { + let v: #zv::Str = ::std::convert::TryInto::try_into(value)?; + + ::std::result::Result::Ok(match v.as_str() { + #( + #str_values => #name::#variant_names, + )* + _ => return ::std::result::Result::Err(#zv::Error::IncorrectType), + }) + } + } else { + quote! { + let v: #repr = ::std::convert::TryInto::try_into(value)?; + + ::std::result::Result::Ok( + #( + if v == #name::#variant_names as #repr { + #name::#variant_names + } else + )* { + return ::std::result::Result::Err(#zv::Error::IncorrectType); + } + ) + } + }; + Ok(quote! { impl ::std::convert::TryFrom<#value_type> for #name { type Error = #zv::Error; #[inline] fn try_from(value: #value_type) -> #zv::Result { - let v: #repr = ::std::convert::TryInto::try_into(value)?; - - ::std::result::Result::Ok(match v { - #( - #variant_values => #name::#variant_names - ),*, - _ => return ::std::result::Result::Err(#zv::Error::IncorrectType), - }) + #from_val } } #into_value }) } + +fn enum_name_for_variant( + v: &Variant, + rename_attr: Option, + rename_all_attr: Option<&str>, +) -> Result { + if let Some(name) = rename_attr { + Ok(name) + } else { + let ident = v.ident.to_string(); + + match rename_all_attr { + Some("lowercase") => Ok(ident.to_ascii_lowercase()), + Some("UPPERCASE") => Ok(ident.to_ascii_uppercase()), + Some("PascalCase") => Ok(case::pascal_or_camel_case(&ident, true)), + Some("camelCase") => Ok(case::pascal_or_camel_case(&ident, false)), + Some("snake_case") => Ok(case::snake_case(&ident)), + None => Ok(ident), + Some(other) => Err(Error::new( + v.span(), + format!("invalid `rename_all` attribute value {other}"), + )), + } + } +}