Skip to content

Commit

Permalink
derive(Zeroable) on fieldful enums and repr(C) enums (#257)
Browse files Browse the repository at this point in the history
* Add support for deriving Zeroable for fieldful enums if:

1. the enum is repr(Int), repr(C), or repr(C, Int),
2. the enum has a variant with discriminant 0,
3. and all fields of the variant with discriminant 0 are Zeroable.

* Allow using derive(Zeroable) with explicit bounds. Update documentation and doctests.

* doc update

* doc update

* remove unused

* Factor out get_zero_variant helper function.

* Use i128 to track disciminants instead of i64.

* Add doc-comment for `get_fields`

Co-authored-by: Daniel Henry-Mantilla <[email protected]>

* Update derive/src/traits.rs

Co-authored-by: Daniel Henry-Mantilla <[email protected]>

---------

Co-authored-by: Daniel Henry-Mantilla <[email protected]>
  • Loading branch information
zachs18 and danielhenrymantilla authored Sep 24, 2024
1 parent bb36879 commit a637e1d
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 106 deletions.
74 changes: 60 additions & 14 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,26 @@ pub fn derive_anybitpattern(
proc_macro::TokenStream::from(expanded)
}

/// Derive the `Zeroable` trait for a struct
/// Derive the `Zeroable` trait for a type.
///
/// The macro ensures that the struct follows all the the safety requirements
/// The macro ensures that the type follows all the the safety requirements
/// for the `Zeroable` trait.
///
/// The following constraints need to be satisfied for the macro to succeed
/// The following constraints need to be satisfied for the macro to succeed on a
/// struct:
///
/// - All fields in the struct must implement `Zeroable`
///
/// The following constraints need to be satisfied for the macro to succeed on
/// an enum:
///
/// - All fields in the struct must to implement `Zeroable`
/// - The enum has an explicit `#[repr(Int)]`, `#[repr(C)]`, or `#[repr(C,
/// Int)]`.
/// - The enum has a variant with discriminant 0 (explicitly or implicitly).
/// - All fields in the variant with discriminant 0 (if any) must implement
/// `Zeroable`
///
/// The macro always succeeds on unions.
///
/// ## Example
///
Expand All @@ -134,6 +146,23 @@ pub fn derive_anybitpattern(
/// b: u16,
/// }
/// ```
/// ```rust
/// # use bytemuck_derive::{Zeroable};
/// #[derive(Copy, Clone, Zeroable)]
/// #[repr(i32)]
/// enum Values {
/// A = 0,
/// B = 1,
/// C = 2,
/// }
/// #[derive(Clone, Zeroable)]
/// #[repr(C)]
/// enum Implicit {
/// A(bool, u8, char),
/// B(String),
/// C(std::num::NonZeroU8),
/// }
/// ```
///
/// # Custom bounds
///
Expand All @@ -157,6 +186,18 @@ pub fn derive_anybitpattern(
///
/// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
/// ```
/// ```rust
/// # use bytemuck::{Zeroable};
/// #[derive(Copy, Clone, Zeroable)]
/// #[repr(u8)]
/// #[zeroable(bound = "")]
/// enum MyOption<T> {
/// None,
/// Some(T),
/// }
///
/// assert!(matches!(MyOption::<std::num::NonZeroU8>::zeroed(), MyOption::None));
/// ```
///
/// ```rust,compile_fail
/// # use bytemuck::Zeroable;
Expand Down Expand Up @@ -407,7 +448,8 @@ pub fn derive_byte_eq(
let input = parse_macro_input!(input as DeriveInput);
let crate_name = bytemuck_crate_name(&input);
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();

proc_macro::TokenStream::from(quote! {
impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
Expand Down Expand Up @@ -460,7 +502,8 @@ pub fn derive_byte_hash(
let input = parse_macro_input!(input as DeriveInput);
let crate_name = bytemuck_crate_name(&input);
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let (impl_generics, ty_generics, where_clause) =
input.generics.split_for_impl();

proc_macro::TokenStream::from(quote! {
impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause {
Expand Down Expand Up @@ -569,26 +612,29 @@ fn derive_marker_trait_inner<Trait: Derivable>(
.flatten()
.collect::<Vec<syn::WherePredicate>>();

let predicates = &mut input.generics.make_where_clause().predicates;

predicates.extend(explicit_bounds);

let fields = match &input.data {
syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
syn::Data::Union(_) => {
let fields = match (Trait::perfect_derive_fields(&input), &input.data) {
(Some(fields), _) => fields,
(None, syn::Data::Struct(syn::DataStruct { fields, .. })) => {
fields.clone()
}
(None, syn::Data::Union(_)) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for unions",
));
}
syn::Data::Enum(_) => {
(None, syn::Data::Enum(_)) => {
return Err(syn::Error::new_spanned(
trait_,
&"perfect derive is not supported for enums",
));
}
};

let predicates = &mut input.generics.make_where_clause().predicates;

predicates.extend(explicit_bounds);

for field in fields {
let ty = field.ty;
predicates.push(syn::parse_quote!(
Expand Down
Loading

0 comments on commit a637e1d

Please sign in to comment.