From 8c509adeb72f427f3afe34d2914d09cc905e1029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Wo=C5=BAniak?= Date: Thu, 17 Oct 2024 18:15:24 +0200 Subject: [PATCH] feat: Add auto deserialization of reply data --- .../src/contract/communication/reply.rs | 58 ++++++++++++++++++- sylvia-derive/src/parser/attributes/data.rs | 48 +++++++++++++++ sylvia-derive/src/parser/attributes/mod.rs | 10 ++++ sylvia-derive/src/types/msg_field.rs | 4 ++ sylvia/tests/remote.rs | 1 + sylvia/tests/reply.rs | 4 +- sylvia/tests/reply_generation.rs | 2 +- 7 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 sylvia-derive/src/parser/attributes/data.rs diff --git a/sylvia-derive/src/contract/communication/reply.rs b/sylvia-derive/src/contract/communication/reply.rs index 0a528200..7580fcd6 100644 --- a/sylvia-derive/src/contract/communication/reply.rs +++ b/sylvia-derive/src/contract/communication/reply.rs @@ -6,7 +6,7 @@ use syn::{parse_quote, GenericParam, Ident, ItemImpl, Type}; use crate::crate_module; use crate::parser::attributes::msg::ReplyOn; -use crate::parser::{MsgType, SylviaAttribute}; +use crate::parser::{MsgType, ParsedSylviaAttributes, SylviaAttribute}; use crate::types::msg_field::MsgField; use crate::types::msg_variant::{MsgVariant, MsgVariants}; use crate::utils::emit_turbofish; @@ -190,12 +190,15 @@ struct ReplyData<'a> { pub handler_id: &'a Ident, /// Methods handling the reply id for the associated reply on. pub handlers: Vec<(&'a Ident, ReplyOn)>, + /// Data parameter associated with the handlers. + pub data: Option<&'a MsgField<'a>>, /// Payload parameters associated with the handlers. pub payload: Vec<&'a MsgField<'a>>, } impl<'a> ReplyData<'a> { pub fn new(reply_id: Ident, variant: &'a MsgVariant<'a>, handler_id: &'a Ident) -> Self { + let data = variant.fields().first(); // Skip the first field reserved for the `data`. let payload = variant.fields().iter().skip(1).collect::>(); let method_name = variant.function_name(); @@ -205,6 +208,7 @@ impl<'a> ReplyData<'a> { reply_id, handler_id, handlers: vec![(method_name, reply_on)], + data, payload, } } @@ -372,12 +376,14 @@ impl<'a> ReplyData<'a> { Some((method_name, reply_on)) if reply_on == &ReplyOn::Success => { let payload_values = self.payload.iter().map(|field| field.name()); let payload_deserialization = self.payload.emit_payload_deserialization(); + let data_deserialization = self.data.map(DataField::emit_data_deserialization); quote! { #sylvia ::cw_std::SubMsgResult::Ok(sub_msg_resp) => { #[allow(deprecated)] let #sylvia ::cw_std::SubMsgResponse { events, data, msg_responses} = sub_msg_resp; #payload_deserialization + #data_deserialization #contract_turbofish ::new(). #method_name ((deps, env, gas_used, events, msg_responses).into(), data, #(#payload_values),* ) } @@ -475,6 +481,56 @@ impl<'a> ReplyVariant<'a> for MsgVariant<'a> { } } +pub trait DataField { + fn emit_data_deserialization(&self) -> TokenStream; +} + +impl DataField for MsgField<'_> { + fn emit_data_deserialization(&self) -> TokenStream { + let sylvia = crate_module(); + let data = ParsedSylviaAttributes::new(self.attrs().iter()).data; + let missing_data_err = "Missing reply data field."; + let invalid_reply_data_err = quote! { + format! {"Invalid reply data: {}\nSerde error while deserializing {}", data, err} + }; + + match data { + Some(data) if data.raw && data.opt => quote! { + if data.is_none() { + return Err( #sylvia ::cw_std::StdError::generic_err(#missing_data_err)).map_err(Into::into); + } + }, + Some(data) if data.raw => quote! { + let data = match data { + Some(data) => data, + None => return Err( #sylvia ::cw_std::StdError::generic_err(#missing_data_err)).map_err(Into::into); + }; + }, + Some(data) if data.opt => quote! { + let data = match data { + Some(data) => + Some(#sylvia ::cw_std::from_json(&data).map_err(|err| Err( #sylvia ::cw_std::StdError::generic_err(format! {"Invalid reply data: {}\nSerde error while deserializing {}", data, err})).map_err(Into::into))?), + None => None; + }; + }, + Some(_) => quote! { + let data = match data { + Some(data) => + #sylvia ::cw_std::from_json(&data).map_err(|err| Err( #sylvia ::cw_std::StdError::generic_err( #invalid_reply_data_err )).map_err(Into::into))?, + None => return Err( #sylvia ::cw_std::StdError::generic_err(#missing_data_err)).map_err(Into::into); + }; + }, + None => { + emit_error!(self.name().span(), "Invalid data usage."; + note = "Reply data should be marked with #[sv::data] attribute."; + note = "Remove this parameter or mark it with #[sv::data] attribute." + ); + quote! {} + } + } + } +} + pub trait PayloadFields { fn emit_payload_deserialization(&self) -> TokenStream; fn emit_payload_serialization(&self) -> TokenStream; diff --git a/sylvia-derive/src/parser/attributes/data.rs b/sylvia-derive/src/parser/attributes/data.rs new file mode 100644 index 00000000..fe8e9f0c --- /dev/null +++ b/sylvia-derive/src/parser/attributes/data.rs @@ -0,0 +1,48 @@ +use proc_macro_error::emit_error; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::{Error, Ident, MetaList, Result, Token}; + +/// Type wrapping data parsed from `sv::data` attribute. +#[derive(Default, Debug)] +pub struct DataFieldParams { + pub raw: bool, + pub opt: bool, +} + +impl DataFieldParams { + pub fn new(attr: &MetaList) -> Result { + DataFieldParams::parse + .parse2(attr.tokens.clone()) + .map_err(|err| { + emit_error!(err.span(), err); + err + }) + } +} + +impl Parse for DataFieldParams { + fn parse(input: ParseStream) -> Result { + let mut data = Self::default(); + + while !input.is_empty() { + let option: Ident = input.parse()?; + match option.to_string().as_str() { + "raw" => data.raw = true, + "opt" => data.opt = true, + _ => { + return Err(Error::new( + option.span(), + "Invalid data type.\n + = note: Expected one of [`raw`, `opt`] comma separated.\n", + )) + } + } + if !input.peek(Token![,]) { + break; + } + let _: Token![,] = input.parse()?; + } + + Ok(data) + } +} diff --git a/sylvia-derive/src/parser/attributes/mod.rs b/sylvia-derive/src/parser/attributes/mod.rs index 2134e818..215bda73 100644 --- a/sylvia-derive/src/parser/attributes/mod.rs +++ b/sylvia-derive/src/parser/attributes/mod.rs @@ -1,6 +1,7 @@ //! Module defining parsing of Sylvia attributes. //! Every Sylvia attribute should be prefixed with `sv::` +use data::DataFieldParams; use features::SylviaFeatures; use proc_macro_error::emit_error; use syn::spanned::Spanned; @@ -8,6 +9,7 @@ use syn::{Attribute, MetaList, PathSegment}; pub mod attr; pub mod custom; +pub mod data; pub mod error; pub mod features; pub mod messages; @@ -33,6 +35,7 @@ pub enum SylviaAttribute { VariantAttrs, MsgAttrs, Payload, + Data, Features, } @@ -56,6 +59,7 @@ impl SylviaAttribute { "attr" => Some(Self::VariantAttrs), "msg_attr" => Some(Self::MsgAttrs), "payload" => Some(Self::Payload), + "data" => Some(Self::Data), "features" => Some(Self::Features), _ => None, } @@ -74,6 +78,7 @@ pub struct ParsedSylviaAttributes { pub variant_attrs_forward: Vec, pub msg_attrs_forward: Vec, pub sv_features: SylviaFeatures, + pub data: Option, } impl ParsedSylviaAttributes { @@ -172,6 +177,11 @@ impl ParsedSylviaAttributes { note = attr.span() => "The `sv::payload` should be used as a prefix for `Binary` payload."; ); } + SylviaAttribute::Data => { + if let Ok(data) = DataFieldParams::new(attr) { + self.data = Some(data); + } + } SylviaAttribute::Features => { if let Ok(features) = SylviaFeatures::new(attr) { self.sv_features = features; diff --git a/sylvia-derive/src/types/msg_field.rs b/sylvia-derive/src/types/msg_field.rs index 2834ed71..46bfffc4 100644 --- a/sylvia-derive/src/types/msg_field.rs +++ b/sylvia-derive/src/types/msg_field.rs @@ -121,6 +121,10 @@ impl<'a> MsgField<'a> { self.ty } + pub fn attrs(&self) -> &'a Vec { + self.attrs + } + pub fn contains_attribute(&self, sv_attr: SylviaAttribute) -> bool { self.attrs .iter() diff --git a/sylvia/tests/remote.rs b/sylvia/tests/remote.rs index f7468e66..17c3f28c 100644 --- a/sylvia/tests/remote.rs +++ b/sylvia/tests/remote.rs @@ -190,6 +190,7 @@ where } pub mod manager { + use cosmwasm_std::Binary; use cw_storage_plus::Item; use schemars::JsonSchema; use serde::de::DeserializeOwned; diff --git a/sylvia/tests/reply.rs b/sylvia/tests/reply.rs index ed478151..23569c5c 100644 --- a/sylvia/tests/reply.rs +++ b/sylvia/tests/reply.rs @@ -214,7 +214,7 @@ where fn remote_instantiated( &self, ctx: ReplyCtx, - data: Option, + #[sv::data(raw, opt)] data: Option, // TODO: Blocked by https://github.com/CosmWasm/cw-multi-test/pull/216. Uncomment when new // MultiTest version is released. // Payload is not currently forwarded in the MultiTest. @@ -236,7 +236,7 @@ where fn success( &self, ctx: ReplyCtx, - _data: Option, + #[sv::data(raw, opt)] _data: Option, #[sv::payload] _payload: Binary, ) -> Result, ContractError> { self.last_reply.save(ctx.deps.storage, &SUCCESS_REPLY_ID)?; diff --git a/sylvia/tests/reply_generation.rs b/sylvia/tests/reply_generation.rs index a4f471ca..afd65ad2 100644 --- a/sylvia/tests/reply_generation.rs +++ b/sylvia/tests/reply_generation.rs @@ -45,7 +45,7 @@ impl Contract { fn reply_on( &self, _ctx: ReplyCtx, - _data: Option, + #[sv::data(raw, opt)] _data: Option, #[sv::payload] _payload: Binary, ) -> StdResult { Ok(Response::new())