Skip to content

Commit c7cfa17

Browse files
committed
feat: filter HTTP requests
1 parent 5945834 commit c7cfa17

File tree

8 files changed

+495
-8
lines changed

8 files changed

+495
-8
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ datafusion-udf-wasm-arrow2bytes = { path = "arrow2bytes", version = "0.1.0" }
2323
datafusion-udf-wasm-bundle = { path = "guests/bundle", version = "0.1.0" }
2424
datafusion-udf-wasm-guest = { path = "guests/rust", version = "0.1.0" }
2525
datafusion-udf-wasm-python = { path = "guests/python", version = "0.1.0" }
26+
http = { version = "1.3.1", default-features = false }
27+
hyper = { version = "1.7", default-features = false }
2628
tokio = { version = "1.48.0", default-features = false }
2729
pyo3 = { version = "0.27.1", default-features = false, features = ["macros"] }
2830
tar = { version = "0.4.44", default-features = false }

host/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ arrow.workspace = true
1313
datafusion-common.workspace = true
1414
datafusion-expr.workspace = true
1515
datafusion-udf-wasm-arrow2bytes.workspace = true
16+
http.workspace = true
17+
hyper.workspace = true
1618
tar.workspace = true
1719
tempfile.workspace = true
1820
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "sync"] }

host/src/http.rs

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
//! Interfaces for HTTP interactions of the guest.
2+
3+
use std::{borrow::Cow, collections::HashSet, fmt};
4+
5+
use http::Method;
6+
use wasmtime_wasi_http::body::HyperOutgoingBody;
7+
8+
/// Validates if an outgoing HTTP interaction is allowed.
9+
///
10+
/// You can implement your own business logic here or use one of the pre-built implementations in [this module](self).
11+
pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static {
12+
/// Validate incoming request.
13+
///
14+
/// Return [`Ok`] if the request should be allowed, return [`Err`] otherwise.
15+
fn validate(
16+
&self,
17+
request: &hyper::Request<HyperOutgoingBody>,
18+
use_tls: bool,
19+
) -> Result<(), Rejected>;
20+
}
21+
22+
/// Reject ALL requests.
23+
#[derive(Debug, Clone, Copy, Default)]
24+
pub struct RejectAllHttpRequests;
25+
26+
impl HttpRequestValidator for RejectAllHttpRequests {
27+
fn validate(
28+
&self,
29+
_request: &hyper::Request<HyperOutgoingBody>,
30+
_use_tls: bool,
31+
) -> Result<(), Rejected> {
32+
Err(Rejected)
33+
}
34+
}
35+
36+
/// A request matcher.
37+
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
38+
pub struct Matcher {
39+
/// Method.
40+
pub method: Method,
41+
42+
/// Host.
43+
///
44+
/// Requests without a host will be rejected.
45+
pub host: Cow<'static, str>,
46+
47+
/// Port.
48+
///
49+
/// For requests without an explicit port, this defaults to `80` for non-TLS requests and to `443` for TLS requests.
50+
pub port: u16,
51+
}
52+
53+
/// Allow-list requests.
54+
#[derive(Debug, Clone, Default)]
55+
pub struct AllowCertainHttpRequests {
56+
/// Set of all matchers.
57+
///
58+
/// If ANY of them matches, the request will be allowed.
59+
matchers: HashSet<Matcher>,
60+
}
61+
62+
impl AllowCertainHttpRequests {
63+
/// Create new, empty request matcher.
64+
pub fn new() -> Self {
65+
Self::default()
66+
}
67+
68+
/// Allow given request.
69+
pub fn allow(&mut self, matcher: Matcher) {
70+
self.matchers.insert(matcher);
71+
}
72+
}
73+
74+
impl HttpRequestValidator for AllowCertainHttpRequests {
75+
fn validate(
76+
&self,
77+
request: &hyper::Request<HyperOutgoingBody>,
78+
use_tls: bool,
79+
) -> Result<(), Rejected> {
80+
let matcher = Matcher {
81+
method: request.method().clone(),
82+
host: request.uri().host().ok_or(Rejected)?.to_owned().into(),
83+
port: request
84+
.uri()
85+
.port_u16()
86+
.unwrap_or(if use_tls { 443 } else { 80 }),
87+
};
88+
89+
if self.matchers.contains(&matcher) {
90+
Ok(())
91+
} else {
92+
Err(Rejected)
93+
}
94+
}
95+
}
96+
97+
/// Reject HTTP request.
98+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
99+
pub struct Rejected;
100+
101+
impl fmt::Display for Rejected {
102+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103+
f.write_str("rejected")
104+
}
105+
}
106+
107+
impl std::error::Error for Rejected {}
108+
109+
#[cfg(test)]
110+
mod test {
111+
use super::*;
112+
113+
#[test]
114+
fn reject_all() {
115+
let policy = RejectAllHttpRequests;
116+
117+
let request = hyper::Request::builder().body(Default::default()).unwrap();
118+
policy.validate(&request, false).unwrap_err();
119+
}
120+
121+
#[test]
122+
fn allow_certain() {
123+
let request_no_port = hyper::Request::builder()
124+
.method(Method::GET)
125+
.uri("http://foo.bar")
126+
.body(Default::default())
127+
.unwrap();
128+
129+
let request_with_port = hyper::Request::builder()
130+
.method(Method::GET)
131+
.uri("http://my.universe:1337")
132+
.body(Default::default())
133+
.unwrap();
134+
135+
struct Case {
136+
matchers: Vec<Matcher>,
137+
result_no_port_no_tls: Result<(), Rejected>,
138+
result_no_port_with_tls: Result<(), Rejected>,
139+
result_with_port_no_tls: Result<(), Rejected>,
140+
result_with_port_with_tls: Result<(), Rejected>,
141+
}
142+
143+
let cases = [
144+
Case {
145+
matchers: vec![],
146+
result_no_port_no_tls: Err(Rejected),
147+
result_no_port_with_tls: Err(Rejected),
148+
result_with_port_no_tls: Err(Rejected),
149+
result_with_port_with_tls: Err(Rejected),
150+
},
151+
Case {
152+
matchers: vec![Matcher {
153+
method: Method::GET,
154+
host: "foo.bar".into(),
155+
port: 80,
156+
}],
157+
result_no_port_no_tls: Ok(()),
158+
result_no_port_with_tls: Err(Rejected),
159+
result_with_port_no_tls: Err(Rejected),
160+
result_with_port_with_tls: Err(Rejected),
161+
},
162+
Case {
163+
matchers: vec![Matcher {
164+
method: Method::GET,
165+
host: "foo.bar".into(),
166+
port: 443,
167+
}],
168+
result_no_port_no_tls: Err(Rejected),
169+
result_no_port_with_tls: Ok(()),
170+
result_with_port_no_tls: Err(Rejected),
171+
result_with_port_with_tls: Err(Rejected),
172+
},
173+
Case {
174+
matchers: vec![Matcher {
175+
method: Method::POST,
176+
host: "foo.bar".into(),
177+
port: 80,
178+
}],
179+
result_no_port_no_tls: Err(Rejected),
180+
result_no_port_with_tls: Err(Rejected),
181+
result_with_port_no_tls: Err(Rejected),
182+
result_with_port_with_tls: Err(Rejected),
183+
},
184+
Case {
185+
matchers: vec![Matcher {
186+
method: Method::GET,
187+
host: "my.universe".into(),
188+
port: 80,
189+
}],
190+
result_no_port_no_tls: Err(Rejected),
191+
result_no_port_with_tls: Err(Rejected),
192+
result_with_port_no_tls: Err(Rejected),
193+
result_with_port_with_tls: Err(Rejected),
194+
},
195+
Case {
196+
matchers: vec![Matcher {
197+
method: Method::GET,
198+
host: "my.universe".into(),
199+
port: 1337,
200+
}],
201+
result_no_port_no_tls: Err(Rejected),
202+
result_no_port_with_tls: Err(Rejected),
203+
result_with_port_no_tls: Ok(()),
204+
result_with_port_with_tls: Ok(()),
205+
},
206+
Case {
207+
matchers: vec![
208+
Matcher {
209+
method: Method::GET,
210+
host: "foo.bar".into(),
211+
port: 80,
212+
},
213+
Matcher {
214+
method: Method::POST,
215+
host: "foo.bar".into(),
216+
port: 80,
217+
},
218+
Matcher {
219+
method: Method::GET,
220+
host: "my.universe".into(),
221+
port: 1337,
222+
},
223+
],
224+
result_no_port_no_tls: Ok(()),
225+
result_no_port_with_tls: Err(Rejected),
226+
result_with_port_no_tls: Ok(()),
227+
result_with_port_with_tls: Ok(()),
228+
},
229+
];
230+
231+
for (i, case) in cases.into_iter().enumerate() {
232+
println!("case: {}", i + 1);
233+
234+
let Case {
235+
matchers,
236+
result_no_port_no_tls,
237+
result_no_port_with_tls,
238+
result_with_port_no_tls,
239+
result_with_port_with_tls,
240+
} = case;
241+
242+
let mut policy = AllowCertainHttpRequests::default();
243+
244+
for matcher in matchers {
245+
policy.allow(matcher);
246+
}
247+
248+
assert_eq!(
249+
policy.validate(&request_no_port, false),
250+
result_no_port_no_tls,
251+
);
252+
assert_eq!(
253+
policy.validate(&request_no_port, true),
254+
result_no_port_with_tls,
255+
);
256+
assert_eq!(
257+
policy.validate(&request_with_port, false),
258+
result_with_port_no_tls,
259+
);
260+
assert_eq!(
261+
policy.validate(&request_with_port, true),
262+
result_with_port_with_tls,
263+
);
264+
}
265+
}
266+
}

0 commit comments

Comments
 (0)