Skip to content

Commit f3d994b

Browse files
committed
feat: filter HTTP requests
3 parents 5e6ad9e + c464190 + 213e822 commit f3d994b

File tree

9 files changed

+508
-19
lines changed

9 files changed

+508
-19
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ datafusion-udf-wasm-arrow2bytes = { path = "arrow2bytes", version = "0.1.0" }
2323
datafusion-udf-wasm-bundle = { path = "guests/bundle", version = "0.1.0" }
2424
datafusion-udf-wasm-guest = { path = "guests/rust", version = "0.1.0" }
2525
datafusion-udf-wasm-python = { path = "guests/python", version = "0.1.0" }
26+
http = { version = "1.3.1", default-features = false }
27+
hyper = { version = "1.7", default-features = false }
2628
tokio = { version = "1.48.0", default-features = false }
2729
pyo3 = { version = "0.27.1", default-features = false, features = ["macros"] }
2830
tar = { version = "0.4.44", default-features = false }

guests/python/src/python_modules/mod.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,7 +1564,7 @@ mod wit_world {
15641564
}
15651565

15661566
#[pyclass]
1567-
#[derive(Debug, IntoPyObject)]
1567+
#[derive(Debug)]
15681568
#[pyo3(extends = PyValueError, frozen, get_all, name = "Err", str)]
15691569
pub(crate) struct ErrWrapper {
15701570
value: Py<PyAny>,
@@ -1623,9 +1623,9 @@ mod wit_world {
16231623
#[derive(Debug, IntoPyObject)]
16241624
pub(crate) enum ResultWrapper {
16251625
#[pyo3(transparent)]
1626-
Ok(OkWrapper),
1626+
Ok(Py<OkWrapper>),
16271627
#[pyo3(transparent)]
1628-
Err(ErrWrapper),
1628+
Err(Py<ErrWrapper>),
16291629
}
16301630

16311631
impl ResultWrapper {
@@ -1635,26 +1635,30 @@ mod wit_world {
16351635
E: IntoPyObject<'py>,
16361636
{
16371637
let res = match res {
1638-
Ok(val) => Self::Ok(OkWrapper {
1639-
value: val
1638+
Ok(val) => {
1639+
let val = val
16401640
.into_pyobject(py)
16411641
.map_err(|e| {
16421642
let e: PyErr = e.into();
16431643
e
16441644
})?
16451645
.into_any()
1646-
.unbind(),
1647-
}),
1648-
Err(val) => Self::Err(ErrWrapper {
1649-
value: val
1646+
.unbind();
1647+
1648+
Self::Ok(Py::new(py, OkWrapper { value: val })?)
1649+
}
1650+
Err(val) => {
1651+
let val = val
16501652
.into_pyobject(py)
16511653
.map_err(|e| {
16521654
let e: PyErr = e.into();
16531655
e
16541656
})?
16551657
.into_any()
1656-
.unbind(),
1657-
}),
1658+
.unbind();
1659+
1660+
Self::Err(Py::new(py, ErrWrapper { value: val })?)
1661+
}
16581662
};
16591663
Ok(res)
16601664
}

host/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ arrow.workspace = true
1313
datafusion-common.workspace = true
1414
datafusion-expr.workspace = true
1515
datafusion-udf-wasm-arrow2bytes.workspace = true
16+
http.workspace = true
17+
hyper.workspace = true
1618
tar.workspace = true
1719
tempfile.workspace = true
1820
tokio = { workspace = true, features = ["rt", "rt-multi-thread", "sync"] }

host/src/http.rs

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
//! Interfaces for HTTP interactions of the guest.
2+
3+
use std::{borrow::Cow, collections::HashSet, fmt};
4+
5+
use http::Method;
6+
use wasmtime_wasi_http::body::HyperOutgoingBody;
7+
8+
/// Validates if an outgoing HTTP interaction is allowed.
9+
pub trait HttpRequestValidator: fmt::Debug + Send + Sync + 'static {
10+
/// Validate incoming request.
11+
///
12+
/// Return [`Ok`] if the request should be allowed, return [`Err`] otherwise.
13+
fn validate(
14+
&self,
15+
request: &hyper::Request<HyperOutgoingBody>,
16+
use_tls: bool,
17+
) -> Result<(), Rejected>;
18+
}
19+
20+
/// Reject ALL requests.
21+
#[derive(Debug, Clone, Copy, Default)]
22+
pub struct RejectAllHttpRequests;
23+
24+
impl HttpRequestValidator for RejectAllHttpRequests {
25+
fn validate(
26+
&self,
27+
_request: &hyper::Request<HyperOutgoingBody>,
28+
_use_tls: bool,
29+
) -> Result<(), Rejected> {
30+
Err(Rejected)
31+
}
32+
}
33+
34+
/// A request matcher.
35+
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
36+
pub struct Matcher {
37+
/// Method.
38+
pub method: Method,
39+
40+
/// Host.
41+
///
42+
/// Requests without a host will be rejected.
43+
pub host: Cow<'static, str>,
44+
45+
/// Port.
46+
///
47+
/// For requests without an explicit port, this defaults to `80` for non-TLS requests and to `443` for TLS requests.
48+
pub port: u16,
49+
}
50+
51+
/// Allow-list requests.
52+
#[derive(Debug, Clone, Default)]
53+
pub struct AllowCertainHttpRequests {
54+
/// Set of all matchers.
55+
///
56+
/// If ANY of them matches, the request will be allowed.
57+
matchers: HashSet<Matcher>,
58+
}
59+
60+
impl AllowCertainHttpRequests {
61+
/// Create new, empty request matcher.
62+
pub fn new() -> Self {
63+
Self::default()
64+
}
65+
66+
/// Allow given request.
67+
pub fn allow(&mut self, matcher: Matcher) {
68+
self.matchers.insert(matcher);
69+
}
70+
}
71+
72+
impl HttpRequestValidator for AllowCertainHttpRequests {
73+
fn validate(
74+
&self,
75+
request: &hyper::Request<HyperOutgoingBody>,
76+
use_tls: bool,
77+
) -> Result<(), Rejected> {
78+
let matcher = Matcher {
79+
method: request.method().clone(),
80+
host: request.uri().host().ok_or(Rejected)?.to_owned().into(),
81+
port: request
82+
.uri()
83+
.port_u16()
84+
.unwrap_or(if use_tls { 443 } else { 80 }),
85+
};
86+
87+
if self.matchers.contains(&matcher) {
88+
Ok(())
89+
} else {
90+
Err(Rejected)
91+
}
92+
}
93+
}
94+
95+
/// Reject HTTP request.
96+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
97+
pub struct Rejected;
98+
99+
impl fmt::Display for Rejected {
100+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101+
f.write_str("rejected")
102+
}
103+
}
104+
105+
impl std::error::Error for Rejected {}
106+
107+
#[cfg(test)]
108+
mod test {
109+
use super::*;
110+
111+
#[test]
112+
fn reject_all() {
113+
let policy = RejectAllHttpRequests;
114+
115+
let request = hyper::Request::builder().body(Default::default()).unwrap();
116+
policy.validate(&request, false).unwrap_err();
117+
}
118+
119+
#[test]
120+
fn allow_certain() {
121+
let request_no_port = hyper::Request::builder()
122+
.method(Method::GET)
123+
.uri("http://foo.bar")
124+
.body(Default::default())
125+
.unwrap();
126+
127+
let request_with_port = hyper::Request::builder()
128+
.method(Method::GET)
129+
.uri("http://my.universe:1337")
130+
.body(Default::default())
131+
.unwrap();
132+
133+
struct Case {
134+
matchers: Vec<Matcher>,
135+
result_no_port_no_tls: Result<(), Rejected>,
136+
result_no_port_with_tls: Result<(), Rejected>,
137+
result_with_port_no_tls: Result<(), Rejected>,
138+
result_with_port_with_tls: Result<(), Rejected>,
139+
}
140+
141+
let cases = [
142+
Case {
143+
matchers: vec![],
144+
result_no_port_no_tls: Err(Rejected),
145+
result_no_port_with_tls: Err(Rejected),
146+
result_with_port_no_tls: Err(Rejected),
147+
result_with_port_with_tls: Err(Rejected),
148+
},
149+
Case {
150+
matchers: vec![Matcher {
151+
method: Method::GET,
152+
host: "foo.bar".into(),
153+
port: 80,
154+
}],
155+
result_no_port_no_tls: Ok(()),
156+
result_no_port_with_tls: Err(Rejected),
157+
result_with_port_no_tls: Err(Rejected),
158+
result_with_port_with_tls: Err(Rejected),
159+
},
160+
Case {
161+
matchers: vec![Matcher {
162+
method: Method::GET,
163+
host: "foo.bar".into(),
164+
port: 443,
165+
}],
166+
result_no_port_no_tls: Err(Rejected),
167+
result_no_port_with_tls: Ok(()),
168+
result_with_port_no_tls: Err(Rejected),
169+
result_with_port_with_tls: Err(Rejected),
170+
},
171+
Case {
172+
matchers: vec![Matcher {
173+
method: Method::POST,
174+
host: "foo.bar".into(),
175+
port: 80,
176+
}],
177+
result_no_port_no_tls: Err(Rejected),
178+
result_no_port_with_tls: Err(Rejected),
179+
result_with_port_no_tls: Err(Rejected),
180+
result_with_port_with_tls: Err(Rejected),
181+
},
182+
Case {
183+
matchers: vec![Matcher {
184+
method: Method::GET,
185+
host: "my.universe".into(),
186+
port: 80,
187+
}],
188+
result_no_port_no_tls: Err(Rejected),
189+
result_no_port_with_tls: Err(Rejected),
190+
result_with_port_no_tls: Err(Rejected),
191+
result_with_port_with_tls: Err(Rejected),
192+
},
193+
Case {
194+
matchers: vec![Matcher {
195+
method: Method::GET,
196+
host: "my.universe".into(),
197+
port: 1337,
198+
}],
199+
result_no_port_no_tls: Err(Rejected),
200+
result_no_port_with_tls: Err(Rejected),
201+
result_with_port_no_tls: Ok(()),
202+
result_with_port_with_tls: Ok(()),
203+
},
204+
Case {
205+
matchers: vec![
206+
Matcher {
207+
method: Method::GET,
208+
host: "foo.bar".into(),
209+
port: 80,
210+
},
211+
Matcher {
212+
method: Method::POST,
213+
host: "foo.bar".into(),
214+
port: 80,
215+
},
216+
Matcher {
217+
method: Method::GET,
218+
host: "my.universe".into(),
219+
port: 1337,
220+
},
221+
],
222+
result_no_port_no_tls: Ok(()),
223+
result_no_port_with_tls: Err(Rejected),
224+
result_with_port_no_tls: Ok(()),
225+
result_with_port_with_tls: Ok(()),
226+
},
227+
];
228+
229+
for (i, case) in cases.into_iter().enumerate() {
230+
println!("case: {}", i + 1);
231+
232+
let Case {
233+
matchers,
234+
result_no_port_no_tls,
235+
result_no_port_with_tls,
236+
result_with_port_no_tls,
237+
result_with_port_with_tls,
238+
} = case;
239+
240+
let mut policy = AllowCertainHttpRequests::default();
241+
242+
for matcher in matchers {
243+
policy.allow(matcher);
244+
}
245+
246+
assert_eq!(
247+
policy.validate(&request_no_port, false),
248+
result_no_port_no_tls,
249+
);
250+
assert_eq!(
251+
policy.validate(&request_no_port, true),
252+
result_no_port_with_tls,
253+
);
254+
assert_eq!(
255+
policy.validate(&request_with_port, false),
256+
result_with_port_no_tls,
257+
);
258+
assert_eq!(
259+
policy.validate(&request_with_port, true),
260+
result_with_port_with_tls,
261+
);
262+
}
263+
}
264+
}

0 commit comments

Comments
 (0)