From aa6e41d02e3be926f38055d2d731f172e85261a4 Mon Sep 17 00:00:00 2001 From: Ozgun Ozerk Date: Fri, 18 Oct 2024 15:52:54 +0300 Subject: [PATCH] separate file for each macro --- src/lib.rs | 335 ++-------------------------------------------- src/require.rs | 129 ++++++++++++++++++ src/states.rs | 103 ++++++++++++++ src/switch_to.rs | 41 ++++++ src/type_state.rs | 74 ++++++++++ 5 files changed, 361 insertions(+), 321 deletions(-) create mode 100644 src/require.rs create mode 100644 src/states.rs create mode 100644 src/switch_to.rs create mode 100644 src/type_state.rs diff --git a/src/lib.rs b/src/lib.rs index c3fdc86..66a2af8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,15 +16,17 @@ extern crate proc_macro; +mod require; +mod states; +mod switch_to; +mod type_state; + +use require::require_inner; +use states::states_inner; +use switch_to::switch_to_inner; +use type_state::type_state_inner; + use proc_macro::TokenStream; -use quote::quote; -use syn::{ - parse::{Parse, ParseStream, Parser}, - parse_macro_input, - punctuated::Punctuated, - Expr, Fields, Ident, ImplItem, ItemFn, ItemImpl, ItemStruct, Member, Meta, ReturnType, Stmt, - Token, Type, -}; /// Denotes which state is required for this method to be called. /// @@ -44,126 +46,7 @@ use syn::{ /// providing the name of the struct to the `#[require]` macro behind the scenes. #[proc_macro_attribute] pub fn require(args: TokenStream, input: TokenStream) -> TokenStream { - // Parse the input arguments and function: (ImplType, State1, State2, ...) - let args_parser = Punctuated::::parse_terminated; - let parsed_args = args_parser.parse(args).unwrap(); - - // Extract the first argument (the name of the impl block) - - let struct_name = &parsed_args[0]; - - // Extract the remaining arguments (states and generics) - - let remaining_args: Vec = parsed_args.iter().skip(1).cloned().collect(); - - let input_fn = parse_macro_input!(input as ItemFn); - - // Only the single letter arguments will be used as generic constraints: (A, B, ...) - let generic_idents: Vec = remaining_args - .iter() - .filter(|ident| is_single_letter(ident)) - .map(|ident| quote!(#ident)) - .collect(); - - // Get the full list of arguments as a vec: (A, B, State1, ...) - let concrete_type: Vec = - remaining_args.iter().map(|ident| quote!(#ident)).collect(); - - // put the sealed trait boundary for the generics: - /* - ``` where - A: TypeStateProtector, - B: TypeStateProtector, - */ - let where_clauses: Vec = remaining_args - .iter() - .filter(|ident| is_single_letter(ident)) - .map(|ident| quote!(#ident: TypeStateProtector)) - .collect(); // Collect into a Vec to make `is_empty()` available - - // Generate the `where` clause only if there are any constraints - let where_clause = if !where_clauses.is_empty() { - quote! { where #(#where_clauses),* } - } else { - quote! {} - }; - - // Get the function name and its generics - let fn_name = &input_fn.sig.ident; - let fn_body = &input_fn.block; - let fn_inputs = &input_fn.sig.inputs; - let fn_output = &input_fn.sig.output; - - // Collect all other macros except the `#[require]` attribute itself - let other_attrs: Vec<_> = input_fn - .attrs - .iter() - .filter(|attr| !attr.path().is_ident("require")) - .collect(); - - // Generate PhantomData for the required number of states - let phantom_data_count = remaining_args.len(); - let phantom_data: Vec = (0..phantom_data_count) - .map(|_| quote!(::std::marker::PhantomData)) - .collect(); - - let phantom_expr = if phantom_data.len() == 1 { - quote! { ::std::marker::PhantomData } - } else { - quote! { ( #(#phantom_data),* ) } - }; - - // Convert the method body to modify struct construction - let new_fn_body = fn_body - .stmts - .iter() - .map(|stmt| { - // Check if the statement contains the struct initialization (e.g., `PlayerBuilder {`) - if let Stmt::Expr(Expr::Struct(expr_struct), maybe_semi) = stmt { - let struct_path = &expr_struct.path; - if struct_path.is_ident(struct_name) { - // Append `_state: (PhantomData, PhantomData, ...)` to the struct fields - let mut new_fields = expr_struct.fields.clone(); - new_fields.push(syn::FieldValue { - attrs: Vec::new(), - member: Member::Named(syn::Ident::new("_state", struct_name.span())), - colon_token: Some(::default()), - expr: Expr::Verbatim(phantom_expr.clone()), - }); - - // Return modified struct construction - return Stmt::Expr( - syn::Expr::Struct(syn::ExprStruct { - fields: new_fields, - ..expr_struct.clone() - }), - *maybe_semi, - ); - } - } - // Return the statement unchanged if it's not a struct construction - stmt.clone() - }) - .collect::>(); - - // Construct the `impl` block - let output = quote! { - impl<#(#generic_idents),*> #struct_name<#(#concrete_type),*> - #where_clause - { - #(#other_attrs)* - fn #fn_name(#fn_inputs) #fn_output { - #(#new_fn_body)* - } - } - }; - - output.into() -} - -fn is_single_letter(ident: &Ident) -> bool { - let ident_str = ident.to_string(); - ident_str.len() == 1 + require_inner(args, input) } /// Denotes to which state will the object transition into after this method @@ -176,50 +59,7 @@ fn is_single_letter(ident: &Ident) -> bool { /// - overwrites the return type of the `impl` block generated by the `#[require]` macro #[proc_macro_attribute] pub fn switch_to(args: TokenStream, input: TokenStream) -> TokenStream { - // Parse the input arguments and function: (State1, State2, ...) - let args_parser = Punctuated::::parse_terminated; - let parsed_args = args_parser.parse(args).unwrap(); - let input_fn = parse_macro_input!(input as ItemFn); - - // Get the function name, inputs, and body - let fn_name = &input_fn.sig.ident; - let fn_inputs = &input_fn.sig.inputs; - let fn_body = &input_fn.block; - - // Get the full list of arguments as a vec: (A, B, State1, ...) - let generic_idents: Vec = - parsed_args.iter().map(|i| quote!(#i)).collect(); - - // Parse the original return type from the function signature - let original_return_type = match &input_fn.sig.output { - ReturnType::Type(_, ty) => quote! { #ty }, - _ => panic!("Expected a return type."), - }; - - // Construct the new return type using the original name and the new generics - let return_type = quote! { - #original_return_type<#(#generic_idents),*> - }; - - // Construct the new method with the modified return type - let output = quote! { - fn #fn_name(#fn_inputs) -> #return_type { - #fn_body - } - }; - - output.into() -} - -struct StatesInput { - states: Punctuated, -} - -impl Parse for StatesInput { - fn parse(input: ParseStream) -> syn::Result { - let states = Punctuated::parse_terminated(input)?; - Ok(StatesInput { states }) - } + switch_to_inner(args, input) } /// Denotes which states will be used for the type-state pattern. @@ -238,87 +78,7 @@ impl Parse for StatesInput { /// providing the name of the struct to the `#[require]` macro behind the scenes. #[proc_macro_attribute] pub fn states(attr: TokenStream, item: TokenStream) -> TokenStream { - // Parse the list of states from the attribute: (State1, State2, ...) - let args = parse_macro_input!(attr as StatesInput); - - // Parse the impl block - let mut input = parse_macro_input!(item as ItemImpl); - - // Extract the type name of the impl block (e.g., Player) - let impl_type = match *input.self_ty { - Type::Path(ref type_path) => type_path.path.segments.last().unwrap().ident.clone(), - _ => panic!("Unsupported type for impl block"), - }; - - // Extract the methods from the impl block - let mut methods = Vec::new(); - - for item in input.items.iter_mut() { - if let ImplItem::Fn(ref mut method) = item { - // Check if the method has a `#[require]` attribute - for attr in method.attrs.iter_mut() { - if attr.path().is_ident("require") { - // Parse the tokens of the `#[require]` macro - let mut args: Punctuated = - attr.parse_args_with(Punctuated::parse_terminated).unwrap(); - - // Append the impl block type (e.g., Player) as the first argument - args.insert(0, impl_type.clone()); - - // Update the attribute tokens with the new arguments - let a = match attr.meta { - Meta::List(ref mut list) => list, - _ => panic!("Expected a list of arguments"), - }; - - a.tokens = quote! { #args }; - } - } - - methods.push(quote! { #method }); - } - } - - // Generate the marker structs, and their implementations - let mut markers = Vec::new(); - let mut sealed_impls = Vec::new(); - let mut trait_impls = Vec::new(); - - for state in args.states { - let marker_name = Ident::new(&format!("{}", state), state.span()); - - markers.push(quote! { - struct #marker_name; - }); - - sealed_impls.push(quote! { - impl sealed::Sealed for #marker_name {} - }); - - trait_impls.push(quote! { - impl TypeStateProtector for #marker_name {} - }); - } - - // Generate the full expanded code - let expanded = quote! { - // Private module to seal traits - mod sealed { - pub trait Sealed {} - } - - pub trait TypeStateProtector: sealed::Sealed {} - - #(#markers)* - - #(#sealed_impls)* - - #(#trait_impls)* - - #(#methods)* - }; - - TokenStream::from(expanded) + states_inner(attr, item) } /// Turns your struct into type-state compatible version. @@ -333,72 +93,5 @@ pub fn states(attr: TokenStream, item: TokenStream) -> TokenStream { /// by enforcing sealed-trait bounds on the states. #[proc_macro_attribute] pub fn type_state(args: TokenStream, input: TokenStream) -> TokenStream { - // Parse the `state_slots` and `default_state` from the arguments - /* - Indices: - --- - 0. `state_slots` - 1. `=` - 2. `3` (this is the value you're interested in for state_slots) - 4. `,` - 5. `default_state` - 6. `=` - 7. `Initial` (this is the value you're interested in for default_state) - */ - let input_args: Vec<_> = args.into_iter().collect(); - let state_slots: usize = if let Some(proc_macro::TokenTree::Literal(lit)) = input_args.get(2) { - lit.to_string().parse().unwrap() - } else { - panic!("Expected a valid number for state_slots."); - }; - - let default_state: Ident = if let Some(proc_macro::TokenTree::Ident(ident)) = input_args.get(6) - { - Ident::new(&format!("{}", ident), ident.span().into()) - } else { - panic!("Expected an identifier for default_state."); - }; - - // Parse the input struct - let input_struct = parse_macro_input!(input as ItemStruct); - let struct_name = &input_struct.ident; - // Extract fields from the struct - let struct_fields = match input_struct.fields { - Fields::Named(ref fields) => &fields.named, - Fields::Unnamed(_) => panic!("Expected named fields in struct."), - Fields::Unit => panic!("Expected a struct with fields."), - }; - - // Generate state generics: `struct StructName` - let state_idents: Vec = (0..state_slots) - .map(|i| Ident::new(&format!("State{}", i + 1), struct_name.span())) - .collect(); - - let default_generics = vec![quote!(#default_state); state_slots]; - - let where_clauses = (0..state_slots).map(|i| { - let state_num = Ident::new(&format!("State{}", i + 1), struct_name.span()); - quote!(#state_num: TypeStateProtector) - }); - - // Construct the `_state` field with PhantomData - // `_state: PhantomData T>` - // the reason for using `fn() -> T` is to: https://github.com/ozgunozerk/state-shift/issues/1 - let phantom_fields = state_idents - .iter() - .map(|ident| quote!(::std::marker::PhantomData #ident>)) - .collect::>(); - - let output = quote! { - #[allow(clippy::type_complexity)] - struct #struct_name<#(#state_idents = #default_generics),*> - where - #(#where_clauses),* - { - #struct_fields - _state: (#(#phantom_fields),*), - } - }; - - output.into() + type_state_inner(args, input) } diff --git a/src/require.rs b/src/require.rs new file mode 100644 index 0000000..7951d88 --- /dev/null +++ b/src/require.rs @@ -0,0 +1,129 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + parse::Parser, parse_macro_input, punctuated::Punctuated, Expr, Ident, ItemFn, Member, Stmt, + Token, +}; + +pub fn require_inner(args: TokenStream, input: TokenStream) -> TokenStream { + // Parse the input arguments and function: (ImplType, State1, State2, ...) + let args_parser = Punctuated::::parse_terminated; + let parsed_args = args_parser.parse(args).unwrap(); + + // Extract the first argument (the name of the impl block) + + let struct_name = &parsed_args[0]; + + // Extract the remaining arguments (states and generics) + + let remaining_args: Vec = parsed_args.iter().skip(1).cloned().collect(); + + let input_fn = parse_macro_input!(input as ItemFn); + + // Only the single letter arguments will be used as generic constraints: (A, B, ...) + let generic_idents: Vec = remaining_args + .iter() + .filter(|ident| is_single_letter(ident)) + .map(|ident| quote!(#ident)) + .collect(); + + // Get the full list of arguments as a vec: (A, B, State1, ...) + let concrete_type: Vec = + remaining_args.iter().map(|ident| quote!(#ident)).collect(); + + // put the sealed trait boundary for the generics: + /* + ``` where + A: TypeStateProtector, + B: TypeStateProtector, + */ + let where_clauses: Vec = remaining_args + .iter() + .filter(|ident| is_single_letter(ident)) + .map(|ident| quote!(#ident: TypeStateProtector)) + .collect(); // Collect into a Vec to make `is_empty()` available + + // Generate the `where` clause only if there are any constraints + let where_clause = if !where_clauses.is_empty() { + quote! { where #(#where_clauses),* } + } else { + quote! {} + }; + + // Get the function name and its generics + let fn_name = &input_fn.sig.ident; + let fn_body = &input_fn.block; + let fn_inputs = &input_fn.sig.inputs; + let fn_output = &input_fn.sig.output; + + // Collect all other macros except the `#[require]` attribute itself + let other_attrs: Vec<_> = input_fn + .attrs + .iter() + .filter(|attr| !attr.path().is_ident("require")) + .collect(); + + // Generate PhantomData for the required number of states + let phantom_data_count = remaining_args.len(); + let phantom_data: Vec = (0..phantom_data_count) + .map(|_| quote!(::std::marker::PhantomData)) + .collect(); + + let phantom_expr = if phantom_data.len() == 1 { + quote! { ::std::marker::PhantomData } + } else { + quote! { ( #(#phantom_data),* ) } + }; + + // Convert the method body to modify struct construction + let new_fn_body = fn_body + .stmts + .iter() + .map(|stmt| { + // Check if the statement contains the struct initialization (e.g., `PlayerBuilder {`) + if let Stmt::Expr(Expr::Struct(expr_struct), maybe_semi) = stmt { + let struct_path = &expr_struct.path; + if struct_path.is_ident(struct_name) { + // Append `_state: (PhantomData, PhantomData, ...)` to the struct fields + let mut new_fields = expr_struct.fields.clone(); + new_fields.push(syn::FieldValue { + attrs: Vec::new(), + member: Member::Named(syn::Ident::new("_state", struct_name.span())), + colon_token: Some(::default()), + expr: Expr::Verbatim(phantom_expr.clone()), + }); + + // Return modified struct construction + return Stmt::Expr( + syn::Expr::Struct(syn::ExprStruct { + fields: new_fields, + ..expr_struct.clone() + }), + *maybe_semi, + ); + } + } + // Return the statement unchanged if it's not a struct construction + stmt.clone() + }) + .collect::>(); + + // Construct the `impl` block + let output = quote! { + impl<#(#generic_idents),*> #struct_name<#(#concrete_type),*> + #where_clause + { + #(#other_attrs)* + fn #fn_name(#fn_inputs) #fn_output { + #(#new_fn_body)* + } + } + }; + + output.into() +} + +fn is_single_letter(ident: &Ident) -> bool { + let ident_str = ident.to_string(); + ident_str.len() == 1 +} diff --git a/src/states.rs b/src/states.rs new file mode 100644 index 0000000..8debd62 --- /dev/null +++ b/src/states.rs @@ -0,0 +1,103 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, + punctuated::Punctuated, + Ident, ImplItem, ItemImpl, Meta, Token, Type, +}; + +struct StatesInput { + states: Punctuated, +} + +impl Parse for StatesInput { + fn parse(input: ParseStream) -> syn::Result { + let states = Punctuated::parse_terminated(input)?; + Ok(StatesInput { states }) + } +} + +pub fn states_inner(attr: TokenStream, item: TokenStream) -> TokenStream { + // Parse the list of states from the attribute: (State1, State2, ...) + let args = parse_macro_input!(attr as StatesInput); + + // Parse the impl block + let mut input = parse_macro_input!(item as ItemImpl); + + // Extract the type name of the impl block (e.g., Player) + let impl_type = match *input.self_ty { + Type::Path(ref type_path) => type_path.path.segments.last().unwrap().ident.clone(), + _ => panic!("Unsupported type for impl block"), + }; + + // Extract the methods from the impl block + let mut methods = Vec::new(); + + for item in input.items.iter_mut() { + if let ImplItem::Fn(ref mut method) = item { + // Check if the method has a `#[require]` attribute + for attr in method.attrs.iter_mut() { + if attr.path().is_ident("require") { + // Parse the tokens of the `#[require]` macro + let mut args: Punctuated = + attr.parse_args_with(Punctuated::parse_terminated).unwrap(); + + // Append the impl block type (e.g., Player) as the first argument + args.insert(0, impl_type.clone()); + + // Update the attribute tokens with the new arguments + let a = match attr.meta { + Meta::List(ref mut list) => list, + _ => panic!("Expected a list of arguments"), + }; + + a.tokens = quote! { #args }; + } + } + + methods.push(quote! { #method }); + } + } + + // Generate the marker structs, and their implementations + let mut markers = Vec::new(); + let mut sealed_impls = Vec::new(); + let mut trait_impls = Vec::new(); + + for state in args.states { + let marker_name = Ident::new(&format!("{}", state), state.span()); + + markers.push(quote! { + struct #marker_name; + }); + + sealed_impls.push(quote! { + impl sealed::Sealed for #marker_name {} + }); + + trait_impls.push(quote! { + impl TypeStateProtector for #marker_name {} + }); + } + + // Generate the full expanded code + let expanded = quote! { + // Private module to seal traits + mod sealed { + pub trait Sealed {} + } + + pub trait TypeStateProtector: sealed::Sealed {} + + #(#markers)* + + #(#sealed_impls)* + + #(#trait_impls)* + + #(#methods)* + }; + + TokenStream::from(expanded) +} diff --git a/src/switch_to.rs b/src/switch_to.rs new file mode 100644 index 0000000..ef40db8 --- /dev/null +++ b/src/switch_to.rs @@ -0,0 +1,41 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{ + parse::Parser, parse_macro_input, punctuated::Punctuated, Ident, ItemFn, ReturnType, Token, +}; + +pub fn switch_to_inner(args: TokenStream, input: TokenStream) -> TokenStream { + // Parse the input arguments and function: (State1, State2, ...) + let args_parser = Punctuated::::parse_terminated; + let parsed_args = args_parser.parse(args).unwrap(); + let input_fn = parse_macro_input!(input as ItemFn); + + // Get the function name, inputs, and body + let fn_name = &input_fn.sig.ident; + let fn_inputs = &input_fn.sig.inputs; + let fn_body = &input_fn.block; + + // Get the full list of arguments as a vec: (A, B, State1, ...) + let generic_idents: Vec = + parsed_args.iter().map(|i| quote!(#i)).collect(); + + // Parse the original return type from the function signature + let original_return_type = match &input_fn.sig.output { + ReturnType::Type(_, ty) => quote! { #ty }, + _ => panic!("Expected a return type."), + }; + + // Construct the new return type using the original name and the new generics + let return_type = quote! { + #original_return_type<#(#generic_idents),*> + }; + + // Construct the new method with the modified return type + let output = quote! { + fn #fn_name(#fn_inputs) -> #return_type { + #fn_body + } + }; + + output.into() +} diff --git a/src/type_state.rs b/src/type_state.rs new file mode 100644 index 0000000..67aaad8 --- /dev/null +++ b/src/type_state.rs @@ -0,0 +1,74 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::{parse_macro_input, Fields, Ident, ItemStruct}; + +pub fn type_state_inner(args: TokenStream, input: TokenStream) -> TokenStream { + // Parse the `state_slots` and `default_state` from the arguments + /* + Indices: + --- + 0. `state_slots` + 1. `=` + 2. `3` (this is the value you're interested in for state_slots) + 4. `,` + 5. `default_state` + 6. `=` + 7. `Initial` (this is the value you're interested in for default_state) + */ + let input_args: Vec<_> = args.into_iter().collect(); + let state_slots: usize = if let Some(proc_macro::TokenTree::Literal(lit)) = input_args.get(2) { + lit.to_string().parse().unwrap() + } else { + panic!("Expected a valid number for state_slots."); + }; + + let default_state: Ident = if let Some(proc_macro::TokenTree::Ident(ident)) = input_args.get(6) + { + Ident::new(&format!("{}", ident), ident.span().into()) + } else { + panic!("Expected an identifier for default_state."); + }; + + // Parse the input struct + let input_struct = parse_macro_input!(input as ItemStruct); + let struct_name = &input_struct.ident; + // Extract fields from the struct + let struct_fields = match input_struct.fields { + Fields::Named(ref fields) => &fields.named, + Fields::Unnamed(_) => panic!("Expected named fields in struct."), + Fields::Unit => panic!("Expected a struct with fields."), + }; + + // Generate state generics: `struct StructName` + let state_idents: Vec = (0..state_slots) + .map(|i| Ident::new(&format!("State{}", i + 1), struct_name.span())) + .collect(); + + let default_generics = vec![quote!(#default_state); state_slots]; + + let where_clauses = (0..state_slots).map(|i| { + let state_num = Ident::new(&format!("State{}", i + 1), struct_name.span()); + quote!(#state_num: TypeStateProtector) + }); + + // Construct the `_state` field with PhantomData + // `_state: PhantomData T>` + // the reason for using `fn() -> T` is to: https://github.com/ozgunozerk/state-shift/issues/1 + let phantom_fields = state_idents + .iter() + .map(|ident| quote!(::std::marker::PhantomData #ident>)) + .collect::>(); + + let output = quote! { + #[allow(clippy::type_complexity)] + struct #struct_name<#(#state_idents = #default_generics),*> + where + #(#where_clauses),* + { + #struct_fields + _state: (#(#phantom_fields),*), + } + }; + + output.into() +}