Skip to content

Commit

Permalink
feat: Support omitting data parameter
Browse files Browse the repository at this point in the history
Expect `(raw)` parameter in the `sv::payload` attribute.
  • Loading branch information
jawoznia committed Oct 31, 2024
1 parent 9c35162 commit 3c89bcc
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 147 deletions.
72 changes: 49 additions & 23 deletions sylvia-derive/src/contract/communication/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use syn::{parse_quote, GenericParam, Ident, ItemImpl, Type};

use crate::crate_module;
use crate::parser::attributes::msg::ReplyOn;
use crate::parser::{MsgType, ParsedSylviaAttributes, SylviaAttribute};
use crate::parser::{MsgType, ParsedSylviaAttributes};
use crate::types::msg_field::MsgField;
use crate::types::msg_variant::{MsgVariant, MsgVariants};
use crate::utils::emit_turbofish;

const NUMBER_OF_DATA_FIELDS: usize = 1;

pub struct Reply<'a> {
source: &'a ItemImpl,
generics: &'a [&'a GenericParam],
Expand Down Expand Up @@ -173,7 +175,7 @@ impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> {
},
)
}
Some(existing_data) => existing_data.add_second_handler(handler),
Some(existing_data) => existing_data.merge(handler),
None => reply_data.push(ReplyData::new(reply_id, handler, handler_id)),
}
});
Expand All @@ -198,9 +200,14 @@ struct ReplyData<'a> {

impl<'a> ReplyData<'a> {
pub fn new(reply_id: Ident, variant: &'a MsgVariant<'a>, handler_id: &'a Ident) -> Self {
let data = variant.fields().first();
// Skip the first field reserved for the `data`.
let payload = variant.fields().iter().skip(1).collect::<Vec<_>>();
let data = variant.as_data_field();
variant.validate_fields_attributes();
let payload = variant.fields().iter();
let payload = if data.is_some() || variant.msg_attr().reply_on() != ReplyOn::Success {
payload.skip(NUMBER_OF_DATA_FIELDS).collect::<Vec<_>>()
} else {
payload.collect::<Vec<_>>()
};
let method_name = variant.function_name();
let reply_on = variant.msg_attr().reply_on();

Expand All @@ -214,13 +221,15 @@ impl<'a> ReplyData<'a> {
}

/// Adds second handler to the reply data provdided their payload signature match.
pub fn add_second_handler(&mut self, new_handler: &'a MsgVariant<'a>) {
pub fn merge(&mut self, new_handler: &'a MsgVariant<'a>) {
let (current_method_name, _) = match self.handlers.first() {
Some(handler) => handler,
_ => return,
};

if self.payload.len() != new_handler.fields().len() - 1 {
let new_reply_data = ReplyData::new(self.reply_id.clone(), new_handler, self.handler_id);

if self.payload.len() != new_reply_data.payload.len() {
emit_error!(current_method_name.span(), "Mismatched quantity of method parameters.";
note = self.handler_id.span() => format!("Both `{}` handlers should have the same number of parameters.", self.handler_id);
note = new_handler.function_name().span() => format!("Previous definition of {} handler.", self.handler_id)
Expand All @@ -229,7 +238,7 @@ impl<'a> ReplyData<'a> {

self.payload
.iter()
.zip(new_handler.fields().iter().skip(1))
.zip(new_reply_data.payload.iter())
.for_each(|(current_field, new_field)|
{
if current_field.ty() != new_field.ty() {
Expand Down Expand Up @@ -377,6 +386,7 @@ impl<'a> ReplyData<'a> {
let payload_values = self.payload.iter().map(|field| field.name());
let payload_deserialization = self.payload.emit_payload_deserialization();
let data_deserialization = self.data.map(DataField::emit_data_deserialization);
let data = self.data.map(|_| quote! { data, });

quote! {
#sylvia ::cw_std::SubMsgResult::Ok(sub_msg_resp) => {
Expand All @@ -385,7 +395,7 @@ impl<'a> ReplyData<'a> {
#payload_deserialization
#data_deserialization

#contract_turbofish ::new(). #method_name ((deps, env, gas_used, events, msg_responses).into(), data, #(#payload_values),* )
#contract_turbofish ::new(). #method_name ((deps, env, gas_used, events, msg_responses).into(), #data #(#payload_values),* )
}
}
}
Expand Down Expand Up @@ -462,6 +472,8 @@ impl<'a> ReplyData<'a> {

trait ReplyVariant<'a> {
fn as_variant_handlers_pair(&'a self) -> Vec<(&'a MsgVariant<'a>, &'a Ident)>;
fn as_data_field(&'a self) -> Option<&'a MsgField<'a>>;
fn validate_fields_attributes(&'a self);
}

impl<'a> ReplyVariant<'a> for MsgVariant<'a> {
Expand All @@ -479,6 +491,28 @@ impl<'a> ReplyVariant<'a> for MsgVariant<'a> {

variant_handler_id_pair
}

/// Validates attributes and returns `Some(MsgField)` if a field marked with `sv::data` attribute
/// is present and the `reply_on` attribute is set to `ReplyOn::Success`.
fn as_data_field(&'a self) -> Option<&'a MsgField<'a>> {
let data_attrs = self.fields().first().map(|field| {
ParsedSylviaAttributes::new(field.attrs().iter())
.data
.is_some()
});
match data_attrs {
Some(attrs) if attrs && self.msg_attr().reply_on() == ReplyOn::Success => {
self.fields().first()
}
_ => None,
}
}

/// Validates if the fields attributes are correct.
fn validate_fields_attributes(&'a self) {
let field_attrs = self.fields().iter().flat_map(|field| field.attrs());
ParsedSylviaAttributes::new(field_attrs);
}
}

pub trait DataField {
Expand All @@ -489,10 +523,6 @@ impl DataField for MsgField<'_> {
fn emit_data_deserialization(&self) -> TokenStream {
let sylvia = crate_module();
let data = ParsedSylviaAttributes::new(self.attrs().iter()).data;
let is_data_attr = self
.attrs()
.iter()
.any(|attr| SylviaAttribute::new(attr) == Some(SylviaAttribute::Data));
let missing_data_err = "Missing reply data field.";
let invalid_reply_data_err = quote! {
format! {"Invalid reply data: {}\nSerde error while deserializing {}", data, err}
Expand Down Expand Up @@ -555,7 +585,7 @@ impl DataField for MsgField<'_> {
None => None,
};
},
None if is_data_attr => quote! {
_ => quote! {
let data = match data {
Some(data) => {
#execute_data_deserialization
Expand All @@ -565,13 +595,6 @@ impl DataField for MsgField<'_> {
None => return Err(Into::into( #sylvia ::cw_std::StdError::generic_err( #missing_data_err ))),
};
},
_ => {
emit_error!(self.name().span(), "Invalid data usage.";
note = "Reply data should be marked with #[sv::data] attribute.";
note = "Remove this parameter or mark it with #[sv::data] attribute."
);
quote! {}
}
}
}
}
Expand Down Expand Up @@ -616,8 +639,11 @@ impl PayloadFields for Vec<&MsgField<'_>> {
}

fn is_payload_marked(&self) -> bool {
self.iter()
.any(|field| field.contains_attribute(SylviaAttribute::Payload))
self.iter().any(|field| {
ParsedSylviaAttributes::new(field.attrs().iter())
.payload
.is_some()
})
}
}

Expand Down
18 changes: 14 additions & 4 deletions sylvia-derive/src/parser/attributes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use data::DataFieldParams;
use features::SylviaFeatures;
use payload::PayloadFieldParam;
use proc_macro_error::emit_error;
use syn::spanned::Spanned;
use syn::{Attribute, MetaList, PathSegment};
Expand All @@ -15,6 +16,7 @@ pub mod features;
pub mod messages;
pub mod msg;
pub mod override_entry_point;
pub mod payload;

pub use attr::{MsgAttrForwarding, VariantAttrForwarding};
pub use custom::Custom;
Expand Down Expand Up @@ -79,6 +81,7 @@ pub struct ParsedSylviaAttributes {
pub msg_attrs_forward: Vec<MsgAttrForwarding>,
pub sv_features: SylviaFeatures,
pub data: Option<DataFieldParams>,
pub payload: Option<PayloadFieldParam>,
}

impl ParsedSylviaAttributes {
Expand All @@ -90,6 +93,14 @@ impl ParsedSylviaAttributes {

if let (Some(sylvia_attr), Ok(attr)) = (sylvia_attr, &attr_content) {
result.match_attribute(&sylvia_attr, attr);
} else if sylvia_attr == Some(SylviaAttribute::Data) {
// The `sv::data` attribute can be used without parameters.
result.data = Some(DataFieldParams::default());
} else if sylvia_attr == Some(SylviaAttribute::Payload) {
emit_error!(
attr.span(), "Missing parameters for `sv::payload`";
note = "Expected `#[sv::payload(raw)]`"
);
}
}

Expand Down Expand Up @@ -172,10 +183,9 @@ impl ParsedSylviaAttributes {
}
}
SylviaAttribute::Payload => {
emit_error!(
attr, "The attribute `sv::payload` used in wrong context";
note = attr.span() => "The `sv::payload` should be used as a prefix for `Binary` payload.";
);
if let Ok(payload) = PayloadFieldParam::new(attr) {
self.payload = Some(payload);
}
}
SylviaAttribute::Data => {
if let Ok(data) = DataFieldParams::new(attr) {
Expand Down
44 changes: 44 additions & 0 deletions sylvia-derive/src/parser/attributes/payload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use proc_macro_error::emit_error;
use syn::parse::{Parse, ParseStream, Parser};
use syn::{Error, Ident, MetaList, Result};

/// Type wrapping data parsed from `sv::payload` attribute.
#[derive(Default, Debug)]
pub struct PayloadFieldParam;

impl PayloadFieldParam {
pub fn new(attr: &MetaList) -> Result<Self> {
let data = PayloadFieldParam::parse
.parse2(attr.tokens.clone())
.map_err(|err| {
emit_error!(err.span(), err);
err
})?;

Ok(data)
}
}

impl Parse for PayloadFieldParam {
fn parse(input: ParseStream) -> Result<Self> {
let option: Ident = input.parse()?;
match option.to_string().as_str() {
"raw" => (),
_ => {
return Err(Error::new(
option.span(),
"Invalid payload parameter.\n= note: Expected [`raw`].\n",
))
}
};

if !input.is_empty() {
return Err(Error::new(
input.span(),
"Unexpected tokens inside `sv::payload` attribute.\n= note: Expected parameters: [`raw`] `.\n",
));
}

Ok(Self)
}
}
7 changes: 0 additions & 7 deletions sylvia-derive/src/types/msg_field.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::fold::StripSelfPath;
use crate::parser::check_generics::{CheckGenerics, GetPath};
use crate::parser::SylviaAttribute;
use proc_macro2::TokenStream;
use proc_macro_error::emit_error;
use quote::quote;
Expand Down Expand Up @@ -124,10 +123,4 @@ impl<'a> MsgField<'a> {
pub fn attrs(&self) -> &'a Vec<Attribute> {
self.attrs
}

pub fn contains_attribute(&self, sv_attr: SylviaAttribute) -> bool {
self.attrs
.iter()
.any(|attr| SylviaAttribute::new(attr) == Some(sv_attr))
}
}
Loading

0 comments on commit 3c89bcc

Please sign in to comment.