Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPL errors from hashes #5169

Merged
merged 5 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libraries/program-error/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
solana-program = "1.16.3"
syn = { version = "2.0", features = ["full"] }
61 changes: 47 additions & 14 deletions libraries/program-error/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,72 @@
extern crate proc_macro;

mod macro_impl;
mod parser;

use macro_impl::MacroType;
use proc_macro::TokenStream;
use syn::{parse_macro_input, ItemEnum};
use {
crate::parser::SplProgramErrorArgs,
macro_impl::MacroType,
proc_macro::TokenStream,
syn::{parse_macro_input, ItemEnum},
};

/// Derive macro to add `Into<solana_program::program_error::ProgramError>` traits
/// Derive macro to add `Into<solana_program::program_error::ProgramError>`
/// trait
#[proc_macro_derive(IntoProgramError)]
pub fn into_program_error(input: TokenStream) -> TokenStream {
MacroType::IntoProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
MacroType::IntoProgramError { ident }
.generate_tokens()
.into()
}

/// Derive macro to add `solana_program::decode_error::DecodeError` trait
#[proc_macro_derive(DecodeError)]
pub fn decode_error(input: TokenStream) -> TokenStream {
MacroType::DecodeError
.generate_tokens(parse_macro_input!(input as ItemEnum))
.into()
let ItemEnum { ident, .. } = parse_macro_input!(input as ItemEnum);
MacroType::DecodeError { ident }.generate_tokens().into()
}

/// Derive macro to add `solana_program::program_error::PrintProgramError` trait
#[proc_macro_derive(PrintProgramError)]
pub fn print_program_error(input: TokenStream) -> TokenStream {
MacroType::PrintProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
let ItemEnum {
ident, variants, ..
} = parse_macro_input!(input as ItemEnum);
MacroType::PrintProgramError { ident, variants }
.generate_tokens()
.into()
}

/// Proc macro attribute to turn your enum into a Solana Program Error
///
/// Adds:
/// - `Clone`
/// - `Debug`
/// - `Eq`
/// - `PartialEq`
/// - `thiserror::Error`
/// - `num_derive::FromPrimitive`
/// - `Into<solana_program::program_error::ProgramError>`
/// - `solana_program::decode_error::DecodeError`
/// - `solana_program::program_error::PrintProgramError`
///
/// Optionally, you can add `hash_error_code_start: u32` argument to create
/// a unique `u32` _starting_ error codes from the names of the enum variants.
/// Notes:
/// - The _error_ variant will start at this value, and the rest will be
/// incremented by one
/// - The value provided is only for code readability, the actual error code
/// will be a hash of the input string and is checked against your input
///
/// Syntax: `#[spl_program_error(hash_error_code_start = 1275525928)]`
/// Hash Input: `spl_program_error:<enum name>:<variant name>`
buffalojoec marked this conversation as resolved.
Show resolved Hide resolved
/// Value: `u32::from_le_bytes(<hash of input>[13..17])`
#[proc_macro_attribute]
pub fn spl_program_error(_: TokenStream, input: TokenStream) -> TokenStream {
MacroType::SplProgramError
.generate_tokens(parse_macro_input!(input as ItemEnum))
pub fn spl_program_error(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as SplProgramErrorArgs);
let item_enum = parse_macro_input!(input as ItemEnum);
MacroType::SplProgramError { args, item_enum }
.generate_tokens()
.into()
}
122 changes: 102 additions & 20 deletions libraries/program-error/derive/src/macro_impl.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,52 @@
//! The actual token generator for the macro
use quote::quote;
use syn::{punctuated::Punctuated, token::Comma, Ident, ItemEnum, LitStr, Variant};

use {
crate::parser::SplProgramErrorArgs,
proc_macro2::Span,
quote::quote,
syn::{
punctuated::Punctuated, token::Comma, Expr, ExprLit, Ident, ItemEnum, Lit, LitInt, LitStr,
Token, Variant,
},
};

const SPL_ERROR_HASH_NAMESPACE: &str = "spl_program_error";
const SPL_ERROR_HASH_MIN_VALUE: u32 = 7_0000;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha I just realized that this has an extra 0, should be 7_000. Good thing we have multiple rounds of reviews

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! my bad


/// The type of macro being called, thus directing which tokens to generate
#[allow(clippy::enum_variant_names)]
pub enum MacroType {
IntoProgramError,
DecodeError,
PrintProgramError,
SplProgramError,
IntoProgramError {
ident: Ident,
},
DecodeError {
ident: Ident,
},
PrintProgramError {
ident: Ident,
variants: Punctuated<Variant, Comma>,
},
SplProgramError {
args: SplProgramErrorArgs,
item_enum: ItemEnum,
},
}

impl MacroType {
/// Generates the corresponding tokens based on variant selection
pub fn generate_tokens(&self, item_enum: ItemEnum) -> proc_macro2::TokenStream {
pub fn generate_tokens(&mut self) -> proc_macro2::TokenStream {
match self {
MacroType::IntoProgramError => into_program_error(&item_enum.ident),
MacroType::DecodeError => decode_error(&item_enum.ident),
MacroType::PrintProgramError => {
print_program_error(&item_enum.ident, &item_enum.variants)
}
MacroType::SplProgramError => spl_program_error(item_enum),
Self::IntoProgramError { ident } => into_program_error(ident),
Self::DecodeError { ident } => decode_error(ident),
Self::PrintProgramError { ident, variants } => print_program_error(ident, variants),
Self::SplProgramError { args, item_enum } => spl_program_error(args, item_enum),
}
}
}

/// Builds the implementation of `Into<solana_program::program_error::ProgramError>`
/// More specifically, implements `From<Self> for solana_program::program_error::ProgramError`
/// Builds the implementation of
/// `Into<solana_program::program_error::ProgramError>` More specifically,
/// implements `From<Self> for solana_program::program_error::ProgramError`
pub fn into_program_error(ident: &Ident) -> proc_macro2::TokenStream {
quote! {
impl From<#ident> for solana_program::program_error::ProgramError {
Expand All @@ -48,7 +68,8 @@ pub fn decode_error(ident: &Ident) -> proc_macro2::TokenStream {
}
}

/// Builds the implementation of `solana_program::program_error::PrintProgramError`
/// Builds the implementation of
/// `solana_program::program_error::PrintProgramError`
pub fn print_program_error(
ident: &Ident,
variants: &Punctuated<Variant, Comma>,
Expand Down Expand Up @@ -96,16 +117,25 @@ fn get_error_message(variant: &Variant) -> Option<String> {

/// The main function that produces the tokens required to turn your
/// error enum into a Solana Program Error
pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
let ident = &input.ident;
let variants = &input.variants;
pub fn spl_program_error(
args: &SplProgramErrorArgs,
item_enum: &mut ItemEnum,
) -> proc_macro2::TokenStream {
if let Some(error_code_start) = args.hash_error_code_start {
set_first_discriminant(item_enum, error_code_start);
}

let ident = &item_enum.ident;
let variants = &item_enum.variants;
let into_program_error = into_program_error(ident);
let decode_error = decode_error(ident);
let print_program_error = print_program_error(ident, variants);

quote! {
#[repr(u32)]
#[derive(Clone, Debug, Eq, thiserror::Error, num_derive::FromPrimitive, PartialEq)]
#[num_traits = "num_traits"]
#input
#item_enum

#into_program_error

Expand All @@ -114,3 +144,55 @@ pub fn spl_program_error(input: ItemEnum) -> proc_macro2::TokenStream {
#print_program_error
}
}

/// This function adds a discriminant to the first enum variant based on the
/// hash of the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
/// name.
/// It will then check to make sure the provided `hash_error_code_start` is
/// equal to the hash-produced `u32`.
///
/// See https://docs.rs/syn/latest/syn/struct.Variant.html
fn set_first_discriminant(item_enum: &mut ItemEnum, error_code_start: u32) {
let enum_ident = &item_enum.ident;
if item_enum.variants.is_empty() {
panic!("Enum must have at least one variant");
}
let first_variant = &mut item_enum.variants[0];
let discriminant = u32_from_hash(enum_ident);
if discriminant == error_code_start {
let eq = Token![=](Span::call_site());
let expr = Expr::Lit(ExprLit {
attrs: Vec::new(),
lit: Lit::Int(LitInt::new(&discriminant.to_string(), Span::call_site())),
});
first_variant.discriminant = Some((eq, expr));
} else {
panic!(
"Error code start value from hash must be {0}. Update your macro attribute to \
`#[spl_program_error(hash_error_code_start = {0})]`.",
discriminant
);
}
}

/// Hashes the `SPL_ERROR_HASH_NAMESPACE` constant, the enum name and variant
/// name and returns four middle bytes (13 through 16) as a u32.
fn u32_from_hash(enum_ident: &Ident) -> u32 {
let hash_input = format!("{}:{}", SPL_ERROR_HASH_NAMESPACE, enum_ident);

// We don't want our error code to start at any number below
// `SPL_ERROR_HASH_MIN_VALUE`!
let mut nonce: u32 = 0;
buffalojoec marked this conversation as resolved.
Show resolved Hide resolved
loop {
let hash = solana_program::hash::hashv(&[hash_input.as_bytes(), &nonce.to_le_bytes()]);
let d = u32::from_le_bytes(
hash.to_bytes()[13..17]
.try_into()
.expect("Unable to convert hash to u32"),
);
if d >= SPL_ERROR_HASH_MIN_VALUE {
return d;
}
nonce += 1;
}
}
64 changes: 64 additions & 0 deletions libraries/program-error/derive/src/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//! Token parsing

use {
proc_macro2::Ident,
syn::{
parse::{Parse, ParseStream},
token::Comma,
LitInt, Token,
},
};

/// Possible arguments to the `#[spl_program_error]` attribute
pub struct SplProgramErrorArgs {
/// Whether to hash the error codes using `solana_program::hash`
/// or to use the default error code assigned by `num_traits`.
pub hash_error_code_start: Option<u32>,
}

impl Parse for SplProgramErrorArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
if input.is_empty() {
return Ok(Self {
hash_error_code_start: None,
});
}
match SplProgramErrorArgParser::parse(input)? {
SplProgramErrorArgParser::HashErrorCodes { value, .. } => Ok(Self {
hash_error_code_start: Some(value.base10_parse::<u32>()?),
}),
}
}
}

/// Parser for args to the `#[spl_program_error]` attribute
/// ie. `#[spl_program_error(hash_error_code_start = 1275525928)]`
enum SplProgramErrorArgParser {
HashErrorCodes {
_ident: Ident,
_equals_sign: Token![=],
value: LitInt,
_comma: Option<Comma>,
},
}

impl Parse for SplProgramErrorArgParser {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _ident = {
let ident = input.parse::<Ident>()?;
if ident != "hash_error_code_start" {
return Err(input.error("Expected argument 'hash_error_code_start'"));
}
ident
};
let _equals_sign = input.parse::<Token![=]>()?;
let value = input.parse::<LitInt>()?;
let _comma: Option<Comma> = input.parse().unwrap_or(None);
Ok(Self::HashErrorCodes {
_ident,
_equals_sign,
value,
_comma,
})
}
}
12 changes: 6 additions & 6 deletions libraries/program-error/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ extern crate self as spl_program_error;

// Make these available downstream for the macro to work without
// additional imports
pub use num_derive;
pub use num_traits;
pub use solana_program;
pub use spl_program_error_derive::{
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
pub use {
num_derive, num_traits, solana_program,
spl_program_error_derive::{
spl_program_error, DecodeError, IntoProgramError, PrintProgramError,
},
thiserror,
};
pub use thiserror;
2 changes: 1 addition & 1 deletion libraries/program-error/tests/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(DecodeError)]`
//!

use spl_program_error::*;

/// Example error
Expand Down
2 changes: 1 addition & 1 deletion libraries/program-error/tests/into.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Tests `#[derive(IntoProgramError)]`
//!

use spl_program_error::*;

/// Example error
Expand Down
14 changes: 8 additions & 6 deletions libraries/program-error/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ pub mod spl;

#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use solana_program::{
decode_error::DecodeError,
program_error::{PrintProgramError, ProgramError},
use {
super::*,
serial_test::serial,
solana_program::{
decode_error::DecodeError,
program_error::{PrintProgramError, ProgramError},
},
std::sync::{Arc, RwLock},
};
use std::sync::{Arc, RwLock};

// Used to capture output for `PrintProgramError` for testing
lazy_static::lazy_static! {
Expand Down
Loading
Loading