Skip to content

Commit 406bac9

Browse files
committed
feat: Python requests
2 parents c850a1c + 3af386b commit 406bac9

File tree

4 files changed

+182
-50
lines changed

4 files changed

+182
-50
lines changed

guests/python/requirements.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
1-
# That's https://github.com/urllib3/urllib3/pull/3593 as a wheel.
2-
urllib3 @ https://github.com/crepererum/urllib3/releases/download/2.5.100/urllib3-2.5.100-py3-none-any.whl --hash=sha256:b48be99c923c989e8db2a41a21f2511d830a7b7e99bfde24ae7214baadf7daaf
1+
certifi==2025.10.5 --hash=sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de
2+
charset_normalizer==3.4.4 --hash=sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f
3+
idna==3.11 --hash=sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea
4+
requests==2.32.5 --hash=sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6
5+
6+
# That's https://github.com/urllib3/urllib3/pull/3593 + https://github.com/golemcloud/urllib3/pull/2 as a wheel.
7+
urllib3 @ https://github.com/crepererum/urllib3/releases/download/2.5.101/urllib3-2.5.101-py3-none-any.whl --hash=sha256:2b2dcd0944a3b5d6ce7517cf068c232c1963a6702146695147454c46b4da71f1

guests/python/src/python_modules/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,12 @@ mod wit_world {
11471147
self.inner.take();
11481148
}
11491149

1150+
fn finish(&mut self) -> PyResult<FutureTrailers> {
1151+
let body = self.inner.take().require_resource()?;
1152+
let trailers = wasip2::http::types::IncomingBody::finish(body);
1153+
Ok(FutureTrailers { inner: trailers })
1154+
}
1155+
11501156
fn stream(&self) -> PyResult<InputStream> {
11511157
let stream = self.inner()?.stream().to_pyres()?;
11521158
Ok(InputStream {

host/src/lib.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//! [DataFusion]: https://datafusion.apache.org/
55
use std::{any::Any, io::Cursor, ops::DerefMut, sync::Arc};
66

7+
use ::http::HeaderName;
78
use arrow::datatypes::DataType;
89
use datafusion_common::{DataFusionError, Result as DataFusionResult};
910
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature};
@@ -20,7 +21,10 @@ use wasmtime_wasi_http::{
2021
HttpResult, WasiHttpCtx, WasiHttpView,
2122
bindings::http::types::ErrorCode as HttpErrorCode,
2223
body::HyperOutgoingBody,
23-
types::{HostFutureIncomingResponse, OutgoingRequestConfig, default_send_request_handler},
24+
types::{
25+
DEFAULT_FORBIDDEN_HEADERS, HostFutureIncomingResponse, OutgoingRequestConfig,
26+
default_send_request_handler,
27+
},
2428
};
2529

2630
use crate::{
@@ -112,9 +116,12 @@ impl WasiHttpView for WasmStateImpl {
112116

113117
fn send_request(
114118
&mut self,
115-
request: hyper::Request<HyperOutgoingBody>,
119+
mut request: hyper::Request<HyperOutgoingBody>,
116120
config: OutgoingRequestConfig,
117121
) -> HttpResult<HostFutureIncomingResponse> {
122+
// Python `requests` sends this so we allow it but later drop it from the actual request.
123+
request.headers_mut().remove(hyper::header::CONNECTION);
124+
118125
// technically we could return an error straight away, but `urllib3` doesn't handle that super well, so we
119126
// create a future and validate the error in there (before actually starting the request of course)
120127

@@ -134,6 +141,15 @@ impl WasiHttpView for WasmStateImpl {
134141

135142
Ok(HostFutureIncomingResponse::pending(handle))
136143
}
144+
145+
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
146+
// Python `requests` sends this so we allow it but later drop it from the actual request.
147+
if name == hyper::header::CONNECTION {
148+
return false;
149+
}
150+
151+
DEFAULT_FORBIDDEN_HEADERS.contains(name)
152+
}
137153
}
138154

139155
/// Pre-compiled WASM component.

host/tests/integration_tests/python/runtime/http.rs

Lines changed: 151 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
use std::sync::Arc;
22

33
use arrow::{
4-
array::{Array, StringBuilder},
4+
array::{Array, StringArray, StringBuilder},
55
datatypes::{DataType, Field},
66
};
77
use datafusion_common::ScalarValue;
88
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
99
use datafusion_udf_wasm_host::{
1010
WasmPermissions, WasmScalarUdf,
11-
http::{AllowCertainHttpRequests, Matcher},
11+
http::{AllowCertainHttpRequests, HttpRequestValidator, Matcher},
1212
};
1313
use wasmtime_wasi_http::types::DEFAULT_FORBIDDEN_HEADERS;
1414
use wiremock::{Mock, MockServer, ResponseTemplate, matchers};
@@ -18,6 +18,46 @@ use crate::integration_tests::{
1818
test_utils::ColumnarValueExt,
1919
};
2020

21+
#[tokio::test(flavor = "multi_thread")]
22+
async fn test_requests_simple() {
23+
const CODE: &str = r#"
24+
import requests
25+
26+
def perform_request(url: str) -> str:
27+
return requests.get(url).text
28+
"#;
29+
30+
let server = MockServer::start().await;
31+
Mock::given(matchers::any())
32+
.respond_with(ResponseTemplate::new(200).set_body_string("hello world!"))
33+
.expect(1)
34+
.mount(&server)
35+
.await;
36+
37+
let mut permissions = AllowCertainHttpRequests::new();
38+
permissions.allow(Matcher {
39+
method: http::Method::GET,
40+
host: server.address().ip().to_string().into(),
41+
port: server.address().port(),
42+
});
43+
let udf = python_udf_with_permissions(CODE, permissions).await;
44+
45+
let array = udf
46+
.invoke_with_args(ScalarFunctionArgs {
47+
args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(server.uri())))],
48+
arg_fields: vec![Arc::new(Field::new("uri", DataType::Utf8, true))],
49+
number_rows: 1,
50+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
51+
})
52+
.unwrap()
53+
.unwrap_array();
54+
55+
assert_eq!(
56+
array.as_ref(),
57+
&StringArray::from_iter([Some("hello world!".to_owned()),]) as &dyn Array,
58+
);
59+
}
60+
2161
#[tokio::test(flavor = "multi_thread")]
2262
async fn test_urllib3_unguarded_fail() {
2363
const CODE: &str = r#"
@@ -131,6 +171,7 @@ def perform_request(url: str) -> str:
131171
#[tokio::test(flavor = "multi_thread")]
132172
async fn test_integration() {
133173
const CODE: &str = r#"
174+
import requests
134175
import urllib3
135176
136177
def _headers_str_to_dict(headers: str) -> dict[str, str]:
@@ -152,14 +193,38 @@ def _headers_dict_to_str(headers: dict[str, str]) -> str:
152193
else:
153194
return headers
154195
155-
def perform_request(method: str, url: str, headers: str | None, body: str | None) -> str:
196+
def test_requests(method: str, url: str, headers: str | None, body: str | None) -> str:
197+
try:
198+
resp = requests.request(
199+
method=method,
200+
url=url,
201+
headers=_headers_str_to_dict(headers),
202+
data=body,
203+
)
204+
except requests.ConnectionError as e:
205+
(e,) = e.args
206+
assert isinstance(e, Exception)
207+
return f"ERR: {e}"
208+
except Exception as e:
209+
return f"ERR: {e}"
210+
211+
resp_status = resp.status_code
212+
resp_body = f"'{resp.text}'" if resp.text else "n/a"
213+
resp_headers = _headers_dict_to_str(resp.headers)
214+
215+
return f"OK: status={resp_status} headers={resp_headers} body={resp_body}"
216+
217+
def test_urllib3(method: str, url: str, headers: str | None, body: str | None) -> str:
156218
try:
157219
resp = urllib3.request(
158220
method=method,
159221
url=url,
160222
headers=_headers_str_to_dict(headers),
161223
body=body,
162224
)
225+
except urllib3.exceptions.MaxRetryError as e:
226+
e = e.reason
227+
return f"ERR: {e}"
163228
except Exception as e:
164229
return f"ERR: {e}"
165230
@@ -169,6 +234,7 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
169234
170235
return f"OK: status={resp_status} headers={resp_headers} body={resp_body}"
171236
"#;
237+
const NUMBER_OF_IMPLEMENTATIONS: usize = 2;
172238

173239
let mut cases = vec![
174240
TestCase {
@@ -246,16 +312,32 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
246312
},
247313
TestCase {
248314
base: Some("http://test.com"),
249-
resp: Err("HTTPConnectionPool(host='test.com', port=80): Max retries exceeded with url: / (Caused by ProtocolError('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied')))".to_owned()),
315+
resp: Err("('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied'))".to_owned()),
316+
..Default::default()
317+
},
318+
// Python `requests` sends this so we allow it but later drop it from the actual request.
319+
TestCase {
320+
path: "/forbidden_header/connection".to_owned(),
321+
requ_headers: vec![(http::header::CONNECTION.to_string(), &["foo"])],
322+
resp: Ok(TestResponse {
323+
body: Some("header is filtered"),
324+
..Default::default()
325+
}),
250326
..Default::default()
251327
},
252328
];
253-
cases.extend(DEFAULT_FORBIDDEN_HEADERS.iter().map(|h| TestCase {
254-
path: format!("/forbidden_header/{h}"),
255-
requ_headers: vec![(h.to_string(), &["foo"])],
256-
resp: Err("Err { value: HeaderError_Forbidden }".to_owned()),
257-
..Default::default()
258-
}));
329+
cases.extend(
330+
DEFAULT_FORBIDDEN_HEADERS
331+
.iter()
332+
// Python `requests` sends this so we allow it but later drop it from the actual request.
333+
.filter(|h| *h != http::header::CONNECTION)
334+
.map(|h| TestCase {
335+
path: format!("/forbidden_header/{h}"),
336+
requ_headers: vec![(h.to_string(), &["foo"])],
337+
resp: Err("Err { value: HeaderError_Forbidden }".to_owned()),
338+
..Default::default()
339+
}),
340+
);
259341

260342
let server = MockServer::start().await;
261343
let mut permissions = AllowCertainHttpRequests::default();
@@ -267,7 +349,7 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
267349
let mut builder_result = StringBuilder::new();
268350

269351
for case in &cases {
270-
if let Some(mock) = case.mock(&server) {
352+
if let Some(mock) = case.mock(&server, NUMBER_OF_IMPLEMENTATIONS) {
271353
mock.mount(&server).await;
272354
}
273355
permissions.allow(case.matcher(&server));
@@ -309,37 +391,33 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
309391
}
310392
}
311393

312-
let udfs = WasmScalarUdf::new(
313-
python_component().await,
314-
&WasmPermissions::new().with_http(permissions),
315-
CODE.to_string(),
316-
)
317-
.await
318-
.unwrap();
319-
assert_eq!(udfs.len(), 1);
320-
let udf = udfs.into_iter().next().unwrap();
321-
322-
let array = udf
323-
.invoke_with_args(ScalarFunctionArgs {
324-
args: vec![
325-
ColumnarValue::Array(Arc::new(builder_method.finish())),
326-
ColumnarValue::Array(Arc::new(builder_url.finish())),
327-
ColumnarValue::Array(Arc::new(builder_headers.finish())),
328-
ColumnarValue::Array(Arc::new(builder_body.finish())),
329-
],
330-
arg_fields: vec![
331-
Arc::new(Field::new("method", DataType::Utf8, true)),
332-
Arc::new(Field::new("url", DataType::Utf8, true)),
333-
Arc::new(Field::new("headers", DataType::Utf8, true)),
334-
Arc::new(Field::new("body", DataType::Utf8, true)),
335-
],
336-
number_rows: cases.len(),
337-
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
338-
})
339-
.unwrap()
340-
.unwrap_array();
341-
342-
assert_eq!(array.as_ref(), &builder_result.finish() as &dyn Array,);
394+
let args = ScalarFunctionArgs {
395+
args: vec![
396+
ColumnarValue::Array(Arc::new(builder_method.finish())),
397+
ColumnarValue::Array(Arc::new(builder_url.finish())),
398+
ColumnarValue::Array(Arc::new(builder_headers.finish())),
399+
ColumnarValue::Array(Arc::new(builder_body.finish())),
400+
],
401+
arg_fields: vec![
402+
Arc::new(Field::new("method", DataType::Utf8, true)),
403+
Arc::new(Field::new("url", DataType::Utf8, true)),
404+
Arc::new(Field::new("headers", DataType::Utf8, true)),
405+
Arc::new(Field::new("body", DataType::Utf8, true)),
406+
],
407+
number_rows: cases.len(),
408+
return_field: Arc::new(Field::new("r", DataType::Utf8, true)),
409+
};
410+
let array_result = builder_result.finish();
411+
412+
let udfs = python_udfs_with_permissions(CODE, permissions).await;
413+
assert_eq!(udfs.len(), NUMBER_OF_IMPLEMENTATIONS);
414+
415+
for udf in udfs {
416+
println!("{}", udf.name());
417+
418+
let array = udf.invoke_with_args(args.clone()).unwrap().unwrap_array();
419+
assert_eq!(array.as_ref(), &array_result as &dyn Array);
420+
}
343421
}
344422

345423
#[derive(Debug, Clone)]
@@ -390,7 +468,7 @@ impl TestCase {
390468
}
391469
}
392470

393-
fn mock(&self, server: &MockServer) -> Option<Mock> {
471+
fn mock(&self, server: &MockServer, hits: usize) -> Option<Mock> {
394472
let Self {
395473
base,
396474
method,
@@ -417,6 +495,11 @@ impl TestCase {
417495
));
418496

419497
for (k, v) in requ_headers {
498+
// Python `requests` sends this so we allow it but later drop it from the actual request.
499+
if k.as_str() == http::header::CONNECTION {
500+
continue;
501+
}
502+
420503
builder = builder.and(matchers::headers(k.as_str(), v.to_vec()));
421504
}
422505

@@ -430,9 +513,9 @@ impl TestCase {
430513
resp_template = resp_template.set_body_string(resp_body);
431514
}
432515

433-
let mock = builder
434-
.respond_with(resp_template)
435-
.expect(resp.is_ok() as u64);
516+
let expect = if resp.is_ok() { hits as u64 } else { 0 };
517+
518+
let mock = builder.respond_with(resp_template).expect(expect);
436519
Some(mock)
437520
}
438521
}
@@ -477,3 +560,25 @@ impl wiremock::Match for NoForbiddenHeaders {
477560
.all(|h| !request.headers.contains_key(h))
478561
}
479562
}
563+
564+
async fn python_udfs_with_permissions<V>(code: &'static str, permissions: V) -> Vec<WasmScalarUdf>
565+
where
566+
V: HttpRequestValidator,
567+
{
568+
WasmScalarUdf::new(
569+
python_component().await,
570+
&WasmPermissions::new().with_http(permissions),
571+
code.to_owned(),
572+
)
573+
.await
574+
.unwrap()
575+
}
576+
577+
async fn python_udf_with_permissions<V>(code: &'static str, permissions: V) -> WasmScalarUdf
578+
where
579+
V: HttpRequestValidator,
580+
{
581+
let udfs = python_udfs_with_permissions(code, permissions).await;
582+
assert_eq!(udfs.len(), 1);
583+
udfs.into_iter().next().unwrap()
584+
}

0 commit comments

Comments
 (0)