11use std:: sync:: Arc ;
22
33use arrow:: {
4- array:: { Array , StringBuilder } ,
4+ array:: { Array , StringArray , StringBuilder } ,
55 datatypes:: { DataType , Field } ,
66} ;
77use datafusion_common:: ScalarValue ;
88use datafusion_expr:: { ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl } ;
99use datafusion_udf_wasm_host:: {
1010 WasmPermissions , WasmScalarUdf ,
11- http:: { AllowCertainHttpRequests , Matcher } ,
11+ http:: { AllowCertainHttpRequests , HttpRequestValidator , Matcher } ,
1212} ;
1313use wasmtime_wasi_http:: types:: DEFAULT_FORBIDDEN_HEADERS ;
1414use wiremock:: { Mock , MockServer , ResponseTemplate , matchers} ;
@@ -18,6 +18,46 @@ use crate::integration_tests::{
1818 test_utils:: ColumnarValueExt ,
1919} ;
2020
21+ #[ tokio:: test( flavor = "multi_thread" ) ]
22+ async fn test_requests_simple ( ) {
23+ const CODE : & str = r#"
24+ import requests
25+
26+ def perform_request(url: str) -> str:
27+ return requests.get(url).text
28+ "# ;
29+
30+ let server = MockServer :: start ( ) . await ;
31+ Mock :: given ( matchers:: any ( ) )
32+ . respond_with ( ResponseTemplate :: new ( 200 ) . set_body_string ( "hello world!" ) )
33+ . expect ( 1 )
34+ . mount ( & server)
35+ . await ;
36+
37+ let mut permissions = AllowCertainHttpRequests :: new ( ) ;
38+ permissions. allow ( Matcher {
39+ method : http:: Method :: GET ,
40+ host : server. address ( ) . ip ( ) . to_string ( ) . into ( ) ,
41+ port : server. address ( ) . port ( ) ,
42+ } ) ;
43+ let udf = python_udf_with_permissions ( CODE , permissions) . await ;
44+
45+ let array = udf
46+ . invoke_with_args ( ScalarFunctionArgs {
47+ args : vec ! [ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( server. uri( ) ) ) ) ] ,
48+ arg_fields : vec ! [ Arc :: new( Field :: new( "uri" , DataType :: Utf8 , true ) ) ] ,
49+ number_rows : 1 ,
50+ return_field : Arc :: new ( Field :: new ( "r" , DataType :: Utf8 , true ) ) ,
51+ } )
52+ . unwrap ( )
53+ . unwrap_array ( ) ;
54+
55+ assert_eq ! (
56+ array. as_ref( ) ,
57+ & StringArray :: from_iter( [ Some ( "hello world!" . to_owned( ) ) , ] ) as & dyn Array ,
58+ ) ;
59+ }
60+
2161#[ tokio:: test( flavor = "multi_thread" ) ]
2262async fn test_urllib3_unguarded_fail ( ) {
2363 const CODE : & str = r#"
@@ -131,6 +171,7 @@ def perform_request(url: str) -> str:
131171#[ tokio:: test( flavor = "multi_thread" ) ]
132172async fn test_integration ( ) {
133173 const CODE : & str = r#"
174+ import requests
134175import urllib3
135176
136177def _headers_str_to_dict(headers: str) -> dict[str, str]:
@@ -152,14 +193,38 @@ def _headers_dict_to_str(headers: dict[str, str]) -> str:
152193 else:
153194 return headers
154195
155- def perform_request(method: str, url: str, headers: str | None, body: str | None) -> str:
196+ def test_requests(method: str, url: str, headers: str | None, body: str | None) -> str:
197+ try:
198+ resp = requests.request(
199+ method=method,
200+ url=url,
201+ headers=_headers_str_to_dict(headers),
202+ data=body,
203+ )
204+ except requests.ConnectionError as e:
205+ (e,) = e.args
206+ assert isinstance(e, Exception)
207+ return f"ERR: {e}"
208+ except Exception as e:
209+ return f"ERR: {e}"
210+
211+ resp_status = resp.status_code
212+ resp_body = f"'{resp.text}'" if resp.text else "n/a"
213+ resp_headers = _headers_dict_to_str(resp.headers)
214+
215+ return f"OK: status={resp_status} headers={resp_headers} body={resp_body}"
216+
217+ def test_urllib3(method: str, url: str, headers: str | None, body: str | None) -> str:
156218 try:
157219 resp = urllib3.request(
158220 method=method,
159221 url=url,
160222 headers=_headers_str_to_dict(headers),
161223 body=body,
162224 )
225+ except urllib3.exceptions.MaxRetryError as e:
226+ e = e.reason
227+ return f"ERR: {e}"
163228 except Exception as e:
164229 return f"ERR: {e}"
165230
@@ -169,6 +234,7 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
169234
170235 return f"OK: status={resp_status} headers={resp_headers} body={resp_body}"
171236"# ;
237+ const NUMBER_OF_IMPLEMENTATIONS : usize = 2 ;
172238
173239 let mut cases = vec ! [
174240 TestCase {
@@ -246,16 +312,32 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
246312 } ,
247313 TestCase {
248314 base: Some ( "http://test.com" ) ,
249- resp: Err ( "HTTPConnectionPool(host='test.com', port=80): Max retries exceeded with url: / (Caused by ProtocolError('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied')))" . to_owned( ) ) ,
315+ resp: Err ( "('Connection aborted.', WasiErrorCode('Request failed with wasi http error ErrorCode_HttpRequestDenied'))" . to_owned( ) ) ,
316+ ..Default :: default ( )
317+ } ,
318+ // Python `requests` sends this so we allow it but later drop it from the actual request.
319+ TestCase {
320+ path: "/forbidden_header/connection" . to_owned( ) ,
321+ requ_headers: vec![ ( http:: header:: CONNECTION . to_string( ) , & [ "foo" ] ) ] ,
322+ resp: Ok ( TestResponse {
323+ body: Some ( "header is filtered" ) ,
324+ ..Default :: default ( )
325+ } ) ,
250326 ..Default :: default ( )
251327 } ,
252328 ] ;
253- cases. extend ( DEFAULT_FORBIDDEN_HEADERS . iter ( ) . map ( |h| TestCase {
254- path : format ! ( "/forbidden_header/{h}" ) ,
255- requ_headers : vec ! [ ( h. to_string( ) , & [ "foo" ] ) ] ,
256- resp : Err ( "Err { value: HeaderError_Forbidden }" . to_owned ( ) ) ,
257- ..Default :: default ( )
258- } ) ) ;
329+ cases. extend (
330+ DEFAULT_FORBIDDEN_HEADERS
331+ . iter ( )
332+ // Python `requests` sends this so we allow it but later drop it from the actual request.
333+ . filter ( |h| * h != http:: header:: CONNECTION )
334+ . map ( |h| TestCase {
335+ path : format ! ( "/forbidden_header/{h}" ) ,
336+ requ_headers : vec ! [ ( h. to_string( ) , & [ "foo" ] ) ] ,
337+ resp : Err ( "Err { value: HeaderError_Forbidden }" . to_owned ( ) ) ,
338+ ..Default :: default ( )
339+ } ) ,
340+ ) ;
259341
260342 let server = MockServer :: start ( ) . await ;
261343 let mut permissions = AllowCertainHttpRequests :: default ( ) ;
@@ -267,7 +349,7 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
267349 let mut builder_result = StringBuilder :: new ( ) ;
268350
269351 for case in & cases {
270- if let Some ( mock) = case. mock ( & server) {
352+ if let Some ( mock) = case. mock ( & server, NUMBER_OF_IMPLEMENTATIONS ) {
271353 mock. mount ( & server) . await ;
272354 }
273355 permissions. allow ( case. matcher ( & server) ) ;
@@ -309,37 +391,33 @@ def perform_request(method: str, url: str, headers: str | None, body: str | None
309391 }
310392 }
311393
312- let udfs = WasmScalarUdf :: new (
313- python_component ( ) . await ,
314- & WasmPermissions :: new ( ) . with_http ( permissions) ,
315- CODE . to_string ( ) ,
316- )
317- . await
318- . unwrap ( ) ;
319- assert_eq ! ( udfs. len( ) , 1 ) ;
320- let udf = udfs. into_iter ( ) . next ( ) . unwrap ( ) ;
321-
322- let array = udf
323- . invoke_with_args ( ScalarFunctionArgs {
324- args : vec ! [
325- ColumnarValue :: Array ( Arc :: new( builder_method. finish( ) ) ) ,
326- ColumnarValue :: Array ( Arc :: new( builder_url. finish( ) ) ) ,
327- ColumnarValue :: Array ( Arc :: new( builder_headers. finish( ) ) ) ,
328- ColumnarValue :: Array ( Arc :: new( builder_body. finish( ) ) ) ,
329- ] ,
330- arg_fields : vec ! [
331- Arc :: new( Field :: new( "method" , DataType :: Utf8 , true ) ) ,
332- Arc :: new( Field :: new( "url" , DataType :: Utf8 , true ) ) ,
333- Arc :: new( Field :: new( "headers" , DataType :: Utf8 , true ) ) ,
334- Arc :: new( Field :: new( "body" , DataType :: Utf8 , true ) ) ,
335- ] ,
336- number_rows : cases. len ( ) ,
337- return_field : Arc :: new ( Field :: new ( "r" , DataType :: Utf8 , true ) ) ,
338- } )
339- . unwrap ( )
340- . unwrap_array ( ) ;
341-
342- assert_eq ! ( array. as_ref( ) , & builder_result. finish( ) as & dyn Array , ) ;
394+ let args = ScalarFunctionArgs {
395+ args : vec ! [
396+ ColumnarValue :: Array ( Arc :: new( builder_method. finish( ) ) ) ,
397+ ColumnarValue :: Array ( Arc :: new( builder_url. finish( ) ) ) ,
398+ ColumnarValue :: Array ( Arc :: new( builder_headers. finish( ) ) ) ,
399+ ColumnarValue :: Array ( Arc :: new( builder_body. finish( ) ) ) ,
400+ ] ,
401+ arg_fields : vec ! [
402+ Arc :: new( Field :: new( "method" , DataType :: Utf8 , true ) ) ,
403+ Arc :: new( Field :: new( "url" , DataType :: Utf8 , true ) ) ,
404+ Arc :: new( Field :: new( "headers" , DataType :: Utf8 , true ) ) ,
405+ Arc :: new( Field :: new( "body" , DataType :: Utf8 , true ) ) ,
406+ ] ,
407+ number_rows : cases. len ( ) ,
408+ return_field : Arc :: new ( Field :: new ( "r" , DataType :: Utf8 , true ) ) ,
409+ } ;
410+ let array_result = builder_result. finish ( ) ;
411+
412+ let udfs = python_udfs_with_permissions ( CODE , permissions) . await ;
413+ assert_eq ! ( udfs. len( ) , NUMBER_OF_IMPLEMENTATIONS ) ;
414+
415+ for udf in udfs {
416+ println ! ( "{}" , udf. name( ) ) ;
417+
418+ let array = udf. invoke_with_args ( args. clone ( ) ) . unwrap ( ) . unwrap_array ( ) ;
419+ assert_eq ! ( array. as_ref( ) , & array_result as & dyn Array ) ;
420+ }
343421}
344422
345423#[ derive( Debug , Clone ) ]
@@ -390,7 +468,7 @@ impl TestCase {
390468 }
391469 }
392470
393- fn mock ( & self , server : & MockServer ) -> Option < Mock > {
471+ fn mock ( & self , server : & MockServer , hits : usize ) -> Option < Mock > {
394472 let Self {
395473 base,
396474 method,
@@ -417,6 +495,11 @@ impl TestCase {
417495 ) ) ;
418496
419497 for ( k, v) in requ_headers {
498+ // Python `requests` sends this so we allow it but later drop it from the actual request.
499+ if k. as_str ( ) == http:: header:: CONNECTION {
500+ continue ;
501+ }
502+
420503 builder = builder. and ( matchers:: headers ( k. as_str ( ) , v. to_vec ( ) ) ) ;
421504 }
422505
@@ -430,9 +513,9 @@ impl TestCase {
430513 resp_template = resp_template. set_body_string ( resp_body) ;
431514 }
432515
433- let mock = builder
434- . respond_with ( resp_template )
435- . expect ( resp . is_ok ( ) as u64 ) ;
516+ let expect = if resp . is_ok ( ) { hits as u64 } else { 0 } ;
517+
518+ let mock = builder . respond_with ( resp_template ) . expect ( expect ) ;
436519 Some ( mock)
437520 }
438521}
@@ -477,3 +560,25 @@ impl wiremock::Match for NoForbiddenHeaders {
477560 . all ( |h| !request. headers . contains_key ( h) )
478561 }
479562}
563+
564+ async fn python_udfs_with_permissions < V > ( code : & ' static str , permissions : V ) -> Vec < WasmScalarUdf >
565+ where
566+ V : HttpRequestValidator ,
567+ {
568+ WasmScalarUdf :: new (
569+ python_component ( ) . await ,
570+ & WasmPermissions :: new ( ) . with_http ( permissions) ,
571+ code. to_owned ( ) ,
572+ )
573+ . await
574+ . unwrap ( )
575+ }
576+
577+ async fn python_udf_with_permissions < V > ( code : & ' static str , permissions : V ) -> WasmScalarUdf
578+ where
579+ V : HttpRequestValidator ,
580+ {
581+ let udfs = python_udfs_with_permissions ( code, permissions) . await ;
582+ assert_eq ! ( udfs. len( ) , 1 ) ;
583+ udfs. into_iter ( ) . next ( ) . unwrap ( )
584+ }
0 commit comments