Skip to content

Commit

Permalink
feat(extension): proc macro impl for Extension Trait (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
indirection42 authored Jun 26, 2024
1 parent 5d42cb0 commit 7bad98c
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 192 deletions.
13 changes: 11 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 2 additions & 32 deletions xcq-extension-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use parity_scale_codec::{Decode, Encode};
use xcq_extension::Vec;
use xcq_extension::{DispatchError, Dispatchable};
use xcq_extension::{ExtensionId, ExtensionIdTy};
use xcq_extension::extension;

#[extension]
pub trait ExtensionCore {
type Config: Config;
fn has_extension(id: <Self::Config as Config>::ExtensionId) -> bool;
Expand All @@ -16,32 +15,3 @@ pub trait ExtensionCore {
pub trait Config {
type ExtensionId: Decode;
}

// #[extension(ExtensionCore)]
// type Call;

mod generated_by_extension_decl {
use super::*;

type ExtensionIdOf<T> = <<T as ExtensionCore>::Config as Config>::ExtensionId;
#[derive(Decode)]
pub enum ExtensionCoreCall<Impl: ExtensionCore> {
HasExtension { id: ExtensionIdOf<Impl> },
}

impl<Impl: ExtensionCore> Dispatchable for ExtensionCoreCall<Impl> {
fn dispatch(self) -> Result<Vec<u8>, DispatchError> {
match self {
Self::HasExtension { id } => Ok(Impl::has_extension(id).encode()),
}
}
}

impl<Impl: ExtensionCore> ExtensionId for ExtensionCoreCall<Impl> {
const EXTENSION_ID: ExtensionIdTy = 0u64;
}

pub type Call<Impl> = ExtensionCoreCall<Impl>;
}

pub use generated_by_extension_decl::*;
42 changes: 2 additions & 40 deletions xcq-extension-fungibles/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use parity_scale_codec::{Decode, Encode};
use xcq_extension::Vec;
use xcq_extension::{DispatchError, Dispatchable};
use xcq_extension::{ExtensionId, ExtensionIdTy};
use xcq_extension::extension;

pub type AccountIdFor<T> = <<T as ExtensionFungibles>::Config as Config>::AccountId;
pub type BalanceFor<T> = <<T as ExtensionFungibles>::Config as Config>::Balance;
pub type AssetIdFor<T> = <<T as ExtensionFungibles>::Config as Config>::AssetId;

#[extension]
pub trait ExtensionFungibles {
type Config: Config;
// fungibles::Inspect (not extensive)
Expand All @@ -24,40 +23,3 @@ pub trait Config {
type AssetId: Decode;
type Balance: Encode;
}

// #[extension(ExtensionFungibles)]
// type Call;

mod generated_by_extension_decl {

use super::*;

#[derive(Decode)]
pub enum ExtensionFungiblesCall<Impl: ExtensionFungibles> {
// TODO: not extensive
Balance {
asset: AssetIdFor<Impl>,
who: AccountIdFor<Impl>,
},
TotalSupply {
asset: AssetIdFor<Impl>,
},
}

impl<Impl: ExtensionFungibles> Dispatchable for ExtensionFungiblesCall<Impl> {
fn dispatch(self) -> Result<Vec<u8>, DispatchError> {
match self {
Self::Balance { asset, who } => Ok(Impl::balance(asset, who).encode()),
Self::TotalSupply { asset } => Ok(Impl::total_supply(asset).encode()),
}
}
}

impl<Impl: ExtensionFungibles> ExtensionId for ExtensionFungiblesCall<Impl> {
const EXTENSION_ID: ExtensionIdTy = 1u64;
}

pub type Call<Impl> = ExtensionFungiblesCall<Impl>;
}

pub use generated_by_extension_decl::*;
4 changes: 1 addition & 3 deletions xcq-extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ scale-info = { workspace = true }
xcq-executor = { workspace = true }
impl-trait-for-tuples = "0.2.2"
tracing = { workspace = true }
xcq-extension-procedural = { path = "procedural" }

[dev-dependencies]
xcq-extension-core = { workspace = true }
xcq-extension-fungibles = { workspace = true }

[features]
default = ["std"]
Expand Down
16 changes: 16 additions & 0 deletions xcq-extension/procedural/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "xcq-extension-procedural"
authors.workspace = true
edition.workspace = true
repository.workspace = true
license.workspace = true
version.workspace = true

[lib]
proc-macro = true

[dependencies]
syn = { version = "2", features = ["full", "extra-traits"] }
quote = "1"
proc-macro2 = "1"
twox-hash = "1.6.3"
182 changes: 182 additions & 0 deletions xcq-extension/procedural/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use proc_macro2::TokenStream;
use quote::quote;
use std::hash::{Hash, Hasher};
use syn::token::Comma;
use syn::{parse_macro_input, parse_quote, parse_str, spanned::Spanned};
use syn::{punctuated::Punctuated, ExprCall, Field, Ident, ItemImpl, Pat, TraitItem, Variant};

#[derive(Clone)]
struct Method {
/// Function name
pub name: Ident,
/// Information on args: `(name, type)`
pub args: Vec<(Ident, Box<syn::Type>)>,
}

#[proc_macro_attribute]
pub fn extension(_args: proc_macro::TokenStream, input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as syn::ItemTrait);

let methods = match methods(&input.items) {
Ok(method) => method,
Err(e) => return e.to_compile_error().into(),
};

let call_enum_def = match call_enum_def(&input.ident, &methods) {
Ok(call_enum_def) => call_enum_def,
Err(e) => return e.to_compile_error().into(),
};

let dispatchable_impl = dispatchable_impl(&input.ident, &methods);
let extension_id_impl = extension_id_impl(&input.ident, &input.items);

let expanded = quote! {
#input
#call_enum_def
#dispatchable_impl
#extension_id_impl
};
expanded.into()
}

fn call_enum_def(trait_ident: &Ident, methods: &[Method]) -> syn::Result<syn::ItemEnum> {
let mut variants = Punctuated::<Variant, Comma>::new();
for method in methods {
let name = &method.name;
let mut args = Punctuated::<Field, Comma>::new();
for (name, ty) in &method.args {
let ty = replace_self_to_impl(ty)?;
args.push(parse_quote! {
#name: #ty
});
}
variants.push(parse_quote! {
#[allow(non_camel_case_types)]
#name {
#args
}
});
}
// Add phantom data
variants.push(parse_quote!(
#[doc(hidden)]
__Phantom(std::marker::PhantomData<Impl>)
));
Ok(parse_quote!(
#[derive(Decode)]
pub enum Call<Impl: #trait_ident> {
#variants
}
))
}

fn dispatchable_impl(trait_ident: &Ident, methods: &[Method]) -> TokenStream {
let mut pats = Vec::<Pat>::new();
for method in methods {
let name = &method.name;
let mut args = Punctuated::<Ident, Comma>::new();
for (ident, _ty) in &method.args {
args.push(parse_quote! {
#ident
});
}
pats.push(parse_quote! {
Self::#name {
#args
}
});
}

let mut method_calls = Vec::<ExprCall>::new();
for method in methods {
let name = &method.name;
let mut args = Punctuated::<Ident, Comma>::new();
for (ident, _ty) in &method.args {
args.push(parse_quote! {
#ident
});
}
method_calls.push({
parse_quote! {
Impl::#name(#args)
}
});
}

parse_quote! {
impl<Impl: #trait_ident> xcq_extension::Dispatchable for Call<Impl> {
fn dispatch(self) -> Result<Vec<u8>, xcq_extension::DispatchError> {
match self {
#( #pats => Ok(#method_calls.encode()),)*
Self::__Phantom(_) => unreachable!(),
}
}
}
}
}

fn extension_id_impl(trait_ident: &Ident, trait_items: &[TraitItem]) -> ItemImpl {
let extension_id = calculate_hash(trait_ident, trait_items);
parse_quote! {
impl<Impl: #trait_ident> xcq_extension::ExtensionId for Call<Impl> {
const EXTENSION_ID: xcq_extension::ExtensionIdTy = #extension_id;
}
}
}

// helper functions
fn methods(trait_items: &[TraitItem]) -> syn::Result<Vec<Method>> {
let mut methods = vec![];
for item in trait_items {
if let TraitItem::Fn(method) = item {
let method_name = &method.sig.ident;
let mut method_args = vec![];
for arg in method.sig.inputs.iter() {
let arg = if let syn::FnArg::Typed(arg) = arg {
arg
} else {
unreachable!("every argument should be typed instead of receiver(self)")
};
let arg_ident = if let syn::Pat::Ident(pat) = &*arg.pat {
pat.ident.clone()
} else {
let msg = "Invalid call, argument must be ident";
return Err(syn::Error::new(arg.pat.span(), msg));
};
method_args.push((arg_ident, arg.ty.clone()))
}
methods.push(Method {
name: method_name.clone(),
args: method_args,
});
}
}
Ok(methods)
}

// TODO: refine this to make it more stable
fn replace_self_to_impl(ty: &syn::Type) -> syn::Result<Box<syn::Type>> {
let ty_str = quote!(#ty).to_string();

let modified_ty_str = ty_str.replace("Self", "Impl");

let modified_ty = parse_str(&modified_ty_str)?;

Ok(Box::new(modified_ty))
}

// TODO: currently we only hash on trait ident and function names,
fn calculate_hash(trait_ident: &Ident, trait_items: &[TraitItem]) -> u64 {
let mut hasher = twox_hash::XxHash64::default();
// reduce the chance of hash collision
"xcq-ext$".hash(&mut hasher);
trait_ident.hash(&mut hasher);
for trait_item in trait_items {
if let TraitItem::Fn(method) = trait_item {
// reduce the chance of hash collision
"@".hash(&mut hasher);
method.sig.ident.hash(&mut hasher);
}
}
hasher.finish()
}
Loading

0 comments on commit 7bad98c

Please sign in to comment.