| 
 | 1 | +//! Interfaces for HTTP interactions of the guest.  | 
 | 2 | +
  | 
 | 3 | +use std::{borrow::Cow, collections::HashSet, fmt};  | 
 | 4 | + | 
 | 5 | +use http::Method;  | 
 | 6 | +use wasmtime_wasi_http::body::HyperOutgoingBody;  | 
 | 7 | + | 
 | 8 | +/// Validates if an outgoing HTTP interaction is allowed.  | 
 | 9 | +pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static {  | 
 | 10 | +    /// Validate incoming request.  | 
 | 11 | +    ///  | 
 | 12 | +    /// Return [`Ok`] if the request should be allowed, return [`Err`] otherwise.  | 
 | 13 | +    fn validate(  | 
 | 14 | +        &self,  | 
 | 15 | +        request: &hyper::Request<HyperOutgoingBody>,  | 
 | 16 | +        use_tls: bool,  | 
 | 17 | +    ) -> Result<(), Rejected>;  | 
 | 18 | +}  | 
 | 19 | + | 
 | 20 | +/// Reject ALL requests.  | 
 | 21 | +#[derive(Debug, Clone, Copy, Default)]  | 
 | 22 | +pub struct RejectAllHttpRequests;  | 
 | 23 | + | 
 | 24 | +impl HttpRequestValidator for RejectAllHttpRequests {  | 
 | 25 | +    fn validate(  | 
 | 26 | +        &self,  | 
 | 27 | +        _request: &hyper::Request<HyperOutgoingBody>,  | 
 | 28 | +        _use_tls: bool,  | 
 | 29 | +    ) -> Result<(), Rejected> {  | 
 | 30 | +        Err(Rejected)  | 
 | 31 | +    }  | 
 | 32 | +}  | 
 | 33 | + | 
 | 34 | +/// A request matcher.  | 
 | 35 | +#[derive(Debug, Clone, Hash, PartialEq, Eq)]  | 
 | 36 | +pub struct Matcher {  | 
 | 37 | +    /// Method.  | 
 | 38 | +    pub method: Method,  | 
 | 39 | + | 
 | 40 | +    /// Host.  | 
 | 41 | +    ///  | 
 | 42 | +    /// Requests without a host will be rejected.  | 
 | 43 | +    pub host: Cow<'static, str>,  | 
 | 44 | + | 
 | 45 | +    /// Port.  | 
 | 46 | +    ///  | 
 | 47 | +    /// For requests without an explicit port, this defaults to `80` for non-TLS requests and to `443` for TLS requests.  | 
 | 48 | +    pub port: u16,  | 
 | 49 | +}  | 
 | 50 | + | 
 | 51 | +/// Allow-list requests.  | 
 | 52 | +#[derive(Debug, Clone, Default)]  | 
 | 53 | +pub struct AllowCertainHttpRequests {  | 
 | 54 | +    /// Set of all matchers.  | 
 | 55 | +    ///  | 
 | 56 | +    /// If ANY of them matches, the request will be allowed.  | 
 | 57 | +    matchers: HashSet<Matcher>,  | 
 | 58 | +}  | 
 | 59 | + | 
 | 60 | +impl AllowCertainHttpRequests {  | 
 | 61 | +    /// Create new, empty request matcher.  | 
 | 62 | +    pub fn new() -> Self {  | 
 | 63 | +        Self::default()  | 
 | 64 | +    }  | 
 | 65 | + | 
 | 66 | +    /// Allow given request.  | 
 | 67 | +    pub fn allow(&mut self, matcher: Matcher) {  | 
 | 68 | +        self.matchers.insert(matcher);  | 
 | 69 | +    }  | 
 | 70 | +}  | 
 | 71 | + | 
 | 72 | +impl HttpRequestValidator for AllowCertainHttpRequests {  | 
 | 73 | +    fn validate(  | 
 | 74 | +        &self,  | 
 | 75 | +        request: &hyper::Request<HyperOutgoingBody>,  | 
 | 76 | +        use_tls: bool,  | 
 | 77 | +    ) -> Result<(), Rejected> {  | 
 | 78 | +        let matcher = Matcher {  | 
 | 79 | +            method: request.method().clone(),  | 
 | 80 | +            host: request.uri().host().ok_or(Rejected)?.to_owned().into(),  | 
 | 81 | +            port: request  | 
 | 82 | +                .uri()  | 
 | 83 | +                .port_u16()  | 
 | 84 | +                .unwrap_or(if use_tls { 443 } else { 80 }),  | 
 | 85 | +        };  | 
 | 86 | + | 
 | 87 | +        if self.matchers.contains(&matcher) {  | 
 | 88 | +            Ok(())  | 
 | 89 | +        } else {  | 
 | 90 | +            Err(Rejected)  | 
 | 91 | +        }  | 
 | 92 | +    }  | 
 | 93 | +}  | 
 | 94 | + | 
 | 95 | +/// Reject HTTP request.  | 
 | 96 | +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]  | 
 | 97 | +pub struct Rejected;  | 
 | 98 | + | 
 | 99 | +impl fmt::Display for Rejected {  | 
 | 100 | +    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {  | 
 | 101 | +        f.write_str("rejected")  | 
 | 102 | +    }  | 
 | 103 | +}  | 
 | 104 | + | 
 | 105 | +impl std::error::Error for Rejected {}  | 
 | 106 | + | 
 | 107 | +#[cfg(test)]  | 
 | 108 | +mod test {  | 
 | 109 | +    use super::*;  | 
 | 110 | + | 
 | 111 | +    #[test]  | 
 | 112 | +    fn reject_all() {  | 
 | 113 | +        let policy = RejectAllHttpRequests;  | 
 | 114 | + | 
 | 115 | +        let request = hyper::Request::builder().body(Default::default()).unwrap();  | 
 | 116 | +        policy.validate(&request, false).unwrap_err();  | 
 | 117 | +    }  | 
 | 118 | + | 
 | 119 | +    #[test]  | 
 | 120 | +    fn allow_certain() {  | 
 | 121 | +        let request_no_port = hyper::Request::builder()  | 
 | 122 | +            .method(Method::GET)  | 
 | 123 | +            .uri("http://foo.bar")  | 
 | 124 | +            .body(Default::default())  | 
 | 125 | +            .unwrap();  | 
 | 126 | + | 
 | 127 | +        let request_with_port = hyper::Request::builder()  | 
 | 128 | +            .method(Method::GET)  | 
 | 129 | +            .uri("http://my.universe:1337")  | 
 | 130 | +            .body(Default::default())  | 
 | 131 | +            .unwrap();  | 
 | 132 | + | 
 | 133 | +        struct Case {  | 
 | 134 | +            matchers: Vec<Matcher>,  | 
 | 135 | +            result_no_port_no_tls: Result<(), Rejected>,  | 
 | 136 | +            result_no_port_with_tls: Result<(), Rejected>,  | 
 | 137 | +            result_with_port_no_tls: Result<(), Rejected>,  | 
 | 138 | +            result_with_port_with_tls: Result<(), Rejected>,  | 
 | 139 | +        }  | 
 | 140 | + | 
 | 141 | +        let cases = [  | 
 | 142 | +            Case {  | 
 | 143 | +                matchers: vec![],  | 
 | 144 | +                result_no_port_no_tls: Err(Rejected),  | 
 | 145 | +                result_no_port_with_tls: Err(Rejected),  | 
 | 146 | +                result_with_port_no_tls: Err(Rejected),  | 
 | 147 | +                result_with_port_with_tls: Err(Rejected),  | 
 | 148 | +            },  | 
 | 149 | +            Case {  | 
 | 150 | +                matchers: vec![Matcher {  | 
 | 151 | +                    method: Method::GET,  | 
 | 152 | +                    host: "foo.bar".into(),  | 
 | 153 | +                    port: 80,  | 
 | 154 | +                }],  | 
 | 155 | +                result_no_port_no_tls: Ok(()),  | 
 | 156 | +                result_no_port_with_tls: Err(Rejected),  | 
 | 157 | +                result_with_port_no_tls: Err(Rejected),  | 
 | 158 | +                result_with_port_with_tls: Err(Rejected),  | 
 | 159 | +            },  | 
 | 160 | +            Case {  | 
 | 161 | +                matchers: vec![Matcher {  | 
 | 162 | +                    method: Method::GET,  | 
 | 163 | +                    host: "foo.bar".into(),  | 
 | 164 | +                    port: 443,  | 
 | 165 | +                }],  | 
 | 166 | +                result_no_port_no_tls: Err(Rejected),  | 
 | 167 | +                result_no_port_with_tls: Ok(()),  | 
 | 168 | +                result_with_port_no_tls: Err(Rejected),  | 
 | 169 | +                result_with_port_with_tls: Err(Rejected),  | 
 | 170 | +            },  | 
 | 171 | +            Case {  | 
 | 172 | +                matchers: vec![Matcher {  | 
 | 173 | +                    method: Method::POST,  | 
 | 174 | +                    host: "foo.bar".into(),  | 
 | 175 | +                    port: 80,  | 
 | 176 | +                }],  | 
 | 177 | +                result_no_port_no_tls: Err(Rejected),  | 
 | 178 | +                result_no_port_with_tls: Err(Rejected),  | 
 | 179 | +                result_with_port_no_tls: Err(Rejected),  | 
 | 180 | +                result_with_port_with_tls: Err(Rejected),  | 
 | 181 | +            },  | 
 | 182 | +            Case {  | 
 | 183 | +                matchers: vec![Matcher {  | 
 | 184 | +                    method: Method::GET,  | 
 | 185 | +                    host: "my.universe".into(),  | 
 | 186 | +                    port: 80,  | 
 | 187 | +                }],  | 
 | 188 | +                result_no_port_no_tls: Err(Rejected),  | 
 | 189 | +                result_no_port_with_tls: Err(Rejected),  | 
 | 190 | +                result_with_port_no_tls: Err(Rejected),  | 
 | 191 | +                result_with_port_with_tls: Err(Rejected),  | 
 | 192 | +            },  | 
 | 193 | +            Case {  | 
 | 194 | +                matchers: vec![Matcher {  | 
 | 195 | +                    method: Method::GET,  | 
 | 196 | +                    host: "my.universe".into(),  | 
 | 197 | +                    port: 1337,  | 
 | 198 | +                }],  | 
 | 199 | +                result_no_port_no_tls: Err(Rejected),  | 
 | 200 | +                result_no_port_with_tls: Err(Rejected),  | 
 | 201 | +                result_with_port_no_tls: Ok(()),  | 
 | 202 | +                result_with_port_with_tls: Ok(()),  | 
 | 203 | +            },  | 
 | 204 | +            Case {  | 
 | 205 | +                matchers: vec![  | 
 | 206 | +                    Matcher {  | 
 | 207 | +                        method: Method::GET,  | 
 | 208 | +                        host: "foo.bar".into(),  | 
 | 209 | +                        port: 80,  | 
 | 210 | +                    },  | 
 | 211 | +                    Matcher {  | 
 | 212 | +                        method: Method::POST,  | 
 | 213 | +                        host: "foo.bar".into(),  | 
 | 214 | +                        port: 80,  | 
 | 215 | +                    },  | 
 | 216 | +                    Matcher {  | 
 | 217 | +                        method: Method::GET,  | 
 | 218 | +                        host: "my.universe".into(),  | 
 | 219 | +                        port: 1337,  | 
 | 220 | +                    },  | 
 | 221 | +                ],  | 
 | 222 | +                result_no_port_no_tls: Ok(()),  | 
 | 223 | +                result_no_port_with_tls: Err(Rejected),  | 
 | 224 | +                result_with_port_no_tls: Ok(()),  | 
 | 225 | +                result_with_port_with_tls: Ok(()),  | 
 | 226 | +            },  | 
 | 227 | +        ];  | 
 | 228 | + | 
 | 229 | +        for (i, case) in cases.into_iter().enumerate() {  | 
 | 230 | +            println!("case: {}", i + 1);  | 
 | 231 | + | 
 | 232 | +            let Case {  | 
 | 233 | +                matchers,  | 
 | 234 | +                result_no_port_no_tls,  | 
 | 235 | +                result_no_port_with_tls,  | 
 | 236 | +                result_with_port_no_tls,  | 
 | 237 | +                result_with_port_with_tls,  | 
 | 238 | +            } = case;  | 
 | 239 | + | 
 | 240 | +            let mut policy = AllowCertainHttpRequests::default();  | 
 | 241 | + | 
 | 242 | +            for matcher in matchers {  | 
 | 243 | +                policy.allow(matcher);  | 
 | 244 | +            }  | 
 | 245 | + | 
 | 246 | +            assert_eq!(  | 
 | 247 | +                policy.validate(&request_no_port, false),  | 
 | 248 | +                result_no_port_no_tls,  | 
 | 249 | +            );  | 
 | 250 | +            assert_eq!(  | 
 | 251 | +                policy.validate(&request_no_port, true),  | 
 | 252 | +                result_no_port_with_tls,  | 
 | 253 | +            );  | 
 | 254 | +            assert_eq!(  | 
 | 255 | +                policy.validate(&request_with_port, false),  | 
 | 256 | +                result_with_port_no_tls,  | 
 | 257 | +            );  | 
 | 258 | +            assert_eq!(  | 
 | 259 | +                policy.validate(&request_with_port, true),  | 
 | 260 | +                result_with_port_with_tls,  | 
 | 261 | +            );  | 
 | 262 | +        }  | 
 | 263 | +    }  | 
 | 264 | +}  | 
0 commit comments