diff --git a/deku-derive/src/macros/deku_read.rs b/deku-derive/src/macros/deku_read.rs index a7527027..df314728 100644 --- a/deku-derive/src/macros/deku_read.rs +++ b/deku-derive/src/macros/deku_read.rs @@ -181,8 +181,15 @@ fn emit_enum(input: &DekuData) -> Result { .first() .and_then(|v| v.ident.as_ref()) .is_some(); + let disc = &variant.discriminant; let (use_id, variant_id) = if let Some(variant_id) = &variant.id { + if disc.is_some() { + return Err(syn::Error::new( + variant.ident.span(), + "DekuRead: `id` cannot be used with arbitrary_enum_discriminant", + )); + } match variant_id { Id::TokenStream(v) => (false, quote! {&#v}.into_token_stream()), Id::LitByteStr(v) => (false, v.into_token_stream()), @@ -199,12 +206,51 @@ fn emit_enum(input: &DekuData) -> Result { (false, variant_id_pat.clone()) } } else if has_discriminant { - let ident = &variant.ident; - let internal_ident = gen_internal_field_ident("e!(#ident)); - pre_match_tokens.push(quote! { - let #internal_ident = <#id_type>::try_from(Self::#ident as isize)?; - }); - (true, quote! { _ if __deku_variant_id == #internal_ident }) + // Regular enum + if variant.fields.fields.is_empty() { + let ident = &variant.ident; + let internal_ident = gen_internal_field_ident("e!(#ident)); + + // If we have the discriminant, we can use that directly into the id_type + if let Some(disc) = disc { + // the discriminant is just an ident, we use the ident here + pre_match_tokens.push(quote! { + #[allow(non_upper_case_globals)] + const #internal_ident: #id_type = #disc; + }); + } else { + // if not, we need to convert this into the type expressed by a cast + // enum Discriminant { + // Cats = 0x01, + // Dogs, + // } + pre_match_tokens.push(quote! { + let #internal_ident = <#id_type>::try_from(Self::#ident as isize)?; + }); + } + (true, quote! { _ if __deku_variant_id == #internal_ident }) + } else { + // arbitrary_enum_discriminant, such as: + // + // #[repr(u8)] + // #[derive(DekuRead, Debug)] + // #[deku(id_type = "u8")] + // enum Foo { + // A(u8) = 0, + // B(i8) = 1, + // C(bool) = 42, + // } + let ident = &variant.ident; + let disc = &variant.discriminant; + let internal_ident = gen_internal_field_ident("e!(#ident)); + pre_match_tokens.push(quote! { + #[allow(non_upper_case_globals)] + const #internal_ident: #id_type = #disc; + }); + + // In this case, we send back false for use_id as we still need to read fields + (false, quote! { _ if __deku_variant_id == #internal_ident }) + } } else { return Err(syn::Error::new( variant.ident.span(), diff --git a/deku-derive/src/macros/deku_write.rs b/deku-derive/src/macros/deku_write.rs index 871d1623..2c233c63 100644 --- a/deku-derive/src/macros/deku_write.rs +++ b/deku-derive/src/macros/deku_write.rs @@ -154,6 +154,7 @@ fn emit_enum(input: &DekuData) -> Result { .first() .and_then(|v| v.ident.as_ref()) .is_some(); + let disc = &variant.discriminant; let variant_ident = &variant.ident; let variant_writer = &variant.writer; @@ -174,6 +175,12 @@ fn emit_enum(input: &DekuData) -> Result { } } else if id_type.is_some() { if let Some(variant_id) = &variant.id { + if disc.is_some() { + return Err(syn::Error::new( + variant.ident.span(), + "DekuWrite: `id` cannot be used with arbitrary_enum_discriminant", + )); + } match variant_id { Id::TokenStream(v) => { quote! { @@ -197,9 +204,19 @@ fn emit_enum(input: &DekuData) -> Result { } else if variant.id_pat.is_some() { quote! {} } else if has_discriminant { - quote! { - let mut __deku_variant_id: #id_type = Self::#variant_ident as #id_type; - __deku_variant_id.to_writer(__deku_writer, (#id_args))?; + // Discriminant is provided, use it + if let Some(disc) = disc { + quote! { + #[allow(non_upper_case_globals)] + const __deku_variant_id: #id_type = #disc; + __deku_variant_id.to_writer(__deku_writer, (#id_args))?; + } + // Regular enum + } else { + quote! { + let mut __deku_variant_id: #id_type = Self::#variant_ident as #id_type; + __deku_variant_id.to_writer(__deku_writer, (#id_args))?; + } } } else { return Err(syn::Error::new( diff --git a/examples/enums_catch_all.rs b/examples/enums_catch_all.rs index 882be0fd..3a1d353c 100644 --- a/examples/enums_catch_all.rs +++ b/examples/enums_catch_all.rs @@ -10,13 +10,13 @@ use hexlit::hex; pub enum DekuTest { /// A #[deku(id = "1")] - A = 0, + A, /// B #[deku(id = "2")] - B = 1, + B, /// C #[deku(id = "3", default)] - C = 2, + C, } fn main() { diff --git a/tests/test_catch_all.rs b/tests/test_catch_all.rs index 62bc4cfd..1cf09e3e 100644 --- a/tests/test_catch_all.rs +++ b/tests/test_catch_all.rs @@ -27,13 +27,13 @@ mod test { pub enum AdvancedRemapping { /// A #[deku(id = "1")] - A = 0, + A, /// B #[deku(id = "2")] - B = 1, + B, /// C #[deku(id = "3", default)] - C = 2, + C, } #[test] diff --git a/tests/test_compile/cases/attribute_token_stream.stderr b/tests/test_compile/cases/attribute_token_stream.stderr index 2d9e7463..5b038eb2 100644 --- a/tests/test_compile/cases/attribute_token_stream.stderr +++ b/tests/test_compile/cases/attribute_token_stream.stderr @@ -12,12 +12,12 @@ error[E0277]: can't compare `{integer}` with `bool` | = help: the trait `PartialEq` is not implemented for `{integer}` = help: the following other types implement trait `PartialEq`: - isize - i8 + f128 + f16 + f32 + f64 + i128 i16 i32 i64 - i128 - usize - u8 and $N others diff --git a/tests/test_compile/cases/id_arbitrary_enum_discriminant.stderr b/tests/test_compile/cases/id_arbitrary_enum_discriminant.stderr new file mode 100644 index 00000000..af604d97 --- /dev/null +++ b/tests/test_compile/cases/id_arbitrary_enum_discriminant.stderr @@ -0,0 +1,5 @@ +error: DekuRead: `id` cannot be used with arbitrary_enum_discriminant + --> tests/test_compile/cases/id_arbitrary_enum_discriminant.rs:8:5 + | +8 | A(u8) = 0, + | ^ diff --git a/tests/test_enum.rs b/tests/test_enum.rs index cbe39478..d524128d 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -152,3 +152,43 @@ fn test_id_pat_with_id() { ); assert_eq!(input, &*v.to_bytes().unwrap()); } + +#[test] +fn test_arbitrary() { + #[repr(u8)] + #[derive(DekuRead, DekuWrite, Debug, PartialEq)] + #[deku(id_type = "u8")] + enum Foo { + A(u8) = 0, + B(bool) = 42, + I(Inner) = 12, + N = 10, + } + + #[repr(u8)] + #[derive(DekuRead, DekuWrite, Debug, PartialEq)] + #[deku(id_type = "u8")] + enum Inner { + One(u8) = 0, + } + + let bytes = [0, 1]; + let (_, foo) = Foo::from_bytes((&bytes, 0)).unwrap(); + assert_eq!(foo, Foo::A(1)); + assert_eq!(bytes, &*foo.to_bytes().unwrap()); + + let bytes = [42, 1]; + let (_, foo) = Foo::from_bytes((&bytes, 0)).unwrap(); + assert_eq!(foo, Foo::B(true)); + assert_eq!(bytes, &*foo.to_bytes().unwrap()); + + let bytes = [12, 0, 1]; + let (_, foo) = Foo::from_bytes((&bytes, 0)).unwrap(); + assert_eq!(foo, Foo::I(Inner::One(1))); + assert_eq!(bytes, &*foo.to_bytes().unwrap()); + + let bytes = [10]; + let (_, foo) = Foo::from_bytes((&bytes, 0)).unwrap(); + assert_eq!(foo, Foo::N); + assert_eq!(bytes, &*foo.to_bytes().unwrap()); +}