@@ -6,11 +6,13 @@ use syn::{parse_quote, GenericParam, Ident, ItemImpl, Type};
6
6
7
7
use crate :: crate_module;
8
8
use crate :: parser:: attributes:: msg:: ReplyOn ;
9
- use crate :: parser:: { MsgType , ParsedSylviaAttributes , SylviaAttribute } ;
9
+ use crate :: parser:: { MsgType , ParsedSylviaAttributes } ;
10
10
use crate :: types:: msg_field:: MsgField ;
11
11
use crate :: types:: msg_variant:: { MsgVariant , MsgVariants } ;
12
12
use crate :: utils:: emit_turbofish;
13
13
14
+ const NUMBER_OF_DATA_FIELDS : usize = 1 ;
15
+
14
16
pub struct Reply < ' a > {
15
17
source : & ' a ItemImpl ,
16
18
generics : & ' a [ & ' a GenericParam ] ,
@@ -173,7 +175,7 @@ impl<'a> ReplyVariants<'a> for MsgVariants<'a, GenericParam> {
173
175
} ,
174
176
)
175
177
}
176
- Some ( existing_data) => existing_data. add_second_handler ( handler) ,
178
+ Some ( existing_data) => existing_data. merge ( handler) ,
177
179
None => reply_data. push ( ReplyData :: new ( reply_id, handler, handler_id) ) ,
178
180
}
179
181
} ) ;
@@ -198,9 +200,13 @@ struct ReplyData<'a> {
198
200
199
201
impl < ' a > ReplyData < ' a > {
200
202
pub fn new ( reply_id : Ident , variant : & ' a MsgVariant < ' a > , handler_id : & ' a Ident ) -> Self {
201
- let data = variant. fields ( ) . first ( ) ;
202
- // Skip the first field reserved for the `data`.
203
- let payload = variant. fields ( ) . iter ( ) . skip ( 1 ) . collect :: < Vec < _ > > ( ) ;
203
+ let data = variant. as_data_field ( ) ;
204
+ let payload = variant. fields ( ) . iter ( ) ;
205
+ let payload = if data. is_some ( ) || variant. msg_attr ( ) . reply_on ( ) != ReplyOn :: Success {
206
+ payload. skip ( NUMBER_OF_DATA_FIELDS ) . collect :: < Vec < _ > > ( )
207
+ } else {
208
+ payload. collect :: < Vec < _ > > ( )
209
+ } ;
204
210
let method_name = variant. function_name ( ) ;
205
211
let reply_on = variant. msg_attr ( ) . reply_on ( ) ;
206
212
@@ -214,13 +220,15 @@ impl<'a> ReplyData<'a> {
214
220
}
215
221
216
222
/// Adds second handler to the reply data provdided their payload signature match.
217
- pub fn add_second_handler ( & mut self , new_handler : & ' a MsgVariant < ' a > ) {
223
+ pub fn merge ( & mut self , new_handler : & ' a MsgVariant < ' a > ) {
218
224
let ( current_method_name, _) = match self . handlers . first ( ) {
219
225
Some ( handler) => handler,
220
226
_ => return ,
221
227
} ;
222
228
223
- if self . payload . len ( ) != new_handler. fields ( ) . len ( ) - 1 {
229
+ let new_reply_data = ReplyData :: new ( self . reply_id . clone ( ) , new_handler, self . handler_id ) ;
230
+
231
+ if self . payload . len ( ) != new_reply_data. payload . len ( ) {
224
232
emit_error ! ( current_method_name. span( ) , "Mismatched quantity of method parameters." ;
225
233
note = self . handler_id. span( ) => format!( "Both `{}` handlers should have the same number of parameters." , self . handler_id) ;
226
234
note = new_handler. function_name( ) . span( ) => format!( "Previous definition of {} handler." , self . handler_id)
@@ -229,7 +237,7 @@ impl<'a> ReplyData<'a> {
229
237
230
238
self . payload
231
239
. iter ( )
232
- . zip ( new_handler . fields ( ) . iter ( ) . skip ( 1 ) )
240
+ . zip ( new_reply_data . payload . iter ( ) )
233
241
. for_each ( |( current_field, new_field) |
234
242
{
235
243
if current_field. ty ( ) != new_field. ty ( ) {
@@ -377,6 +385,7 @@ impl<'a> ReplyData<'a> {
377
385
let payload_values = self . payload . iter ( ) . map ( |field| field. name ( ) ) ;
378
386
let payload_deserialization = self . payload . emit_payload_deserialization ( ) ;
379
387
let data_deserialization = self . data . map ( DataField :: emit_data_deserialization) ;
388
+ let data = self . data . map ( |_| quote ! { data, } ) ;
380
389
381
390
quote ! {
382
391
#sylvia :: cw_std:: SubMsgResult :: Ok ( sub_msg_resp) => {
@@ -385,7 +394,7 @@ impl<'a> ReplyData<'a> {
385
394
#payload_deserialization
386
395
#data_deserialization
387
396
388
- #contract_turbofish :: new( ) . #method_name ( ( deps, env, gas_used, events, msg_responses) . into( ) , data, #( #payload_values) , * )
397
+ #contract_turbofish :: new( ) . #method_name ( ( deps, env, gas_used, events, msg_responses) . into( ) , # data #( #payload_values) , * )
389
398
}
390
399
}
391
400
}
@@ -462,6 +471,7 @@ impl<'a> ReplyData<'a> {
462
471
463
472
trait ReplyVariant < ' a > {
464
473
fn as_variant_handlers_pair ( & ' a self ) -> Vec < ( & ' a MsgVariant < ' a > , & ' a Ident ) > ;
474
+ fn as_data_field ( & ' a self ) -> Option < & ' a MsgField < ' a > > ;
465
475
}
466
476
467
477
impl < ' a > ReplyVariant < ' a > for MsgVariant < ' a > {
@@ -479,6 +489,22 @@ impl<'a> ReplyVariant<'a> for MsgVariant<'a> {
479
489
480
490
variant_handler_id_pair
481
491
}
492
+
493
+ /// Returns `Some(MsgField)` if a field marked with `sv::data` attribute is present
494
+ /// and the `reply_on` attribute is set to `ReplyOn::Success`.
495
+ fn as_data_field ( & ' a self ) -> Option < & ' a MsgField < ' a > > {
496
+ let data_attrs = self . fields ( ) . first ( ) . map ( |field| {
497
+ ParsedSylviaAttributes :: new ( field. attrs ( ) . iter ( ) )
498
+ . data
499
+ . is_some ( )
500
+ } ) ;
501
+ match data_attrs {
502
+ Some ( attrs) if attrs && self . msg_attr ( ) . reply_on ( ) == ReplyOn :: Success => {
503
+ self . fields ( ) . first ( )
504
+ }
505
+ _ => None ,
506
+ }
507
+ }
482
508
}
483
509
484
510
pub trait DataField {
@@ -489,10 +515,6 @@ impl DataField for MsgField<'_> {
489
515
fn emit_data_deserialization ( & self ) -> TokenStream {
490
516
let sylvia = crate_module ( ) ;
491
517
let data = ParsedSylviaAttributes :: new ( self . attrs ( ) . iter ( ) ) . data ;
492
- let is_data_attr = self
493
- . attrs ( )
494
- . iter ( )
495
- . any ( |attr| SylviaAttribute :: new ( attr) == Some ( SylviaAttribute :: Data ) ) ;
496
518
let missing_data_err = "Missing reply data field." ;
497
519
let invalid_reply_data_err = quote ! {
498
520
format! { "Invalid reply data: {}\n Serde error while deserializing {}" , data, err}
@@ -555,7 +577,7 @@ impl DataField for MsgField<'_> {
555
577
None => None ,
556
578
} ;
557
579
} ,
558
- None if is_data_attr => quote ! {
580
+ Some ( _ ) => quote ! {
559
581
let data = match data {
560
582
Some ( data) => {
561
583
#execute_data_deserialization
@@ -616,8 +638,11 @@ impl PayloadFields for Vec<&MsgField<'_>> {
616
638
}
617
639
618
640
fn is_payload_marked ( & self ) -> bool {
619
- self . iter ( )
620
- . any ( |field| field. contains_attribute ( SylviaAttribute :: Payload ) )
641
+ self . iter ( ) . any ( |field| {
642
+ ParsedSylviaAttributes :: new ( field. attrs ( ) . iter ( ) )
643
+ . payload
644
+ . is_some ( )
645
+ } )
621
646
}
622
647
}
623
648
0 commit comments