Skip to content

Commit

Permalink
SPL errors from hashes
Browse files Browse the repository at this point in the history
  • Loading branch information
buffalojoec committed Aug 30, 2023
1 parent 8ae351d commit 4ef69d7
Show file tree
Hide file tree
Showing 17 changed files with 258 additions and 108 deletions.
6 changes: 2 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libraries/program-error/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
solana-program = "1.16.3"
syn = { version = "2.0", features = ["full"] }
56 changes: 42 additions & 14 deletions libraries/program-error/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,67 @@
extern crate proc_macro;

mod macro_impl;
mod parser;

use macro_impl::MacroType;
use proc_macro::TokenStream;
use syn::{parse_macro_input, ItemEnum};
use {
crate::parser::SplProgramErrorArgs,
macro_impl::MacroType,
proc_macro::TokenStream,
syn::{parse_macro_input, ItemEnum},
};

/// Derive macro to add `Into<solana_program::program_error::ProgramError>` traits
/// Derive macro to add `Into<solana_program::program_error::ProgramError>`
/// trait
#[proc_macro_derive(IntoProgramError)]
pub fn into_program_error(input: TokenStream) -> TokenStream {
MacroType::IntoProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
MacroType::IntoProgramError { ident }
.generate_tokens()
.into()
}

/// Derive macro to add `solana_program::decode_error::DecodeError` trait
#[proc_macro_derive(DecodeError)]
pub fn decode_error(input: TokenStream) -> TokenStream {
MacroType::DecodeError
.generate_tokens(parse_macro_input!(input as ItemEnum))
.into()
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
MacroType::DecodeError { ident }.generate_tokens().into()
}

/// Derive macro to add `solana_program::program_error::PrintProgramError` trait
#[proc_macro_derive(PrintProgramError)]
pub fn print_program_error(input: TokenStream) -> TokenStream {
MacroType::PrintProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
let ItemEnum {
ident, variants, ..
} = parse_macro_input!(input as ItemEnum);
MacroType::PrintProgramError { ident, variants }
.generate_tokens()
.into()
}

/// Proc macro attribute to turn your enum into a Solana Program Error
///
/// Adds:
/// - `Clone`
/// - `Debug`
/// - `Eq`
/// - `PartialEq`
/// - `thiserror::Error`
/// - `num_derive::FromPrimitive`
/// - `Into<solana_program::program_error::ProgramError>`
/// - `solana_program::decode_error::DecodeError`
/// - `solana_program::program_error::PrintProgramError`
///
/// Optionally, you can add `hash_error_codes: bool` argument to create unique
/// `u32` error codes from the names of the enum variants.
///
/// Syntax: `#[spl_program_error(hash_error_codes = true)]`
/// Hash Input: `spl_program_error:<enum name>:<variant name>`
/// Value: `u32::from_le_bytes(<hash of input>[8..12])`
#[proc_macro_attribute]
pub fn spl_program_error(_: TokenStream, input: TokenStream) -> TokenStream {
MacroType::SplProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
pub fn spl_program_error(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as SplProgramErrorArgs);
let item_enum = parse_macro_input!(input as ItemEnum);
MacroType::SplProgramError { args, item_enum }
.generate_tokens()
.into()
}
103 changes: 83 additions & 20 deletions libraries/program-error/derive/src/macro_impl.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,51 @@
//! The actual token generator for the macro
use quote::quote;
use syn::{punctuated::Punctuated, token::Comma, Ident, ItemEnum, LitStr, Variant};

use {
crate::parser::SplProgramErrorArgs,
proc_macro2::Span,
quote::quote,
syn::{
punctuated::Punctuated, token::Comma, Expr, ExprLit, Ident, ItemEnum, Lit, LitInt, LitStr,
Token, Variant,
},
};

const SPL_ERROR_HASH_NAMESPACE: &str = "spl_program_error";

/// The type of macro being called, thus directing which tokens to generate
#[allow(clippy::enum_variant_names)]
pub enum MacroType {
IntoProgramError,
DecodeError,
PrintProgramError,
SplProgramError,
IntoProgramError {
ident: Ident,
},
DecodeError {
ident: Ident,
},
PrintProgramError {
ident: Ident,
variants: Punctuated<Variant, Comma>,
},
SplProgramError {
args: SplProgramErrorArgs,
item_enum: ItemEnum,
},
}

impl MacroType {
/// Generates the corresponding tokens based on variant selection
pub fn generate_tokens(&self, item_enum: ItemEnum) -> proc_macro2::TokenStream {
pub fn generate_tokens(&mut self) -> proc_macro2::TokenStream {
match self {
MacroType::IntoProgramError => into_program_error(&item_enum.ident),
MacroType::DecodeError => decode_error(&item_enum.ident),
MacroType::PrintProgramError => {
print_program_error(&item_enum.ident, &item_enum.variants)
}
MacroType::SplProgramError => spl_program_error(item_enum),
Self::IntoProgramError { ident } => into_program_error(ident),
Self::DecodeError { ident } => decode_error(ident),
Self::PrintProgramError { ident, variants } => print_program_error(ident, variants),
Self::SplProgramError { args, item_enum } => spl_program_error(args, item_enum),
}
}
}

/// Builds the implementation of `Into<solana_program::program_error::ProgramError>`
/// More specifically, implements `From<Self> for solana_program::program_error::ProgramError`
/// Builds the implementation of
/// `Into<solana_program::program_error::ProgramError>` More specifically,
/// implements `From<Self> for solana_program::program_error::ProgramError`
pub fn into_program_error(ident: &Ident) -> proc_macro2::TokenStream {
quote! {
impl From<#ident> for solana_program::program_error::ProgramError {
Expand All @@ -48,7 +67,8 @@ pub fn decode_error(ident: &Ident) -> proc_macro2::TokenStream {
}
}

/// Builds the implementation of `solana_program::program_error::PrintProgramError`
/// Builds the implementation of
/// `solana_program::program_error::PrintProgramError`
pub fn print_program_error(
ident: &Ident,
variants: &Punctuated<Variant, Comma>,
Expand Down Expand Up @@ -96,16 +116,25 @@ fn get_error_message(variant: &Variant) -> Option<String> {

/// The main function that produces the tokens required to turn your
/// error enum into a Solana Program Error
pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
let ident = &input.ident;
let variants = &input.variants;
pub fn spl_program_error(
args: &SplProgramErrorArgs,
item_enum: &mut ItemEnum,
) -> proc_macro2::TokenStream {
if args.hash_error_codes {
build_discriminants(item_enum);
}

let ident = &item_enum.ident;
let variants = &item_enum.variants;
let into_program_error = into_program_error(ident);
let decode_error = decode_error(ident);
let print_program_error = print_program_error(ident, variants);

quote! {
#[repr(u32)]
#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)]
#[num_traits = "num_traits"]
#input
#item_enum

#into_program_error

Expand All @@ -114,3 +143,37 @@ pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
#print_program_error
}
}

/// This function adds discriminants to the enum variants based on the
/// hash of the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
/// name.
///
/// See https://docs.rs/syn/latest/syn/struct.Variant.html
fn build_discriminants(item_enum: &mut ItemEnum) {
let enum_ident = &item_enum.ident;
for variant in item_enum.variants.iter_mut() {
let variant_ident = &variant.ident;
let discriminant = u32_from_hash(enum_ident, variant_ident);
let eq = Token![=](Span::call_site());
let expr = Expr::Lit(ExprLit {
attrs: Vec::new(),
lit: Lit::Int(LitInt::new(&discriminant.to_string(), Span::call_site())),
});
variant.discriminant = Some((eq, expr));
}
}

/// Hashes the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
/// name and returns four middle bytes (8 through 12) as a u32.
fn u32_from_hash(enum_ident: &Ident, variant_ident: &Ident) -> u32 {
let hash_input = format!(
"{}:{}:{}",
SPL_ERROR_HASH_NAMESPACE, enum_ident, variant_ident
);
let hash = solana_program::hash::hash(hash_input.as_bytes());
u32::from_le_bytes(
hash.to_bytes()[13..17]
.try_into()
.expect("Unable to convert hash to u32"),
)
}
56 changes: 56 additions & 0 deletions libraries/program-error/derive/src/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use {
proc_macro2::Ident,
syn::{
parse::{Parse, ParseStream},
token::Comma,
LitBool, Token,
},
};

/// Possible arguments to the `#[spl_program_error]` attribute
pub struct SplProgramErrorArgs {
/// Whether to hash the error codes using `solana_program::hash`
/// or to use the default error code assigned by `num_traits`.
pub hash_error_codes: bool,
}

impl Parse for SplProgramErrorArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(Self {
hash_error_codes: false,
});
}
match SplProgramErrorArgParser::parse(input)? {
SplProgramErrorArgParser::HashErrorCodes { value, .. } => Ok(Self {
hash_error_codes: value.value,
}),
}
}
}

/// Parser for args to the `#[spl_program_error]` attribute
/// ie. `#[spl_program_error(hash_error_codes = true)]`
enum SplProgramErrorArgParser {
HashErrorCodes {
_ident: Ident,
_equals_sign: Token![=],
value: LitBool,
_comma: Option<Comma>,
},
}

impl Parse for SplProgramErrorArgParser {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _ident = input.parse::<Ident>()?;
let _equals_sign = input.parse::<Token![=]>()?;
let value = input.parse::<LitBool>()?;
let _comma: Option<Comma> = input.parse().unwrap_or(None);
Ok(Self::HashErrorCodes {
_ident,
_equals_sign,
value,
_comma,
})
}
}
12 changes: 6 additions & 6 deletions libraries/program-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ extern crate self as spl_program_error;

// Make these available downstream for the macro to work without
// additional imports
pub use num_derive;
pub use num_traits;
pub use solana_program;
pub use spl_program_error_derive::{
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
pub use {
num_derive, num_traits, solana_program,
spl_program_error_derive::{
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
},
thiserror,
};
pub use thiserror;
2 changes: 1 addition & 1 deletion libraries/program-error/tests/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(DecodeError)]`
//!

use spl_program_error::*;

/// Example error
Expand Down
2 changes: 1 addition & 1 deletion libraries/program-error/tests/into.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(IntoProgramError)]`
//!

use spl_program_error::*;

/// Example error
Expand Down
14 changes: 8 additions & 6 deletions libraries/program-error/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ pub mod spl;

#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use solana_program::{
decode_error::DecodeError,
program_error::{PrintProgramError, ProgramError},
use {
super::*,
serial_test::serial,
solana_program::{
decode_error::DecodeError,
program_error::{PrintProgramError, ProgramError},
},
std::sync::{Arc, RwLock},
};
use std::sync::{Arc, RwLock};

// Used to capture output for `PrintProgramError` for testing
lazy_static::lazy_static! {
Expand Down
2 changes: 1 addition & 1 deletion libraries/program-error/tests/print.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(PrintProgramError)]`
//!

use spl_program_error::*;

/// Example error
Expand Down
Loading

0 comments on commit 4ef69d7

Please sign in to comment.