From 69b152094b314e3f568aa148a4badc807d21a165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Pobiar=C5=BCyn?= Date: Tue, 28 Nov 2023 14:24:44 +0100 Subject: [PATCH] Generate odra::Module implementation (#268) * generate Module trait implementation * generate odra::Module in examples * update odra paths, * update ModuleIR --- examples2/src/counter_pack.rs | 42 +----- examples2/src/erc20.rs | 23 +--- odra-macros/src/ast/host_ref_item.rs | 2 +- odra-macros/src/ast/mod.rs | 2 + odra-macros/src/ast/module_item.rs | 192 +++++++++++++++++++++++++++ odra-macros/src/ast/test_parts.rs | 2 +- odra-macros/src/ir/mod.rs | 117 +++++++++++++--- odra-macros/src/lib.rs | 50 +++++-- odra-macros/src/test_utils.rs | 16 ++- odra-macros/src/utils/expr.rs | 9 ++ odra-macros/src/utils/string.rs | 17 +-- odra-macros/src/utils/syn.rs | 89 +++++++++++-- odra-macros/src/utils/ty.rs | 16 +++ 13 files changed, 462 insertions(+), 115 deletions(-) create mode 100644 odra-macros/src/ast/module_item.rs diff --git a/examples2/src/counter_pack.rs b/examples2/src/counter_pack.rs index c0513fbb..6afb1360 100644 --- a/examples2/src/counter_pack.rs +++ b/examples2/src/counter_pack.rs @@ -5,6 +5,7 @@ use odra::Mapping; use odra::Module; use odra::ModuleWrapper; +#[odra_macros::module] pub struct CounterPack { env: Rc, counter0: ModuleWrapper, @@ -61,47 +62,6 @@ impl CounterPack { } } -// autogenerated -mod odra_core_module { - use super::*; - - impl Module for CounterPack { - fn new(env: Rc) -> Self { - let counter0 = ModuleWrapper::new(Rc::clone(&env), 0); - let counter1 = ModuleWrapper::new(Rc::clone(&env), 1); - let counter2 = ModuleWrapper::new(Rc::clone(&env), 2); - let counter3 = ModuleWrapper::new(Rc::clone(&env), 3); - let counter4 = ModuleWrapper::new(Rc::clone(&env), 4); - let counter5 = ModuleWrapper::new(Rc::clone(&env), 5); - let counter6 = ModuleWrapper::new(Rc::clone(&env), 6); - let counter7 = ModuleWrapper::new(Rc::clone(&env), 7); - let counter8 = ModuleWrapper::new(Rc::clone(&env), 8); - let counter9 = ModuleWrapper::new(Rc::clone(&env), 9); - let counters = Mapping::new(Rc::clone(&env), 10); - let counters_map = Mapping::new(Rc::clone(&env), 11); - Self { - env, - counter0, - counter1, - counter2, - counter3, - counter4, - counter5, - counter6, - counter7, - counter8, - counter9, - counters, - counters_map - } - } - - fn env(&self) -> Rc { - self.env.clone() - } - } -} - #[cfg(odra_module = "CounterPack")] #[cfg(target_arch = "wasm32")] mod __counter_pack_wasm_parts { diff --git a/examples2/src/erc20.rs b/examples2/src/erc20.rs index 9d7f1dcd..4d824988 100644 --- a/examples2/src/erc20.rs +++ b/examples2/src/erc20.rs @@ -36,6 +36,7 @@ impl From for OdraError { } } +#[odra_macros::module] pub struct Erc20 { env: Rc, total_supply: Variable, @@ -138,28 +139,6 @@ impl Erc20 { } } -// autogenerated for general purpose module. -mod __erc20_module { - use super::Erc20; - use odra::{module::Module, prelude::*, ContractEnv, Mapping, Variable}; - - impl Module for Erc20 { - fn new(env: Rc) -> Self { - let total_supply = Variable::new(Rc::clone(&env), 1); - let balances = Mapping::new(Rc::clone(&env), 2); - Self { - env, - total_supply, - balances - } - } - - fn env(&self) -> Rc { - self.env.clone() - } - } -} - #[cfg(odra_module = "Erc20")] mod __erc20_schema { use odra::{contract_def::ContractBlueprint2, prelude::String}; diff --git a/odra-macros/src/ast/host_ref_item.rs b/odra-macros/src/ast/host_ref_item.rs index f818117d..00095518 100644 --- a/odra-macros/src/ast/host_ref_item.rs +++ b/odra-macros/src/ast/host_ref_item.rs @@ -198,7 +198,7 @@ mod ref_item_tests { self.env.last_call().contract_last_call(self.address) } - pub fn try_total_supply(&self) -> Result { + pub fn try_total_supply(&self) -> Result { self.env.call_contract( self.address, odra::CallDef::new( diff --git a/odra-macros/src/ast/mod.rs b/odra-macros/src/ast/mod.rs index d396ae5d..bff0073c 100644 --- a/odra-macros/src/ast/mod.rs +++ b/odra-macros/src/ast/mod.rs @@ -1,10 +1,12 @@ mod deployer_item; mod deployer_utils; mod host_ref_item; +mod module_item; mod parts_utils; mod ref_item; mod ref_utils; mod test_parts; +pub(crate) use module_item::ModuleModItem; pub(crate) use ref_item::RefItem; pub(crate) use test_parts::{TestParts, TestPartsReexport}; diff --git a/odra-macros/src/ast/module_item.rs b/odra-macros/src/ast/module_item.rs new file mode 100644 index 00000000..70e5bbe0 --- /dev/null +++ b/odra-macros/src/ast/module_item.rs @@ -0,0 +1,192 @@ +use quote::{ToTokens, TokenStreamExt}; +use syn::parse_quote; + +use crate::{ + ir::{EnumeratedTypedField, StructIR}, + utils +}; + +use super::parts_utils::UseSuperItem; + +#[derive(syn_derive::ToTokens)] +pub struct ModuleModItem { + mod_token: syn::token::Mod, + mod_ident: syn::Ident, + #[syn(braced)] + braces: syn::token::Brace, + #[syn(in = braces)] + use_super: UseSuperItem, + #[syn(in = braces)] + item: ModuleImplItem +} + +impl TryFrom<&'_ StructIR> for ModuleModItem { + type Error = syn::Error; + + fn try_from(ir: &'_ StructIR) -> Result { + Ok(Self { + mod_token: Default::default(), + mod_ident: ir.module_mod_ident(), + use_super: UseSuperItem, + braces: Default::default(), + item: ir.try_into()? + }) + } +} + +#[derive(syn_derive::ToTokens)] +struct ModuleImplItem { + impl_token: syn::token::Impl, + trait_path: syn::Type, + for_token: syn::token::For, + module_path: syn::Ident, + #[syn(braced)] + braces: syn::token::Brace, + #[syn(in = braces)] + new_fn: NewModuleFnItem, + #[syn(in = braces)] + env_fn: EnvFnItem +} + +impl TryFrom<&'_ StructIR> for ModuleImplItem { + type Error = syn::Error; + + fn try_from(ir: &'_ StructIR) -> Result { + Ok(Self { + impl_token: Default::default(), + trait_path: utils::ty::module(), + for_token: Default::default(), + module_path: ir.module_ident(), + braces: Default::default(), + new_fn: ir.try_into()?, + env_fn: EnvFnItem + }) + } +} + +#[derive(syn_derive::ToTokens)] +struct NewModuleFnItem { + sig: syn::Signature, + #[syn(braced)] + braces: syn::token::Brace, + #[syn(in = braces)] + #[to_tokens(|tokens, val| tokens.append_all(val))] + fields: Vec, + #[syn(in = braces)] + instance: ModuleInstanceItem +} + +impl TryFrom<&'_ StructIR> for NewModuleFnItem { + type Error = syn::Error; + + fn try_from(ir: &'_ StructIR) -> Result { + let ty_contract_env = utils::ty::contract_env(); + let env = utils::ident::env(); + let fields = ir.typed_fields()?; + Ok(Self { + sig: parse_quote!(fn new(#env: Rc<#ty_contract_env>) -> Self), + braces: Default::default(), + fields: fields.iter().map(Into::into).collect(), + instance: ir.try_into()? + }) + } +} + +#[derive(syn_derive::ToTokens)] +struct ModuleFieldItem { + let_token: syn::token::Let, + ident: syn::Ident, + assign_token: syn::token::Eq, + field_expr: syn::Expr, + semi_token: syn::token::Semi +} + +impl From<&'_ EnumeratedTypedField> for ModuleFieldItem { + fn from(field: &'_ EnumeratedTypedField) -> Self { + Self { + let_token: Default::default(), + ident: field.ident.clone(), + assign_token: Default::default(), + field_expr: utils::expr::new_type(&field.ty, &utils::ident::env(), field.idx), + semi_token: Default::default() + } + } +} + +#[derive(syn_derive::ToTokens)] +struct ModuleInstanceItem { + self_token: syn::token::SelfType, + #[syn(braced)] + braces: syn::token::Brace, + #[syn(in = braces)] + values: syn::punctuated::Punctuated +} + +impl TryFrom<&'_ StructIR> for ModuleInstanceItem { + type Error = syn::Error; + + fn try_from(ir: &'_ StructIR) -> Result { + Ok(Self { + self_token: Default::default(), + braces: Default::default(), + values: ir.field_names()?.into_iter().collect() + }) + } +} + +struct EnvFnItem; + +impl ToTokens for EnvFnItem { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let ty_contract_env = utils::ty::contract_env(); + let m_env = utils::member::env(); + + tokens.extend(quote::quote!( + fn env(&self) -> Rc<#ty_contract_env> { + #m_env.clone() + } + )) + } +} + +#[cfg(test)] +mod test { + use crate::test_utils; + use quote::quote; + + use super::ModuleModItem; + + #[test] + fn counter_pack() { + let module = test_utils::mock_module_definition(); + let expected = quote!( + mod __counter_pack_module { + use super::*; + + impl odra::Module for CounterPack { + fn new(env: Rc) -> Self { + let counter0 = ModuleWrapper::new(Rc::clone(&env), 0u8); + let counter1 = ModuleWrapper::new(Rc::clone(&env), 1u8); + let counter2 = ModuleWrapper::new(Rc::clone(&env), 2u8); + let counters = Variable::new(Rc::clone(&env), 3u8); + let counters_map = Mapping::new(Rc::clone(&env), 4u8); + Self { + env, + counter0, + counter1, + counter2, + counters, + counters_map + } + } + + fn env(&self) -> Rc { + self.env.clone() + } + } + } + ); + let actual = ModuleModItem::try_from(&module).unwrap(); + test_utils::assert_eq(actual, expected); + } +} diff --git a/odra-macros/src/ast/test_parts.rs b/odra-macros/src/ast/test_parts.rs index 884dd4ff..5de2637d 100644 --- a/odra-macros/src/ast/test_parts.rs +++ b/odra-macros/src/ast/test_parts.rs @@ -117,7 +117,7 @@ mod test { self.env.last_call().contract_last_call(self.address) } - pub fn try_total_supply(&self) -> Result { + pub fn try_total_supply(&self) -> Result { self.env.call_contract( self.address, odra::CallDef::new( diff --git a/odra-macros/src/ir/mod.rs b/odra-macros/src/ir/mod.rs index 4367cbd6..58236c8c 100644 --- a/odra-macros/src/ir/mod.rs +++ b/odra-macros/src/ir/mod.rs @@ -1,26 +1,109 @@ use crate::utils; -use proc_macro2::{Ident, TokenStream}; +use proc_macro2::Ident; use quote::format_ident; -use syn::{parse_quote, ItemImpl}; +use syn::{parse_quote, spanned::Spanned}; const CONSTRUCTOR_NAME: &str = "init"; -pub struct ModuleIR { - code: ItemImpl +macro_rules! try_parse { + ($from:path => $to:ident) => { + pub struct $to { + code: $from + } + + impl TryFrom<&proc_macro2::TokenStream> for $to { + type Error = syn::Error; + + fn try_from(stream: &proc_macro2::TokenStream) -> Result { + Ok(Self { + code: syn::parse2::<$from>(stream.clone())? + }) + } + } + }; } -impl TryFrom<&TokenStream> for ModuleIR { - type Error = syn::Error; +try_parse!(syn::ItemStruct => StructIR); + +impl StructIR { + pub fn self_code(&self) -> &syn::ItemStruct { + &self.code + } + + pub fn field_names(&self) -> Result, syn::Error> { + utils::syn::struct_fields_ident(&self.code) + } + + pub fn module_ident(&self) -> syn::Ident { + utils::syn::ident_from_struct(&self.code) + } + + pub fn module_mod_ident(&self) -> syn::Ident { + format_ident!( + "__{}_module", + utils::string::camel_to_snake(self.module_ident()) + ) + } + + pub fn typed_fields(&self) -> Result, syn::Error> { + let fields = utils::syn::struct_fields(&self.code)?; + let fields = fields + .iter() + .filter(|(i, _)| i != &utils::ident::env()) + .collect::>(); + + for (_, ty) in &fields { + Self::validate_ty(ty)?; + } + + fields + .iter() + .enumerate() + .map(|(idx, (ident, ty))| { + Ok(EnumeratedTypedField { + idx: idx as u8, + ident: ident.clone(), + ty: utils::syn::clear_generics(ty)? + }) + }) + .collect() + } - fn try_from(value: &TokenStream) -> Result { - Ok(Self { - code: syn::parse2::(value.clone())? - }) + fn validate_ty(ty: &syn::Type) -> Result<(), syn::Error> { + let non_generic_ty = utils::syn::clear_generics(ty)?; + + // both odra::Variable and Variable (Mapping, ModuleWrapper) are valid. + let valid_types = vec![ + utils::ty::module_wrapper(), + utils::ty::variable(), + utils::ty::mapping(), + ] + .iter() + .map(|ty| utils::syn::last_segment_ident(ty).map(|i| vec![ty.clone(), parse_quote!(#i)])) + .collect::, _>>()?; + let valid_types = valid_types.into_iter().flatten().collect::>(); + + if valid_types + .iter() + .any(|t| utils::string::eq(t, &non_generic_ty)) + { + return Ok(()); + } + + Err(syn::Error::new(ty.span(), "Invalid module type")) } } +pub struct EnumeratedTypedField { + pub idx: u8, + pub ident: syn::Ident, + pub ty: syn::Type +} + +try_parse!(syn::ItemImpl => ModuleIR); + impl ModuleIR { - pub fn self_code(&self) -> &ItemImpl { + pub fn self_code(&self) -> &syn::ItemImpl { &self.code } @@ -47,7 +130,6 @@ impl ModuleIR { )) } - #[allow(dead_code)] pub fn deployer_ident(&self) -> Result { let module_ident = self.module_ident()?; Ok(Ident::new( @@ -57,9 +139,14 @@ impl ModuleIR { } pub fn test_parts_mod_ident(&self) -> Result { - self.module_ident() - .map(crate::utils::string::camel_to_snake) - .map(|ident| format_ident!("__{}_test_parts", ident)) + let module_ident = self.module_ident()?; + Ok(Ident::new( + &format!( + "__{}_test_parts", + crate::utils::string::camel_to_snake(&module_ident) + ), + module_ident.span() + )) } pub fn functions(&self) -> Vec { diff --git a/odra-macros/src/lib.rs b/odra-macros/src/lib.rs index c4347c49..5f4527aa 100644 --- a/odra-macros/src/lib.rs +++ b/odra-macros/src/lib.rs @@ -1,8 +1,10 @@ -#![feature(box_patterns)] +#![feature(box_patterns, result_flattening)] use ast::*; -use ir::ModuleIR; +use ir::{ModuleIR, StructIR}; use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use syn::spanned::Spanned; mod ast; mod ir; @@ -12,20 +14,24 @@ mod utils; #[proc_macro_attribute] pub fn module(_attr: TokenStream, item: TokenStream) -> TokenStream { - match module_impl(item) { - Ok(result) => result, - Err(e) => e.to_compile_error() + let stream: TokenStream2 = item.into(); + if let Ok(ir) = ModuleIR::try_from(&stream) { + return handle_result(module_impl(ir)); } - .into() + if let Ok(ir) = StructIR::try_from(&stream) { + return handle_result(module_struct(ir)); + } + handle_result(Err(syn::Error::new( + stream.span(), + "Struct or impl block expected" + ))) } -fn module_impl(item: TokenStream) -> Result { - let module_ir = ModuleIR::try_from(&item.into())?; - - let code = module_ir.self_code(); - let ref_item = RefItem::try_from(&module_ir)?; - let test_parts = TestParts::try_from(&module_ir)?; - let test_parts_reexport = TestPartsReexport::try_from(&module_ir)?; +fn module_impl(ir: ModuleIR) -> Result { + let code = ir.self_code(); + let ref_item = RefItem::try_from(&ir)?; + let test_parts = TestParts::try_from(&ir)?; + let test_parts_reexport = TestPartsReexport::try_from(&ir)?; Ok(quote::quote! { #code @@ -34,3 +40,21 @@ fn module_impl(item: TokenStream) -> Result Result { + let code = ir.self_code(); + let module_mod = ModuleModItem::try_from(&ir)?; + + Ok(quote::quote!( + #code + #module_mod + )) +} + +fn handle_result(result: Result) -> TokenStream { + match result { + Ok(stream) => stream, + Err(e) => e.to_compile_error() + } + .into() +} diff --git a/odra-macros/src/test_utils.rs b/odra-macros/src/test_utils.rs index 489f8c58..e91525b2 100644 --- a/odra-macros/src/test_utils.rs +++ b/odra-macros/src/test_utils.rs @@ -1,6 +1,6 @@ use quote::{quote, ToTokens}; -use crate::ir::ModuleIR; +use crate::ir::{ModuleIR, StructIR}; pub fn mock_module() -> ModuleIR { let module = quote! { @@ -20,6 +20,20 @@ pub fn mock_module() -> ModuleIR { ModuleIR::try_from(&module).unwrap() } +pub fn mock_module_definition() -> StructIR { + let module = quote!( + pub struct CounterPack { + env: Rc, + counter0: ModuleWrapper, + counter1: ModuleWrapper, + counter2: ModuleWrapper, + counters: Variable, + counters_map: Mapping + } + ); + StructIR::try_from(&module).unwrap() +} + pub fn assert_eq(a: A, b: B) { fn parse(e: T) -> String { let e = e.to_token_stream().to_string(); diff --git a/odra-macros/src/utils/expr.rs b/odra-macros/src/utils/expr.rs index b7e8d98f..00315c8c 100644 --- a/odra-macros/src/utils/expr.rs +++ b/odra-macros/src/utils/expr.rs @@ -11,3 +11,12 @@ pub fn u512_zero() -> syn::Expr { pub fn parse_bytes(data_ident: &syn::Ident) -> syn::Expr { parse_quote!(odra::ToBytes::to_bytes(&#data_ident).map(Into::into).unwrap()) } + +pub fn new_type(ty: &syn::Type, env_ident: &syn::Ident, idx: u8) -> syn::Expr { + let rc = rc_clone(env_ident); + parse_quote!(#ty::new(#rc, #idx)) +} + +fn rc_clone(ident: &syn::Ident) -> syn::Expr { + parse_quote!(Rc::clone(&#ident)) +} diff --git a/odra-macros/src/utils/string.rs b/odra-macros/src/utils/string.rs index 7805a6b5..ed13d8c4 100644 --- a/odra-macros/src/utils/string.rs +++ b/odra-macros/src/utils/string.rs @@ -1,20 +1,13 @@ use convert_case::{Boundary, Case, Casing}; +use quote::ToTokens; -/// Converts a camel-cased &str to String. -/// -/// # Example -/// -/// ``` -/// use odra_utils::camel_to_snake; -/// -/// let camel = "ContractName"; -/// let result = camel_to_snake(camel); -/// -/// assert_eq!(&result, "contract_name"); -/// ``` pub fn camel_to_snake(text: T) -> String { text.to_string() .from_case(Case::UpperCamel) .without_boundaries(&[Boundary::UpperDigit, Boundary::LowerDigit]) .to_case(Case::Snake) } + +pub fn eq(a: A, b: B) -> bool { + a.to_token_stream().to_string() == b.to_token_stream().to_string() +} diff --git a/odra-macros/src/utils/syn.rs b/odra-macros/src/utils/syn.rs index 55f7676a..e741b0c8 100644 --- a/odra-macros/src/utils/syn.rs +++ b/odra-macros/src/utils/syn.rs @@ -1,15 +1,11 @@ use syn::{parse_quote, spanned::Spanned}; pub fn ident_from_impl(impl_code: &syn::ItemImpl) -> Result { - match &*impl_code.self_ty { - syn::Type::Path(type_path) => { - Ok(type_path.path.segments.last().expect("dupa").ident.clone()) - } - ty => Err(syn::Error::new( - ty.span(), - "Only support impl for type path" - )) - } + last_segment_ident(&impl_code.self_ty) +} + +pub fn ident_from_struct(struct_code: &syn::ItemStruct) -> syn::Ident { + struct_code.ident.clone() } pub fn function_arg_names(function: &syn::ImplItemFn) -> Vec { @@ -67,6 +63,81 @@ pub fn function_return_type(function: &syn::ImplItemFn) -> syn::ReturnType { function.sig.output.clone() } +pub fn struct_fields_ident(item: &syn::ItemStruct) -> Result, syn::Error> { + if let syn::Fields::Named(named) = &item.fields { + let err_msg = "Invalid field. Module fields must be named"; + named + .named + .iter() + .map(|f| f.ident.clone().ok_or(syn::Error::new(f.span(), err_msg))) + .collect::, syn::Error>>() + } else { + Err(syn::Error::new_spanned( + &item.fields, + "Invalid fields. Module fields must be named" + )) + } +} + +pub fn struct_fields(item: &syn::ItemStruct) -> Result, syn::Error> { + if let syn::Fields::Named(named) = &item.fields { + let err_msg = "Invalid field. Module fields must be named"; + named + .named + .iter() + .map(|f| { + f.ident + .clone() + .ok_or(syn::Error::new_spanned(f, err_msg)) + .map(|i| (i, f.ty.clone())) + }) + .collect() + } else { + Err(syn::Error::new_spanned( + &item.fields, + "Invalid fields. Module fields must be named" + )) + } +} + pub fn visibility_pub() -> syn::Visibility { parse_quote!(pub) } + +pub fn last_segment_ident(ty: &syn::Type) -> Result { + match ty { + syn::Type::Path(type_path) => type_path + .path + .segments + .last() + .map(|seg| seg.ident.clone()) + .ok_or(syn::Error::new(type_path.span(), "Invalid type path")), + ty => Err(syn::Error::new( + ty.span(), + "Only support impl for type path" + )) + } +} + +pub fn clear_generics(ty: &syn::Type) -> Result { + match ty { + syn::Type::Path(type_path) => clear_path(type_path).map(syn::Type::Path), + ty => Err(syn::Error::new( + ty.span(), + "Only support impl for type path" + )) + } +} + +fn clear_path(ty: &syn::TypePath) -> Result { + let mut owned_ty = ty.to_owned(); + + let mut segment = owned_ty + .path + .segments + .last_mut() + .ok_or(syn::Error::new(ty.span(), "Invalid type path"))?; + segment.arguments = syn::PathArguments::None; + + Ok(owned_ty) +} diff --git a/odra-macros/src/utils/ty.rs b/odra-macros/src/utils/ty.rs index 8c77b239..a6fdf1a6 100644 --- a/odra-macros/src/utils/ty.rs +++ b/odra-macros/src/utils/ty.rs @@ -43,3 +43,19 @@ pub fn contract_call_result() -> syn::Type { pub fn odra_error() -> syn::Type { parse_quote!(odra::OdraError) } + +pub fn module_wrapper() -> syn::Type { + parse_quote!(odra::ModuleWrapper) +} + +pub fn module() -> syn::Type { + parse_quote!(odra::Module) +} + +pub fn variable() -> syn::Type { + parse_quote!(odra::Variable) +} + +pub fn mapping() -> syn::Type { + parse_quote!(odra::Mapping) +}