Skip to content

Commit

Permalink
support for generics
Browse files Browse the repository at this point in the history
  • Loading branch information
valsteen committed Sep 22, 2022
1 parent 9ad00aa commit 01132e1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
24 changes: 19 additions & 5 deletions src/enum_sequence.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use proc_macro2::{Ident, TokenStream};

use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Fields};
use syn::{Data, DeriveInput, Fields, GenericParam};

const ENUM_HELP: &str = "EnumSequence: Only enums are supported";

Expand Down Expand Up @@ -34,12 +34,26 @@ pub fn impl_enum_sequence(input: DeriveInput) -> TokenStream {
})
}

let vis = &input.vis;
let generics = &input.generics;
let where_clause = &input.generics.where_clause;
let generics_params = &input
.generics
.params
.iter()
.flat_map(|p| match p {
GenericParam::Type(t) => Some(&t.ident),
GenericParam::Const(t) => Some(&t.ident),
_ => None,
})
.collect::<Vec<_>>();

quote::quote_spanned! {input_span =>
pub trait #trait_ident {
#vis trait #trait_ident #generics #where_clause {
fn enum_sequence(&self) -> usize;
}

impl #trait_ident for #ident {
impl #generics #trait_ident <#(#generics_params),*> for #ident <#(#generics_params),*> #where_clause {
fn enum_sequence(&self) -> usize {
match self {
#(#match_branches),*
Expand Down Expand Up @@ -69,7 +83,8 @@ mod test {
.unwrap();

assert_eq!(
r#"pub trait EEnumSequence {
output,
r#"trait EEnumSequence {
fn enum_sequence(&self) -> usize;
}
impl EEnumSequence for E {
Expand All @@ -82,7 +97,6 @@ impl EEnumSequence for E {
}
}
"#,
output
)
}
}
24 changes: 19 additions & 5 deletions src/enum_variant_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote_spanned, ToTokens};
use syn::{
self, punctuated::Pair, spanned::Spanned, Attribute, Data, DeriveInput, Expr, ExprParen,
ExprTuple, ExprType, Fields, FieldsNamed, Token, Type, Variant,
ExprTuple, ExprType, Fields, FieldsNamed, GenericParam, Token, Type, Variant,
};

const ATTR_HELP: &str = "EnumAccessor: Invalid accessor declaration, expected #[accessor(field1: type, (VariantWithoutAccessor1,VariantWithoutAccessor2))]";
Expand Down Expand Up @@ -400,12 +400,26 @@ pub fn impl_enum_accessor(input: DeriveInput) -> TokenStream {
}
}

let vis = &input.vis;
let generics = &input.generics;
let where_clause = &input.generics.where_clause;
let generics_params = &input
.generics
.params
.iter()
.flat_map(|p| match p {
GenericParam::Type(t) => Some(&t.ident),
GenericParam::Const(t) => Some(&t.ident),
_ => None,
})
.collect::<Vec<_>>();

syn::parse_quote_spanned! {input_span =>
pub trait #extension_trait {
#vis trait #extension_trait #generics #where_clause {
#(#accessor_defs)*
}

impl #extension_trait for #ident {
impl #generics #extension_trait <#(#generics_params),*> for #ident <#(#generics_params),*> #where_clause {
#(#accessor_impls)*
}
}
Expand Down Expand Up @@ -436,7 +450,7 @@ mod test {
.unwrap();
assert_eq!(
output,
r#"pub trait SomeEnumAccessor {
r#"trait SomeEnumAccessor {
fn acc1(&self) -> std::option::Option<&usize>;
fn acc1_mut(&mut self) -> std::option::Option<&mut usize>;
fn acc2(&self) -> &u8;
Expand Down Expand Up @@ -530,7 +544,7 @@ impl SomeEnumAccessor for SomeEnum {
.unwrap();
assert_eq!(
output,
r#"pub trait SomeEnumAccessor {
r#"trait SomeEnumAccessor {
fn acc1(&self) -> std::option::Option<&usize>;
fn acc1_mut(&mut self) -> std::option::Option<&mut usize>;
}
Expand Down
25 changes: 19 additions & 6 deletions src/sort_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use quote::{quote_spanned, ToTokens};

use syn::{
self, spanned::Spanned, Attribute, Data, DataStruct, DeriveInput, Error, Expr, ExprLit, Fields,
FieldsNamed, Lit, Meta, NestedMeta,
FieldsNamed, GenericParam, Lit, Meta, NestedMeta,
};

const HELP_SORTBY: &str = r#"SortBy: invalid sort_by attribute, expected list form i.e #[sort_by(attr1, attr2, methodcall())]"#;
Expand Down Expand Up @@ -67,28 +67,41 @@ pub fn impl_sort_by_derive(input: DeriveInput) -> TokenStream {
.map(|expr| syn::parse_quote_spanned!(expr.span() => self.#expr.hash(state)))
.collect();

let generics = &input.generics;
let where_clause = &input.generics.where_clause;
let generics_params = &input
.generics
.params
.iter()
.flat_map(|p| match p {
GenericParam::Type(t) => Some(&t.ident),
GenericParam::Const(t) => Some(&t.ident),
_ => None,
})
.collect::<Vec<_>>();

quote_spanned! {input_span =>
impl std::hash::Hash for #struct_name {
impl #generics std::hash::Hash for #struct_name <#(#generics_params),*> #where_clause {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
#(#hash_expressions);*;
}
}

impl core::cmp::Eq for #struct_name {}
impl #generics core::cmp::Eq for #struct_name <#(#generics_params),*> #where_clause {}

impl core::cmp::PartialEq<Self> for #struct_name {
impl #generics core::cmp::PartialEq<Self> for #struct_name <#(#generics_params),*> #where_clause {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}

impl core::cmp::PartialOrd<Self> for #struct_name {
impl #generics core::cmp::PartialOrd<Self> for #struct_name <#(#generics_params),*> #where_clause {
fn partial_cmp(&self, other: &Self) -> core::option::Option<core::cmp::Ordering> {
std::option::Option::Some(self.cmp(other))
}
}

impl core::cmp::Ord for #struct_name {
impl #generics core::cmp::Ord for #struct_name <#(#generics_params),*> #where_clause {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
#ord_statement
}
Expand Down

0 comments on commit 01132e1

Please sign in to comment.