Skip to content

Commit

Permalink
support mut calls
Browse files Browse the repository at this point in the history
  • Loading branch information
valsteen committed Sep 23, 2022
1 parent 801e737 commit 1024e2b
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sort_by_derive"
version = "0.1.8"
version = "0.1.9"
edition = "2021"
license = "Unlicense"
description = "Derive macro SortBy and helper macros EnumAccessor and EnumSequence, deriving traits `Ord`, `PartialOrd`, `Eq`, `PartialEq` and `Hash` for structs and enums that can't automatically derive from those traits."
Expand Down
43 changes: 32 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ enum E {

This will derive the accessor methods `fn name(&self) -> &type;` and`fn name_mut(&mut self) -> &mut type;`, and return a reference to the field of the same name on any variant.

So you can take any `E`, all variants will have `name_of_the_field`, `name_of_the_field_mut`, `name_of_other_field`, `name_of_other_field_mut`

```rust
fn do_something(some_e: &mut E) {
let field_value = *some_e.name_of_the_field() ; // take the value of that field, whatever variant it is
*some_e.name_of_the_field_mut() = "somevalue" ; // use the accessor method returning a &mut to the field
}
```


```rust
#[derive(EnumAccessor)]
#[accessor(name: type, (Variant3,Variant4))]
Expand All @@ -92,7 +102,7 @@ enum E {

This derives the same accessor methods, but the return type will be `Option<&type>` and `Option<&mut type>`. The provided comma-separated list of variants are exceptions and will return `None`.

Methods without arguments ( i.e. only `&self` are also supported ). It takes the form: `#[accessor(method_name(): type)]`.
Methods without arguments ( i.e. only `&self` are also supported ). It takes the form: `#[accessor(method_name(): type)]`. If `type` is a `&mut`, the generated method will take `&mut self` instead of `&self`. This can be useful for accessing mutable derived methods of nested enums.

To avoid name clashes, accessors can be given an alias by using `as`:

Expand Down Expand Up @@ -220,6 +230,9 @@ impl A {
fn sum(&self) -> u8 {
self.f1 + self.f2
}
fn set(&mut self) -> &mut u8 {
&mut self.f1
}
}

struct B {
Expand All @@ -234,24 +247,32 @@ impl B {

#[derive(EnumAccessor)]
#[accessor(sum():u8)]
enum E {
#[accessor(set(): &mut u8, (B,C))]
enum E<Get: Fn() -> u8> {
A(A),
B(B),
C{sum: Box<dyn Fn() -> u8>}
C{sum: Get}
}

#[test]
fn test_sum() {
let a = E::A(A{ f1: 10, f2: 22 });
let b = E::B(B{ values: vec![9,4,3,2] });
let factor = Arc::new(AtomicU8::new(1));

let c = {
let factor = factor.clone();
E::C { sum: Box::new(move || 21 * factor.load(Ordering::Relaxed)) }
};

assert_eq!(32, a.sum());
let [mut a, b, c] = [
E::A(A { f1: 10, f2: 22 }),
E::B(B { values: vec![9, 4, 3, 2] }),
{
let factor = factor.clone();
E::C {
sum: move || 21 * factor.load(Ordering::Relaxed),
}
}];

assert_eq!(32, a.sum()); // sum() is available without matching against E::A, E::B or E::C
if let Some(value) = a.set() { // set() is only available for E::A and returns a &mut u8, so we get a Option<&mut u8>
*value = 0;
}
assert_eq!(22, a.sum());
assert_eq!(18, b.sum());
assert_eq!(21, c.sum());
factor.store(2, Ordering::Relaxed);
Expand Down
170 changes: 125 additions & 45 deletions src/enum_variant_accessor.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
use either::Either;
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote_spanned, ToTokens};
use syn::token::Mut;
use syn::{
self, punctuated::Pair, spanned::Spanned, Attribute, Data, DeriveInput, Expr, ExprCall,
ExprParen, ExprTuple, ExprType, Fields, FieldsNamed, GenericParam, Token, Type, Variant,
ExprParen, ExprTuple, ExprType, Fields, FieldsNamed, GenericParam, Token, Type, TypeReference,
Variant,
};

const ATTR_HELP: &str = "EnumAccessor: Invalid accessor declaration, expected #[accessor(field1: type, (VariantWithoutAccessor1,VariantWithoutAccessor2))]";
const ENUM_HELP: &str =
"EnumAccessor: only variants with one unnamed parameter, unit and named variants are supported";

#[derive(Clone, Copy)]
enum AccessorType {
Field,
Call,
CallMut,
}

struct Accessor {
ident: Ident,
alias: Ident,
ty: Type,
except: Vec<Ident>,
span: Span,
is_call: bool,
accessor_type: AccessorType,
}

fn parse_tuple(expr: &ExprTuple) -> Option<Accessor> {
Expand Down Expand Up @@ -68,17 +77,25 @@ fn ident_from_call(call: &ExprCall) -> Option<Ident> {
}

fn parse_ty(expr: &ExprType) -> Option<Accessor> {
let (ident, alias, is_call, span) = match expr.expr.as_ref() {
let (ident, alias, accessor_type, span) = match expr.expr.as_ref() {
Expr::Path(p) => {
let ident = p.path.get_ident()?.clone();
(ident.clone(), ident, false, expr.span())
(ident.clone(), ident, AccessorType::Field, expr.span())
}
Expr::Call(call) => {
let ident = ident_from_call(call)?;
if !call.args.is_empty() {
return None;
}
(ident.clone(), ident, true, call.span())
let accessor_type = match expr.ty.as_ref() {
Type::Reference(TypeReference {
mutability: Some(_),
..
}) => AccessorType::CallMut,
_ => AccessorType::Call,
};

(ident.clone(), ident, accessor_type, call.span())
}
Expr::Cast(expr) => {
let alias = if let Type::Path(path) = expr.ty.as_ref() {
Expand All @@ -89,14 +106,14 @@ fn parse_ty(expr: &ExprType) -> Option<Accessor> {
match expr.expr.as_ref() {
Expr::Path(path) => {
let ident = path.path.get_ident()?.clone();
(ident, alias, false, expr.span())
(ident, alias, AccessorType::Field, expr.span())
}
Expr::Call(call) => {
let ident = ident_from_call(call)?;
if !call.args.is_empty() {
return None;
}
(ident, alias, true, call.span())
(ident, alias, AccessorType::Call, call.span())
}
_ => return None,
}
Expand All @@ -110,7 +127,7 @@ fn parse_ty(expr: &ExprType) -> Option<Accessor> {
ty: expr.ty.as_ref().clone(),
except: vec![],
span,
is_call,
accessor_type,
})
}

Expand Down Expand Up @@ -144,10 +161,10 @@ fn make_mut(ident: &Ident, span: Span) -> Ident {
Ident::new(format!("{ident}_mut").as_str(), span)
}

fn get_ret(span: Span, is_optional: bool, access_type: AccessType, typ: &Type) -> TokenStream {
fn get_ret(span: Span, is_optional: bool, access_type: TypeModifier, typ: &Type) -> TokenStream {
let modifier = match access_type {
AccessType::Read => Some(quote_spanned!(span => &)),
AccessType::Mut => Some(quote_spanned!(span => &mut )),
TypeModifier::Ref => Some(quote_spanned!(span => &)),
TypeModifier::RefMut => Some(quote_spanned!(span => &mut )),
_ => None,
};

Expand All @@ -159,16 +176,16 @@ fn get_ret(span: Span, is_optional: bool, access_type: AccessType, typ: &Type) -
}

#[derive(Eq, PartialEq, Copy, Clone)]
enum AccessType {
Read,
Mut,
Call,
enum TypeModifier {
Ref,
RefMut,
Unmodified,
}

fn make_match_arms(
variant: &Variant,
accessor: &Accessor,
access_type: AccessType,
access_type: TypeModifier,
) -> Result<TokenStream, syn::Error> {
let span = variant.span();
let variant_ident = &variant.ident;
Expand All @@ -178,13 +195,13 @@ fn make_match_arms(
let mut call = None;

match access_type {
AccessType::Read => {
TypeModifier::Ref => {
modifier = Some(quote_spanned!(span => &));
}
AccessType::Mut => {
TypeModifier::RefMut => {
modifier = Some(quote_spanned!(span => &mut ));
}
AccessType::Call => call = Some(quote_spanned!(span => ())),
TypeModifier::Unmodified => call = Some(quote_spanned!(span => ())),
}

match (
Expand Down Expand Up @@ -286,31 +303,55 @@ fn get_named_variant_field_span(
Ok(span)
}

fn make_def(span: Span, is_mut: bool, method_name: &Ident, ret: &TokenStream) -> TokenStream {
let modifier = is_mut.then(|| Token![mut](span));
let method_name = is_mut
.then(|| make_mut(method_name, span))
.unwrap_or_else(|| method_name.clone());
#[derive(PartialEq, Clone, Copy)]
enum SignatureType {
ReadOnly,
FieldMut,
CallMut,
}

fn get_method_modifiers(
signature_type: SignatureType,
method_name: &Ident,
span: Span,
) -> (Option<Mut>, Ident) {
let self_modifier = match signature_type {
SignatureType::ReadOnly => None,
SignatureType::FieldMut | SignatureType::CallMut => Some(Token![mut](span)),
};

let method_name = if let SignatureType::FieldMut = signature_type {
make_mut(method_name, span)
} else {
method_name.clone()
};
(self_modifier, method_name)
}

fn make_def(
span: Span,
signature_type: SignatureType,
method_name: &Ident,
ret: &TokenStream,
) -> TokenStream {
let (self_modifier, method_name) = get_method_modifiers(signature_type, method_name, span);

quote_spanned! {span =>
fn #method_name(& #modifier self) -> #ret;
fn #method_name(& #self_modifier self) -> #ret;
}
}

fn make_impl(
span: Span,
is_mut: bool,
signature_type: SignatureType,
method_name: &Ident,
ret: &TokenStream,
arms: Vec<TokenStream>,
) -> TokenStream {
let modifier = is_mut.then(|| Token![mut](span));
let method_name = is_mut
.then(|| make_mut(method_name, span))
.unwrap_or_else(|| method_name.clone());
let (self_modifier, method_name) = get_method_modifiers(signature_type, method_name, span);

quote_spanned! {span =>
fn #method_name(& #modifier self) -> #ret {
fn #method_name(& #self_modifier self) -> #ret {
match self {
#(#arms),*
}
Expand Down Expand Up @@ -393,17 +434,29 @@ pub fn impl_enum_accessor(input: DeriveInput) -> TokenStream {
let span = accessor.span;
let method_name = &accessor.alias;

let variations = if accessor.is_call {
Either::Left([AccessType::Call])
} else {
Either::Right([AccessType::Read, AccessType::Mut])
let variations = match accessor.accessor_type {
AccessorType::Call => {
Either::Left([(SignatureType::ReadOnly, TypeModifier::Unmodified)])
}
AccessorType::CallMut => {
Either::Left([(SignatureType::CallMut, TypeModifier::Unmodified)])
}
AccessorType::Field => Either::Right([
(SignatureType::ReadOnly, TypeModifier::Ref),
(SignatureType::FieldMut, TypeModifier::RefMut),
]),
};

for access_type in variations.into_iter() {
let ret = get_ret(span, !accessor.except.is_empty(), access_type, &accessor.ty);
for (signature_type, self_modifier) in variations.into_iter() {
let ret = get_ret(
span,
!accessor.except.is_empty(),
self_modifier,
&accessor.ty,
);
let match_arms = match variants
.iter()
.map(|variant| make_match_arms(variant, accessor, access_type))
.map(|variant| make_match_arms(variant, accessor, self_modifier))
.collect::<Result<Vec<_>, _>>()
{
Ok(r) => r,
Expand All @@ -412,17 +465,12 @@ pub fn impl_enum_accessor(input: DeriveInput) -> TokenStream {

accessor_impls.push(make_impl(
span,
access_type == AccessType::Mut,
signature_type,
method_name,
&ret,
match_arms,
));
accessor_defs.push(make_def(
span,
access_type == AccessType::Mut,
method_name,
&ret,
));
accessor_defs.push(make_def(span, signature_type, method_name, &ret));
}
}

Expand Down Expand Up @@ -604,4 +652,36 @@ impl SomeEnumAccessor for SomeEnum {
"#
)
}

#[test]
fn test_mut_method() {
let input = syn::parse_quote! {
#[accessor(inner_mut(): &mut usize, (D))]
enum SomeEnum {
A(b),
C(d),
D(g),
}
};
let output = crate::enum_variant_accessor::impl_enum_accessor(syn::parse2(input).unwrap());
let output = rust_format::RustFmt::default()
.format_str(output.to_string())
.unwrap();
assert_eq!(
output,
r"trait SomeEnumAccessor {
fn inner_mut(&mut self) -> std::option::Option<&mut usize>;
}
impl SomeEnumAccessor for SomeEnum {
fn inner_mut(&mut self) -> std::option::Option<&mut usize> {
match self {
Self::A(x, ..) => std::option::Option::Some(x.inner_mut()),
Self::C(x, ..) => std::option::Option::Some(x.inner_mut()),
Self::D(..) => std::option::Option::None,
}
}
}
"
)
}
}

0 comments on commit 1024e2b

Please sign in to comment.