2
2
// SPDX-License-Identifier: Apache-2.0
3
3
4
4
use axum_extra:: headers:: { self , Header , HeaderName , HeaderValue } ;
5
+ use base64:: prelude:: * ;
5
6
use lazy_static:: lazy_static;
6
7
use prometheus:: { register_counter, Counter } ;
8
+ use prost:: Message ;
9
+ use tap_aggregator:: grpc;
7
10
use tap_graph:: SignedReceipt ;
8
11
9
12
use crate :: tap:: TapReceipt ;
@@ -26,17 +29,28 @@ impl Header for TapHeader {
26
29
where
27
30
I : Iterator < Item = & ' i HeaderValue > ,
28
31
{
29
- let mut execute = || {
30
- let value = values. next ( ) ;
31
- let raw_receipt = value. ok_or ( headers:: Error :: invalid ( ) ) ?;
32
- let raw_receipt = raw_receipt
33
- . to_str ( )
34
- . map_err ( |_| headers:: Error :: invalid ( ) ) ?;
35
- let parsed_receipt: SignedReceipt =
36
- serde_json:: from_str ( raw_receipt) . map_err ( |_| headers:: Error :: invalid ( ) ) ?;
37
- Ok ( TapHeader ( crate :: tap:: TapReceipt :: V1 ( parsed_receipt) ) )
32
+ let mut execute = || -> anyhow:: Result < TapHeader > {
33
+ let raw_receipt = values. next ( ) . ok_or ( headers:: Error :: invalid ( ) ) ?;
34
+
35
+ // we first try to decode a v2 receipt since it's cheaper and fail earlier than using
36
+ // serde
37
+ match BASE64_STANDARD . decode ( raw_receipt) {
38
+ Ok ( raw_receipt) => {
39
+ tracing:: debug!( "Decoded v2" ) ;
40
+ let receipt = grpc:: v2:: SignedReceipt :: decode ( raw_receipt. as_ref ( ) ) ?;
41
+ Ok ( TapHeader ( TapReceipt :: V2 ( receipt. try_into ( ) ?) ) )
42
+ }
43
+ Err ( _) => {
44
+ tracing:: debug!( "Could not decode v2, trying v1" ) ;
45
+ let parsed_receipt: SignedReceipt =
46
+ serde_json:: from_slice ( raw_receipt. as_ref ( ) ) ?;
47
+ Ok ( TapHeader ( TapReceipt :: V1 ( parsed_receipt) ) )
48
+ }
49
+ }
38
50
} ;
39
- execute ( ) . inspect_err ( |_| TAP_RECEIPT_INVALID . inc ( ) )
51
+ execute ( )
52
+ . map_err ( |_| headers:: Error :: invalid ( ) )
53
+ . inspect_err ( |_| TAP_RECEIPT_INVALID . inc ( ) )
40
54
}
41
55
42
56
fn encode < E > ( & self , _values : & mut E )
@@ -51,13 +65,16 @@ impl Header for TapHeader {
51
65
mod test {
52
66
use axum:: http:: HeaderValue ;
53
67
use axum_extra:: headers:: Header ;
54
- use test_assets:: { create_signed_receipt, SignedReceiptRequest } ;
68
+ use base64:: prelude:: * ;
69
+ use prost:: Message ;
70
+ use tap_aggregator:: grpc:: v2:: SignedReceipt ;
71
+ use test_assets:: { create_signed_receipt, create_signed_receipt_v2, SignedReceiptRequest } ;
55
72
56
73
use super :: TapHeader ;
57
74
use crate :: tap:: TapReceipt ;
58
75
59
76
#[ tokio:: test]
60
- async fn test_decode_valid_tap_receipt_header ( ) {
77
+ async fn test_decode_valid_tap_v1_receipt_header ( ) {
61
78
let original_receipt = create_signed_receipt ( SignedReceiptRequest :: builder ( ) . build ( ) ) . await ;
62
79
let serialized_receipt = serde_json:: to_string ( & original_receipt) . unwrap ( ) ;
63
80
let header_value = HeaderValue :: from_str ( & serialized_receipt) . unwrap ( ) ;
@@ -68,6 +85,20 @@ mod test {
68
85
assert_eq ! ( decoded_receipt, TapHeader ( TapReceipt :: V1 ( original_receipt) ) ) ;
69
86
}
70
87
88
+ #[ test_log:: test( tokio:: test) ]
89
+ async fn test_decode_valid_tap_v2_receipt_header ( ) {
90
+ let original_receipt = create_signed_receipt_v2 ( ) . call ( ) . await ;
91
+ let protobuf_receipt = SignedReceipt :: from ( original_receipt. clone ( ) ) ;
92
+ let encoded = protobuf_receipt. encode_to_vec ( ) ;
93
+ let base64_encoded = BASE64_STANDARD . encode ( encoded) ;
94
+ let header_value = HeaderValue :: from_str ( & base64_encoded) . unwrap ( ) ;
95
+ let header_values = vec ! [ & header_value] ;
96
+ let decoded_receipt = TapHeader :: decode ( & mut header_values. into_iter ( ) )
97
+ . expect ( "tap receipt header value should be valid" ) ;
98
+
99
+ assert_eq ! ( decoded_receipt, TapHeader ( TapReceipt :: V2 ( original_receipt) ) ) ;
100
+ }
101
+
71
102
#[ test]
72
103
fn test_decode_non_string_tap_receipt_header ( ) {
73
104
let header_value = HeaderValue :: from_static ( "123" ) ;
0 commit comments