diff --git a/wren-core-py/src/remote_functions.rs b/wren-core-py/src/remote_functions.rs index 9ba07d380..2d96f0fe0 100644 --- a/wren-core-py/src/remote_functions.rs +++ b/wren-core-py/src/remote_functions.rs @@ -28,8 +28,10 @@ pub struct PyRemoteFunction { pub function_type: String, pub name: String, pub return_type: Option, - pub param_names: Option>, - pub param_types: Option>, + /// It's a comma separated string of parameter names + pub param_names: Option, + /// It's a comma separated string of parameter types + pub param_types: Option, pub description: Option, } @@ -54,12 +56,26 @@ impl PyRemoteFunction { impl From for PyRemoteFunction { fn from(remote_function: wren_core::mdl::function::RemoteFunction) -> Self { + let param_names = remote_function.param_names.map(|names| { + names + .iter() + .map(|name| name.to_string()) + .collect::>() + .join(",") + }); + let param_types = remote_function.param_types.map(|types| { + types + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(",") + }); Self { function_type: remote_function.function_type.to_string(), name: remote_function.name, return_type: Some(remote_function.return_type), - param_names: remote_function.param_names, - param_types: remote_function.param_types, + param_names, + param_types, description: remote_function.description, } } @@ -69,14 +85,26 @@ impl From for wren_core::mdl::function::RemoteFunction { fn from( remote_function: PyRemoteFunction, ) -> wren_core::mdl::function::RemoteFunction { + let param_names = remote_function.param_names.map(|names| { + names + .split(",") + .map(|name| name.to_string()) + .collect::>() + }); + let param_types = remote_function.param_types.map(|types| { + types + .split(",") + .map(|t| t.to_string()) + .collect::>() + }); wren_core::mdl::function::RemoteFunction { function_type: FunctionType::from_str(&remote_function.function_type) .unwrap(), name: remote_function.name, // TODO: Get the return type form DataFusion SessionState return_type: remote_function.return_type.unwrap_or("string".to_string()), - param_names: remote_function.param_names, - param_types: remote_function.param_types, + param_names, + param_types, description: remote_function.description, } } diff --git a/wren-core-py/tests/functions.csv b/wren-core-py/tests/functions.csv index 33d302c19..cf9464767 100644 --- a/wren-core-py/tests/functions.csv +++ b/wren-core-py/tests/functions.csv @@ -1,3 +1,3 @@ -function_type,name,return_type,description -scalar,add_two,int,"Adds two numbers together." -window,max_if,int,"If the condition is true, returns the maximum value in the window." +function_type,name,return_type,param_names,param_types,description +scalar,add_two,int,"f1,f2","int,int","Adds two numbers together." +window,max_if,int,,,"If the condition is true, returns the maximum value in the window." diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 307ad620b..f94357334 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -63,3 +63,13 @@ def test_get_available_functions(): assert add_two["name"] == "add_two" assert add_two["function_type"] == "scalar" assert add_two["description"] == "Adds two numbers together." + assert add_two["return_type"] == "int" + assert add_two["param_names"] == "f1,f2" + assert add_two["param_types"] == "int,int" + + max_if = next(filter(lambda x: x["name"] == "max_if", map(lambda x: x.to_dict(), functions))) + assert max_if["name"] == "max_if" + assert max_if["function_type"] == "window" + assert max_if["param_names"] is None + assert max_if["param_types"] is None +