Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make derive macros support generic enums #21

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Please make sure to add your changes to the appropriate categories:
- Added enum-level `#[enumcapsulate(exclude)]`/`#[enumcapsulate(exclude(…))]` helper attribute.
- Added optional selection list to variant-level `#[enumcapsulate(exclude(…))]` helper attribute.
- Added optional selection list to variant-level `#[enumcapsulate(include(…))]` helper attribute.
- Added derive support for generic enums (by force-excluding variants that use generic const/type parameters).

### Changed

Expand Down
2 changes: 1 addition & 1 deletion macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ proc-macro = true
[dependencies]
proc-macro2 = { version = "1.0.81", features = ["span-locations"] }
quote = "1.0.36"
syn = "2.0.60"
syn = { version = "2.0.60", features = ["derive", "visit"] }
91 changes: 91 additions & 0 deletions macros/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,97 @@ This attribute is recognized by the following variant-based derive macros:

- `Encapsulate`

## Generics

There is limited support for generic enums:

Variants using generic const/type parameters are always excluded when deriving generic traits with `enumcapsulate`'s derive macros.

The reason for this behavior is that implementing generic traits for variants that use any of the generic parameters of the enum tends to result in conflicting implementations in Rust, as shown by the following example program:

```rust
use enumcapsulate::FromVariant;

pub struct VariantA;
pub struct VariantB;

#[derive(FromVariant)]
// error[E0119]: conflicting implementations of trait `FromVariant<VariantB>` for type `Enum<VariantB>`
pub enum Enum<T> {
Unit,
Generic(T),
NonGeneric(VariantB),
}

fn main() {
let _: Enum<VariantA> = Enum::from_variant(VariantA);
let _: Enum<VariantA> = Enum::from_variant(VariantB);
}
```

The expanded version of the above makes it easier to see why: The compiler can't prove that `T` and `VariantB` are disjoint types.

```rust
pub struct VariantA;
pub struct VariantB;

pub enum Enum<T> {
Unit,
Generic(T),
NonGeneric(VariantB),
}

impl<T> FromVariant<T> for Enum<T> { // <- first implementation here
fn from_variant(variant: T) -> Self {
Self::Generic(variant)
}
}

// error[E0119]: conflicting implementations of trait `FromVariant<VariantB>` for type `Enum<VariantB>`
impl<T> FromVariant<VariantB> for Enum<T> { // <- conflicting implementation for `Enum<VariantB>`
fn from_variant(variant: VariantB) -> Self {
Self::NonGeneric(variant)
}
}

fn main() {
let _: Enum<VariantA> = Enum::from_variant(VariantA);
let _: Enum<VariantA> = Enum::from_variant(VariantB);
}
```

So in order to avoid such pitfalls altogether `enumcapsulate`'s derive macros will skip `impl<T> FromVariant<T> for Enum<T>`, since it uses a generic type (or const) parameter of `Enum<T>`.

So all you have to do is provide your own non-generic implementations for specific type instances of your generic type yourself, filling any gaps left behind by the derive macro:

```rust
use enumcapsulate::FromVariant;

pub struct VariantA;
pub struct VariantB;

#[derive(FromVariant)]
pub enum Enum<T> {
Unit,
Generic(T),
NonGeneric(VariantB),
}

// Notice how the trait is implemented on
// a specific type of the `Enum<T>` kind,
// rather than on the generic kind itself:
impl From<VariantA> for Enum<VariantA> {
fn from(value: VariantA) -> Self {
Self::Generic(value)
}
}

fn main() {
let _: Enum<VariantA> = Enum::from_variant(VariantA);
let _: Enum<VariantA> = Enum::from_variant(VariantB);
}
```

## Documentation

Please refer to the documentation on [docs.rs](https://docs.rs/enumcapsulate-macros).
Expand Down
86 changes: 71 additions & 15 deletions macros/src/enum_deriver.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_quote_spanned, DataEnum, DeriveInput, Fields, Type, Variant};
use syn::{parse_quote_spanned, visit::Visit as _, DataEnum, DeriveInput, Fields, Type, Variant};

use crate::{
config_for_enum_with_attrs, config_for_variant, macro_name, position_of_selected_field,
TypeVisitor,
};

pub(crate) struct EnumDeriver {
Expand Down Expand Up @@ -43,6 +44,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -65,6 +68,10 @@ impl EnumDeriver {
let inner_field = fields[selection_index];
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}

let field_expressions: Vec<_> = fields
.iter()
.enumerate()
Expand Down Expand Up @@ -92,7 +99,7 @@ impl EnumDeriver {
};

impls.push(quote! {
impl ::core::convert::From<#inner_ty> for #outer_ty {
impl #impl_generics ::core::convert::From<#inner_ty> for #outer_ty #type_generics #where_clause {
fn from(inner: #inner_ty) -> Self {
#expression
}
Expand All @@ -116,6 +123,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -139,6 +148,10 @@ impl EnumDeriver {
let inner_ident = inner_field.ident.as_ref();
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}

let pattern = match &variant.fields {
Fields::Named(_) => {
let field = inner_ident;
Expand All @@ -154,10 +167,10 @@ impl EnumDeriver {
};

impls.push(quote! {
impl ::core::convert::TryFrom<#outer_ty> for #inner_ty {
type Error = #outer_ty;
impl #impl_generics ::core::convert::TryFrom<#outer_ty #type_generics> for #inner_ty #where_clause {
type Error = #outer_ty #type_generics;

fn try_from(outer: #outer_ty) -> Result<Self, Self::Error> {
fn try_from(outer: #outer_ty #type_generics) -> Result<Self, Self::Error> {
match outer {
#pattern => Ok(inner),
err => Err(err)
Expand All @@ -183,6 +196,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -205,6 +220,10 @@ impl EnumDeriver {
let inner_field = fields[selection_index];
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}

let field_expressions: Vec<_> = fields
.iter()
.enumerate()
Expand Down Expand Up @@ -232,7 +251,7 @@ impl EnumDeriver {
};

impls.push(quote! {
impl ::enumcapsulate::FromVariant<#inner_ty> for #outer_ty {
impl #impl_generics ::enumcapsulate::FromVariant<#inner_ty> for #outer_ty #type_generics #where_clause {
fn from_variant(inner: #inner_ty) -> Self {
#expression
}
Expand All @@ -256,6 +275,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -279,6 +300,10 @@ impl EnumDeriver {
let inner_ident = inner_field.ident.as_ref();
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}

let pattern = match &variant.fields {
Fields::Named(_) => {
let field = inner_ident;
Expand All @@ -293,11 +318,13 @@ impl EnumDeriver {
Fields::Unit => continue,
};

let where_clause = match where_clause {
Some(where_clause) => quote! { #where_clause #inner_ty: Clone },
None => quote! { where #inner_ty: Clone },
};

impls.push(quote! {
impl ::enumcapsulate::AsVariant<#inner_ty> for #outer_ty
where
#inner_ty: Clone
{
impl #impl_generics ::enumcapsulate::AsVariant<#inner_ty> for #outer_ty #type_generics #where_clause {
fn as_variant(&self) -> Option<#inner_ty> {
match self {
#pattern => Some(inner.clone()),
Expand All @@ -324,6 +351,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -347,6 +376,9 @@ impl EnumDeriver {
let inner_ident = inner_field.ident.as_ref();
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}
let pattern = match &variant.fields {
Fields::Named(_) => {
let field = inner_ident;
Expand All @@ -362,7 +394,7 @@ impl EnumDeriver {
};

impls.push(quote! {
impl ::enumcapsulate::AsVariantRef<#inner_ty> for #outer_ty {
impl #impl_generics ::enumcapsulate::AsVariantRef<#inner_ty> for #outer_ty #type_generics #where_clause {
fn as_variant_ref(&self) -> Option<&#inner_ty> {
match self {
#pattern => Some(inner),
Expand All @@ -389,6 +421,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -412,6 +446,10 @@ impl EnumDeriver {
let inner_ident = inner_field.ident.as_ref();
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}

let pattern = match &variant.fields {
Fields::Named(_) => {
let field = inner_ident;
Expand All @@ -427,7 +465,7 @@ impl EnumDeriver {
};

impls.push(quote! {
impl ::enumcapsulate::AsVariantMut<#inner_ty> for #outer_ty {
impl #impl_generics ::enumcapsulate::AsVariantMut<#inner_ty> for #outer_ty #type_generics #where_clause {
fn as_variant_mut(&mut self) -> Option<&mut #inner_ty> {
match self {
#pattern => Some(inner),
Expand All @@ -454,6 +492,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let mut impls: Vec<TokenStream2> = vec![];

for variant in self.variants()? {
Expand All @@ -477,6 +517,10 @@ impl EnumDeriver {
let inner_ident = inner_field.ident.as_ref();
let inner_ty = &inner_field.ty;

if self.uses_generic_const_or_type(inner_ty) {
continue;
}

let pattern = match &variant.fields {
Fields::Named(_) => {
let field = inner_ident;
Expand All @@ -492,7 +536,7 @@ impl EnumDeriver {
};

impls.push(quote! {
impl ::enumcapsulate::IntoVariant<#inner_ty> for #outer_ty {
impl #impl_generics ::enumcapsulate::IntoVariant<#inner_ty> for #outer_ty #type_generics #where_clause {
fn into_variant(self) -> Result<#inner_ty, Self> {
match self {
#pattern => Ok(inner),
Expand Down Expand Up @@ -523,8 +567,10 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let tokens = quote! {
impl ::enumcapsulate::VariantDowncast for #outer_ty {}
impl #impl_generics ::enumcapsulate::VariantDowncast for #outer_ty #type_generics #where_clause {}
};

Ok(tokens)
Expand All @@ -545,6 +591,8 @@ impl EnumDeriver {
let outer = enum_ident;
let outer_ty: Type = parse_quote_spanned! { outer.span() => #outer };

let (impl_generics, type_generics, where_clause) = self.input.generics.split_for_impl();

let variants = self.variants()?;

let discriminant_ident = quote::format_ident!("{outer}Discriminant");
Expand Down Expand Up @@ -592,7 +640,7 @@ impl EnumDeriver {
Ok(quote! {
#discriminant_enum

impl ::enumcapsulate::VariantDiscriminant for #outer_ty {
impl #impl_generics ::enumcapsulate::VariantDiscriminant for #outer_ty #type_generics #where_clause {
type Discriminant = #discriminant_ident;

fn variant_discriminant(&self) -> Self::Discriminant {
Expand All @@ -604,4 +652,12 @@ impl EnumDeriver {
}
})
}

fn uses_generic_const_or_type(&self, ty: &syn::Type) -> bool {
let mut visitor = TypeVisitor::new(&self.input.generics);

visitor.visit_type(ty);

visitor.uses_const_or_type_param()
}
}
2 changes: 2 additions & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use crate::utils::tokenstream;

mod config;
mod enum_deriver;
mod type_visitor;
mod utils;

use config::*;
use enum_deriver::*;
use type_visitor::*;
use utils::*;

/// Derive macro generating an impl of the trait `From<T>`.
Expand Down
Loading
Loading