diff --git a/README.md b/README.md index e026e67f3..90419f23c 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,9 @@ fn main() -> hyperlight_host::Result<()> { let message = "Hello, World! I am executing inside of a VM :)\n".to_string(); // in order to call a function it first must be defined in the guest and exposed so that // the host can call it - let result = multi_use_sandbox.call_guest_function_by_name( + let result: i32 = multi_use_sandbox.call_guest_function_by_name( "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(message.clone())]), + message, ); assert!(result.is_ok()); diff --git a/fuzz/fuzz_targets/guest_call.rs b/fuzz/fuzz_targets/guest_call.rs index a89bb15de..e3b3dd398 100644 --- a/fuzz/fuzz_targets/guest_call.rs +++ b/fuzz/fuzz_targets/guest_call.rs @@ -42,8 +42,8 @@ fuzz_target!( SANDBOX.set(Mutex::new(mu_sbox)).unwrap(); }, - |data: (ReturnType, Option>)| { + |data: (ReturnType, Vec)| { let mut sandbox = SANDBOX.get().unwrap().lock().unwrap(); - let _ = sandbox.call_guest_function_by_name("PrintOutput", data.0, data.1); + let _ = sandbox.call_type_erased_guest_function_by_name("PrintOutput", data.0, data.1); } ); diff --git a/fuzz/fuzz_targets/host_call.rs b/fuzz/fuzz_targets/host_call.rs index 423186f16..2e7fe4bd6 100644 --- a/fuzz/fuzz_targets/host_call.rs +++ b/fuzz/fuzz_targets/host_call.rs @@ -45,7 +45,7 @@ fuzz_target!( let (host_func_name, host_func_return, mut host_func_params) = data; let mut sandbox = SANDBOX.get().unwrap().lock().unwrap(); host_func_params.insert(0, ParameterValue::String(host_func_name)); - match sandbox.call_guest_function_by_name("FuzzHostFunc", host_func_return, Some(host_func_params)) { + match sandbox.call_type_erased_guest_function_by_name("FuzzHostFunc", host_func_return, host_func_params) { Err(HyperlightError::GuestAborted(_, message)) if !message.contains("Host Function Not Found") => { // We don't allow GuestAborted errors, except for the "Host Function Not Found" case panic!("Guest Aborted: {}", message); diff --git a/fuzz/fuzz_targets/host_print.rs b/fuzz/fuzz_targets/host_print.rs index c90245e13..6829399e7 100644 --- a/fuzz/fuzz_targets/host_print.rs +++ b/fuzz/fuzz_targets/host_print.rs @@ -2,7 +2,6 @@ use std::sync::{Mutex, OnceLock}; -use hyperlight_host::func::{ParameterValue, ReturnType, ReturnValue}; use hyperlight_host::sandbox::uninitialized::GuestBinary; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -28,22 +27,14 @@ fuzz_target!( SANDBOX.set(Mutex::new(mu_sbox)).unwrap(); }, - |data: ParameterValue| -> Corpus { - // only interested in String types - if !matches!(data, ParameterValue::String(_)) { - return Corpus::Reject; - } - + |data: String| -> Corpus { let mut sandbox = SANDBOX.get().unwrap().lock().unwrap(); - let res = sandbox.call_guest_function_by_name( + let len: i32 = sandbox.call_guest_function_by_name::( "PrintOutput", - ReturnType::Int, - Some(vec![data.clone()]), - ); - match res { - Ok(ReturnValue::Int(len)) => assert!(len >= 0), - _ => panic!("Unexpected return value: {:?}", res), - } + data, + ) + .expect("Unexpected return value"); + assert!(len >= 0); Corpus::Keep }); diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index 62624dd5c..fdabe8cae 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -17,7 +17,6 @@ limitations under the License. use std::time::Duration; use criterion::{criterion_group, criterion_main, Criterion}; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_host::sandbox::{MultiUseSandbox, SandboxConfiguration, UninitializedSandbox}; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -43,11 +42,7 @@ fn guest_call_benchmark(c: &mut Criterion) { b.iter(|| { call_ctx - .call( - "Echo", - ReturnType::Int, - Some(vec![ParameterValue::String("hello\n".to_string())]), - ) + .call::("Echo", "hello\n".to_string()) .unwrap() }); }); @@ -59,11 +54,7 @@ fn guest_call_benchmark(c: &mut Criterion) { b.iter(|| { sandbox - .call_guest_function_by_name( - "Echo", - ReturnType::Int, - Some(vec![ParameterValue::String("hello\n".to_string())]), - ) + .call_guest_function_by_name::("Echo", "hello\n".to_string()) .unwrap() }); }); @@ -88,13 +79,9 @@ fn guest_call_benchmark(c: &mut Criterion) { b.iter(|| { sandbox - .call_guest_function_by_name( + .call_guest_function_by_name::<()>( "LargeParameters", - ReturnType::Void, - Some(vec![ - ParameterValue::VecBytes(large_vec.clone()), - ParameterValue::String(large_string.clone()), - ]), + (large_vec.clone(), large_string.clone()), ) .unwrap() }); @@ -114,15 +101,7 @@ fn guest_call_benchmark(c: &mut Criterion) { uninitialized_sandbox.evolve(Noop::default()).unwrap(); let mut call_ctx = multiuse_sandbox.new_call_context(); - b.iter(|| { - call_ctx - .call( - "Add", - ReturnType::Int, - Some(vec![ParameterValue::Int(1), ParameterValue::Int(41)]), - ) - .unwrap() - }); + b.iter(|| call_ctx.call::("Add", (1_i32, 41_i32)).unwrap()); }); group.finish(); diff --git a/src/hyperlight_host/examples/func_ctx/main.rs b/src/hyperlight_host/examples/func_ctx/main.rs index 8fe041393..4bfda0e23 100644 --- a/src/hyperlight_host/examples/func_ctx/main.rs +++ b/src/hyperlight_host/examples/func_ctx/main.rs @@ -14,12 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_host::func::call_ctx::MultiUseGuestCallContext; use hyperlight_host::sandbox::{MultiUseSandbox, UninitializedSandbox}; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; -use hyperlight_host::{new_error, GuestBinary, Result}; +use hyperlight_host::{GuestBinary, Result}; use hyperlight_testing::simple_guest_as_string; fn main() { @@ -47,29 +46,10 @@ fn main() { /// call `ctx.finish()` and return the resulting `MultiUseSandbox`. Return an `Err` /// if anything failed. fn do_calls(mut ctx: MultiUseGuestCallContext) -> Result { - { - let res1: String = { - let rv = ctx.call( - "Echo", - ReturnType::Int, - Some(vec![ParameterValue::String("hello".to_string())]), - )?; - rv.try_into() - } - .map_err(|e| new_error!("failed to get Echo result: {}", e))?; - println!("got Echo res: {res1}"); - } - { - let res2: i32 = { - let rv = ctx.call( - "CallMalloc", - ReturnType::Int, - Some(vec![ParameterValue::Int(200)]), - )?; - rv.try_into() - } - .map_err(|e| new_error!("failed to get CallMalloc result: {}", e))?; - println!("got CallMalloc res: {res2}"); - } + let res: String = ctx.call("Echo", "hello".to_string())?; + println!("got Echo res: {res}"); + + let res: i32 = ctx.call("CallMalloc", 200_i32)?; + println!("got CallMalloc res: {res}"); ctx.finish() } diff --git a/src/hyperlight_host/examples/guest-debugging/main.rs b/src/hyperlight_host/examples/guest-debugging/main.rs index 2ea5a6eb7..fbdf3a84d 100644 --- a/src/hyperlight_host/examples/guest-debugging/main.rs +++ b/src/hyperlight_host/examples/guest-debugging/main.rs @@ -16,7 +16,6 @@ limitations under the License. #![allow(clippy::disallowed_macros)] use std::thread; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; #[cfg(gdb)] use hyperlight_host::sandbox::config::DebugInfo; use hyperlight_host::sandbox::SandboxConfiguration; @@ -62,13 +61,12 @@ fn main() -> hyperlight_host::Result<()> { // Call guest function let message = "Hello, World! I am executing inside of a VM :)\n".to_string(); - let result = multi_use_sandbox.call_guest_function_by_name( - "PrintOutput", // function must be defined in the guest binary - ReturnType::Int, - Some(vec![ParameterValue::String(message.clone())]), - ); - - assert!(result.is_ok()); + multi_use_sandbox + .call_guest_function_by_name::( + "PrintOutput", // function must be defined in the guest binary + message.clone(), + ) + .unwrap(); Ok(()) } diff --git a/src/hyperlight_host/examples/hello-world/main.rs b/src/hyperlight_host/examples/hello-world/main.rs index 2ee72a945..77133ba83 100644 --- a/src/hyperlight_host/examples/hello-world/main.rs +++ b/src/hyperlight_host/examples/hello-world/main.rs @@ -16,7 +16,6 @@ limitations under the License. #![allow(clippy::disallowed_macros)] use std::thread; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; use hyperlight_host::{MultiUseSandbox, UninitializedSandbox}; @@ -42,13 +41,12 @@ fn main() -> hyperlight_host::Result<()> { // Call guest function let message = "Hello, World! I am executing inside of a VM :)\n".to_string(); - let result = multi_use_sandbox.call_guest_function_by_name( - "PrintOutput", // function must be defined in the guest binary - ReturnType::Int, - Some(vec![ParameterValue::String(message.clone())]), - ); - - assert!(result.is_ok()); + multi_use_sandbox + .call_guest_function_by_name::( + "PrintOutput", // function must be defined in the guest binary + message, + ) + .unwrap(); Ok(()) } diff --git a/src/hyperlight_host/examples/logging/main.rs b/src/hyperlight_host/examples/logging/main.rs index 3948947e0..fc5e43f6c 100644 --- a/src/hyperlight_host/examples/logging/main.rs +++ b/src/hyperlight_host/examples/logging/main.rs @@ -16,7 +16,6 @@ limitations under the License. #![allow(clippy::disallowed_macros)] extern crate hyperlight_host; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_host::sandbox::uninitialized::UninitializedSandbox; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -53,12 +52,9 @@ fn main() -> Result<()> { // Call a guest function 5 times to generate some log entries. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("a".to_string())]), - ); - result.unwrap(); + multiuse_sandbox + .call_guest_function_by_name::("Echo", "a".to_string()) + .unwrap(); } // Define a message to send to the guest. @@ -67,12 +63,9 @@ fn main() -> Result<()> { // Call a guest function that calls the HostPrint host function 5 times to generate some log entries. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ); - result.unwrap(); + multiuse_sandbox + .call_guest_function_by_name::("PrintOutput", msg.clone()) + .unwrap(); } Ok(()) }; @@ -95,10 +88,8 @@ fn main() -> Result<()> { for _ in 0..5 { let mut ctx = multiuse_sandbox.new_call_context(); - let result = ctx.call("Spin", ReturnType::Void, None); - assert!(result.is_err()); - let result = ctx.finish(); - multiuse_sandbox = result.unwrap(); + ctx.call::<()>("Spin", ()).unwrap_err(); + multiuse_sandbox = ctx.finish().unwrap(); } Ok(()) diff --git a/src/hyperlight_host/examples/metrics/main.rs b/src/hyperlight_host/examples/metrics/main.rs index 979b29506..cd8ae0bdc 100644 --- a/src/hyperlight_host/examples/metrics/main.rs +++ b/src/hyperlight_host/examples/metrics/main.rs @@ -17,7 +17,6 @@ limitations under the License. extern crate hyperlight_host; use std::thread::{spawn, JoinHandle}; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_host::sandbox::uninitialized::UninitializedSandbox; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -65,12 +64,9 @@ fn do_hyperlight_stuff() { // Call a guest function 5 times to generate some metrics. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("a".to_string())]), - ); - assert!(result.is_ok()); + multiuse_sandbox + .call_guest_function_by_name::("Echo", "a".to_string()) + .unwrap(); } // Define a message to send to the guest. @@ -79,12 +75,9 @@ fn do_hyperlight_stuff() { // Call a guest function that calls the HostPrint host function 5 times to generate some metrics. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ); - assert!(result.is_ok()); + multiuse_sandbox + .call_guest_function_by_name::("PrintOutput", msg.clone()) + .unwrap(); } Ok(()) }); @@ -108,11 +101,8 @@ fn do_hyperlight_stuff() { for _ in 0..5 { let mut ctx = multiuse_sandbox.new_call_context(); - let result = ctx.call("Spin", ReturnType::Void, None); - assert!(result.is_err()); - let result = ctx.finish(); - assert!(result.is_ok()); - multiuse_sandbox = result.unwrap(); + ctx.call::<()>("Spin", ()).unwrap_err(); + multiuse_sandbox = ctx.finish().unwrap(); } for join_handle in join_handles { diff --git a/src/hyperlight_host/examples/tracing-chrome/main.rs b/src/hyperlight_host/examples/tracing-chrome/main.rs index b5fe3f79a..1d68aec89 100644 --- a/src/hyperlight_host/examples/tracing-chrome/main.rs +++ b/src/hyperlight_host/examples/tracing-chrome/main.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #![allow(clippy::disallowed_macros)] -use hyperlight_host::func::{ParameterValue, ReturnType, ReturnValue}; use hyperlight_host::sandbox::uninitialized::UninitializedSandbox; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -41,13 +40,9 @@ fn main() -> Result<()> { // do the function call let current_time = std::time::Instant::now(); - let res = sbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("Hello, World!".to_string())]), - )?; + let res: String = sbox.call_guest_function_by_name("Echo", "Hello, World!".to_string())?; let elapsed = current_time.elapsed(); println!("Function call finished in {:?}.", elapsed); - assert!(matches!(res, ReturnValue::String(s) if s == "Hello, World!")); + assert_eq!(res, "Hello, World!"); Ok(()) } diff --git a/src/hyperlight_host/examples/tracing-otlp/main.rs b/src/hyperlight_host/examples/tracing-otlp/main.rs index 97cb8d5c6..51c0eb10c 100644 --- a/src/hyperlight_host/examples/tracing-otlp/main.rs +++ b/src/hyperlight_host/examples/tracing-otlp/main.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #![allow(clippy::disallowed_macros)] -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; //use opentelemetry_sdk::resource::ResourceBuilder; use opentelemetry_sdk::trace::SdkTracerProvider; use rand::Rng; @@ -108,7 +107,7 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { let mut join_handles: Vec>> = vec![]; // Construct a new span named "hyperlight otel tracing example" with INFO level. - let span = span!(Level::INFO, "hyperlight otel tracing example",); + let span = span!(Level::INFO, "hyperlight otel tracing example"); let _entered = span.enter(); let should_exit = Arc::new(Mutex::new(false)); @@ -141,12 +140,9 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { // Call a guest function 5 times to generate some log entries. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("a".to_string())]), - ); - assert!(result.is_ok()); + multiuse_sandbox + .call_guest_function_by_name::("Echo", "a".to_string()) + .unwrap(); } // Define a message to send to the guest. @@ -155,12 +151,9 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { // Call a guest function that calls the HostPrint host function 5 times to generate some log entries. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ); - assert!(result.is_ok()); + multiuse_sandbox + .call_guest_function_by_name::("PrintOutput", msg.clone()) + .unwrap(); } // Call a function that gets cancelled by the host function 5 times to generate some log entries. @@ -177,11 +170,8 @@ fn run_example(wait_input: bool) -> HyperlightResult<()> { let _entered = span.enter(); let mut ctx = multiuse_sandbox.new_call_context(); - let result = ctx.call("Spin", ReturnType::Void, None); - assert!(result.is_err()); - let result = ctx.finish(); - assert!(result.is_ok()); - multiuse_sandbox = result.unwrap(); + ctx.call::<()>("Spin", ()).unwrap_err(); + multiuse_sandbox = ctx.finish().unwrap(); } let sleep_for = { let mut rng = rand::rng(); diff --git a/src/hyperlight_host/examples/tracing-tracy/main.rs b/src/hyperlight_host/examples/tracing-tracy/main.rs index 216eef1fc..1b2b6e5a3 100644 --- a/src/hyperlight_host/examples/tracing-tracy/main.rs +++ b/src/hyperlight_host/examples/tracing-tracy/main.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #![allow(clippy::disallowed_macros)] -use hyperlight_host::func::{ParameterValue, ReturnType, ReturnValue}; use hyperlight_host::sandbox::uninitialized::UninitializedSandbox; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -47,13 +46,9 @@ fn main() -> Result<()> { // do the function call let current_time = std::time::Instant::now(); - let res = sbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("Hello, World!".to_string())]), - )?; + let res: String = sbox.call_guest_function_by_name("Echo", "Hello, World!".to_string())?; let elapsed = current_time.elapsed(); println!("Function call finished in {:?}.", elapsed); - assert!(matches!(res, ReturnValue::String(s) if s == "Hello, World!")); + assert_eq!(res, "Hello, World!"); Ok(()) } diff --git a/src/hyperlight_host/examples/tracing/main.rs b/src/hyperlight_host/examples/tracing/main.rs index a3ca8fb7e..e270b48d1 100644 --- a/src/hyperlight_host/examples/tracing/main.rs +++ b/src/hyperlight_host/examples/tracing/main.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #![allow(clippy::disallowed_macros)] -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use tracing::{span, Level}; extern crate hyperlight_host; use std::thread::{spawn, JoinHandle}; @@ -54,7 +53,7 @@ fn run_example() -> Result<()> { let mut join_handles: Vec>> = vec![]; // Construct a new span named "hyperlight tracing example" with INFO level. - let span = span!(Level::INFO, "hyperlight tracing example",); + let span = span!(Level::INFO, "hyperlight tracing example"); let _entered = span.enter(); for i in 0..10 { @@ -82,12 +81,9 @@ fn run_example() -> Result<()> { // Call a guest function 5 times to generate some log entries. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("a".to_string())]), - ); - assert!(result.is_ok()); + multiuse_sandbox + .call_guest_function_by_name::("Echo", "a".to_string()) + .unwrap(); } // Define a message to send to the guest. @@ -96,12 +92,9 @@ fn run_example() -> Result<()> { // Call a guest function that calls the HostPrint host function 5 times to generate some log entries. for _ in 0..5 { - let result = multiuse_sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ); - assert!(result.is_ok()); + multiuse_sandbox + .call_guest_function_by_name::("PrintOutput", msg.clone()) + .unwrap(); } Ok(()) }); @@ -132,11 +125,8 @@ fn run_example() -> Result<()> { let _entered = span.enter(); let mut ctx = multiuse_sandbox.new_call_context(); - let result = ctx.call("Spin", ReturnType::Void, None); - assert!(result.is_err()); - let result = ctx.finish(); - assert!(result.is_ok()); - multiuse_sandbox = result.unwrap(); + ctx.call::<()>("Spin", ()).unwrap_err(); + multiuse_sandbox = ctx.finish().unwrap(); } for join_handle in join_handles { diff --git a/src/hyperlight_host/src/func/call_ctx.rs b/src/hyperlight_host/src/func/call_ctx.rs index 619c45818..412643c3d 100644 --- a/src/hyperlight_host/src/func/call_ctx.rs +++ b/src/hyperlight_host/src/func/call_ctx.rs @@ -14,12 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -use hyperlight_common::flatbuffer_wrappers::function_types::{ - ParameterValue, ReturnType, ReturnValue, -}; use tracing::{instrument, Span}; use super::guest_dispatch::call_function_on_guest; +use super::{ParameterTuple, SupportedReturnType}; use crate::{MultiUseSandbox, Result}; /// A context for calling guest functions. /// @@ -61,18 +59,19 @@ impl MultiUseGuestCallContext { /// If you want to reset state, call `finish()` on this `MultiUseGuestCallContext` /// and get a new one from the resulting `MultiUseSandbox` #[instrument(err(Debug),skip(self, args),parent = Span::current())] - pub fn call( + pub fn call( &mut self, func_name: &str, - func_ret_type: ReturnType, - args: Option>, - ) -> Result { + args: impl ParameterTuple, + ) -> Result { // we are guaranteed to be holding a lock now, since `self` can't // exist without doing so. Since GuestCallContext is effectively // !Send (and !Sync), we also don't need to worry about // synchronization - call_function_on_guest(&mut self.sbox, func_name, func_ret_type, args) + let ret = + call_function_on_guest(&mut self.sbox, func_name, Output::TYPE, args.into_value()); + Output::from_value(ret?) } /// Close out the context and get back the internally-stored @@ -104,11 +103,9 @@ mod tests { use std::sync::mpsc::sync_channel; use std::thread::{self, JoinHandle}; - use hyperlight_common::flatbuffer_wrappers::function_types::{ - ParameterValue, ReturnType, ReturnValue, - }; use hyperlight_testing::simple_guest_as_string; + use super::MultiUseGuestCallContext; use crate::sandbox_state::sandbox::EvolvableSandbox; use crate::sandbox_state::transition::Noop; use crate::{GuestBinary, HyperlightError, MultiUseSandbox, Result, UninitializedSandbox}; @@ -148,10 +145,7 @@ mod tests { while let Ok(calls) = recv.recv() { let mut ctx = sbox.new_call_context(); for call in calls { - let res = ctx - .call(call.func_name.as_str(), call.ret_type, call.params) - .unwrap(); - assert_eq!(call.expected_ret, res); + call.call(&mut ctx); } sbox = ctx.finish().unwrap(); } @@ -162,21 +156,16 @@ mod tests { .map(|i| { let sender = snd.clone(); thread::spawn(move || { - let calls: Vec = vec![ - TestFuncCall { - func_name: "Echo".to_string(), - ret_type: ReturnType::String, - params: Some(vec![ParameterValue::String( - format!("Hello {}", i).to_string(), - )]), - expected_ret: ReturnValue::String(format!("Hello {}", i).to_string()), - }, - TestFuncCall { - func_name: "CallMalloc".to_string(), - ret_type: ReturnType::Int, - params: Some(vec![ParameterValue::Int(i + 2)]), - expected_ret: ReturnValue::Int(i + 2), - }, + let calls = vec![ + TestFuncCall::new(move |ctx| { + let msg = format!("Hello {}", i); + let ret: String = ctx.call("Echo", msg.clone()).unwrap(); + assert_eq!(ret, msg) + }), + TestFuncCall::new(move |ctx| { + let ret: i32 = ctx.call("CallMalloc", i + 2).unwrap(); + assert_eq!(ret, i + 2) + }), ]; sender.send(calls).unwrap(); }) @@ -206,15 +195,11 @@ mod tests { let mut ctx = self.sandbox.new_call_context(); let mut sum: i32 = 0; for n in 0..i { - let result = ctx.call( - "AddToStatic", - ReturnType::Int, - Some(vec![ParameterValue::Int(n)]), - ); + let result = ctx.call::("AddToStatic", n); sum += n; println!("{:?}", result); let result = result.unwrap(); - assert_eq!(result, ReturnValue::Int(sum)); + assert_eq!(result, sum); } let result = ctx.finish(); assert!(result.is_ok()); @@ -224,14 +209,12 @@ mod tests { pub fn call_add_to_static(mut self, i: i32) -> Result<()> { for n in 0..i { - let result = self.sandbox.call_guest_function_by_name( - "AddToStatic", - ReturnType::Int, - Some(vec![ParameterValue::Int(n)]), - ); + let result = self + .sandbox + .call_guest_function_by_name::("AddToStatic", n); println!("{:?}", result); let result = result.unwrap(); - assert_eq!(result, ReturnValue::Int(n)); + assert_eq!(result, n); } Ok(()) } @@ -251,10 +234,15 @@ mod tests { assert!(result.is_ok()); } - struct TestFuncCall { - func_name: String, - ret_type: ReturnType, - params: Option>, - expected_ret: ReturnValue, + struct TestFuncCall(Box); + + impl TestFuncCall { + fn new(f: impl FnOnce(&mut MultiUseGuestCallContext) + Send + 'static) -> Self { + TestFuncCall(Box::new(f)) + } + + fn call(self, ctx: &mut MultiUseGuestCallContext) { + (self.0)(ctx); + } } } diff --git a/src/hyperlight_host/src/func/guest_dispatch.rs b/src/hyperlight_host/src/func/guest_dispatch.rs index 8b512cead..ad980b146 100644 --- a/src/hyperlight_host/src/func/guest_dispatch.rs +++ b/src/hyperlight_host/src/func/guest_dispatch.rs @@ -33,19 +33,19 @@ use crate::{HyperlightError, Result}; parent = Span::current(), level = "Trace" )] -pub(crate) fn call_function_on_guest( - wrapper_getter: &mut WrapperGetterT, +pub(crate) fn call_function_on_guest( + wrapper_getter: &mut impl WrapperGetter, function_name: &str, - return_type: ReturnType, - args: Option>, + ret_type: ReturnType, + args: Vec, ) -> Result { let mut timedout = false; let fc = FunctionCall::new( function_name.to_string(), - args, + Some(args), FunctionCallType::Guest, - return_type, + ret_type, ); let buffer: Vec = fc @@ -83,7 +83,7 @@ pub(crate) fn call_function_on_guest( mem_mgr.check_stack_guard()?; // <- wrapper around mem_mgr `check_for_stack_guard` check_for_guest_error(mem_mgr)?; - mem_mgr + let ret = mem_mgr .as_mut() .get_guest_function_call_result() .map_err(|e| { @@ -100,7 +100,9 @@ pub(crate) fn call_function_on_guest( } else { e } - }) + })?; + + Ok(ret) } #[cfg(test)] @@ -110,7 +112,6 @@ mod tests { use hyperlight_testing::{callback_guest_as_string, simple_guest_as_string}; - use super::*; use crate::func::call_ctx::MultiUseGuestCallContext; use crate::sandbox::is_hypervisor_present; use crate::sandbox::uninitialized::GuestBinary; @@ -161,8 +162,7 @@ mod tests { let mut sbox: MultiUseSandbox = usbox.evolve(Noop::default())?; - let res = - sbox.call_guest_function_by_name("ViolateSeccompFilters", ReturnType::ULong, None); + let res: Result = sbox.call_guest_function_by_name("ViolateSeccompFilters", ()); #[cfg(feature = "seccomp")] match res { @@ -198,8 +198,7 @@ mod tests { let mut sbox: MultiUseSandbox = usbox.evolve(Noop::default())?; - let res = - sbox.call_guest_function_by_name("ViolateSeccompFilters", ReturnType::ULong, None); + let res: Result = sbox.call_guest_function_by_name("ViolateSeccompFilters", ()); match res { Ok(_) => {} @@ -311,15 +310,9 @@ mod tests { let msg = "Hello, World!!\n".to_string(); let len = msg.len() as i32; let mut ctx = mu_sbox.new_call_context(); - let result = ctx - .call( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ) - .unwrap(); + let result: i32 = ctx.call("PrintOutput", msg).unwrap(); - assert_eq!(result, ReturnValue::Int(len)); + assert_eq!(result, len); } fn call_guest_function_by_name_hv() { @@ -353,7 +346,7 @@ mod tests { )?; let sandbox: MultiUseSandbox = usbox.evolve(Noop::default())?; let mut ctx = sandbox.new_call_context(); - let result = ctx.call("Spin", ReturnType::Void, None); + let result: Result<()> = ctx.call("Spin", ()); assert!(result.is_err()); match result.unwrap_err() { @@ -417,7 +410,7 @@ mod tests { let sandbox: MultiUseSandbox = usbox.evolve(Noop::default()).unwrap(); let mut ctx = sandbox.new_call_context(); - let result = ctx.call("CallHostSpin", ReturnType::Void, None); + let result: Result<()> = ctx.call("CallHostSpin", ()); assert!(result.is_err()); match result.unwrap_err() { @@ -439,11 +432,7 @@ mod tests { let mut multi_use_sandbox: MultiUseSandbox = usbox.evolve(Noop::default()).unwrap(); - let res = multi_use_sandbox.call_guest_function_by_name( - "TriggerException", - ReturnType::Void, - None, - ); + let res: Result<()> = multi_use_sandbox.call_guest_function_by_name("TriggerException", ()); assert!(res.is_err()); diff --git a/src/hyperlight_host/src/func/param_type.rs b/src/hyperlight_host/src/func/param_type.rs index 407b202da..50de7f8c3 100644 --- a/src/hyperlight_host/src/func/param_type.rs +++ b/src/hyperlight_host/src/func/param_type.rs @@ -45,6 +45,8 @@ macro_rules! for_each_param_type { $macro!(u32, UInt); $macro!(i64, Long); $macro!(u64, ULong); + $macro!(f32, Float); + $macro!(f64, Double); $macro!(bool, Bool); $macro!(Vec, VecBytes); }; @@ -94,6 +96,27 @@ pub trait ParameterTuple: Sized + Clone + Send + Sync + 'static { fn from_value(value: Vec) -> Result; } +impl ParameterTuple for T { + const SIZE: usize = 1; + + const TYPE: &[ParameterType] = &[T::TYPE]; + + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + fn into_value(self) -> Vec { + vec![self.into_value()] + } + + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + fn from_value(value: Vec) -> Result { + match <[ParameterValue; 1]>::try_from(value) { + Ok([val]) => Ok(T::from_value(val)?), + Err(value) => { + log_then_return!(UnexpectedNoOfArguments(value.len(), 1)); + } + } + } +} + macro_rules! impl_param_tuple { ([$N:expr] ($($name:ident: $param:ident),*)) => { impl<$($param: SupportedParameterType),*> ParameterTuple for ($($param,)*) { diff --git a/src/hyperlight_host/src/func/ret_type.rs b/src/hyperlight_host/src/func/ret_type.rs index 1006fe649..45bc36d34 100644 --- a/src/hyperlight_host/src/func/ret_type.rs +++ b/src/hyperlight_host/src/func/ret_type.rs @@ -49,6 +49,8 @@ macro_rules! for_each_return_type { $macro!(u32, UInt); $macro!(i64, Long); $macro!(u64, ULong); + $macro!(f32, Float); + $macro!(f64, Double); $macro!(bool, Bool); $macro!(Vec, VecBytes); }; diff --git a/src/hyperlight_host/src/hypervisor/hypervisor_handler.rs b/src/hyperlight_host/src/hypervisor/hypervisor_handler.rs index 8e351708c..c7f46df9e 100644 --- a/src/hyperlight_host/src/hypervisor/hypervisor_handler.rs +++ b/src/hyperlight_host/src/hypervisor/hypervisor_handler.rs @@ -929,7 +929,6 @@ mod tests { use std::sync::{Arc, Barrier}; use std::thread; - use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_testing::simple_guest_as_string; #[cfg(target_os = "windows")] @@ -1013,11 +1012,7 @@ mod tests { let mut sandbox = create_multi_use_sandbox(); let msg = "Hello, World!\n".to_string(); - let res = sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ); + let res = sandbox.call_guest_function_by_name::("PrintOutput", msg); assert!(res.is_ok()); @@ -1028,7 +1023,7 @@ mod tests { fn terminate_execution_then_call_another_function() -> Result<()> { let mut sandbox = create_multi_use_sandbox(); - let res = sandbox.call_guest_function_by_name("Spin", ReturnType::Void, None); + let res = sandbox.call_guest_function_by_name::<()>("Spin", ()); assert!(res.is_err()); @@ -1037,11 +1032,7 @@ mod tests { _ => panic!("Expected ExecutionTerminated error"), } - let res = sandbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("a".to_string())]), - ); + let res = sandbox.call_guest_function_by_name::("Echo", "a".to_string()); assert!(res.is_ok()); @@ -1053,11 +1044,7 @@ mod tests { { let call_print_output = |sandbox: &mut MultiUseSandbox| { let msg = "Hello, World!\n".to_string(); - let res = sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.clone())]), - ); + let res = sandbox.call_guest_function_by_name::("PrintOutput", msg); assert!(res.is_ok()); }; diff --git a/src/hyperlight_host/src/metrics/mod.rs b/src/hyperlight_host/src/metrics/mod.rs index c08f30c94..f88d78ee0 100644 --- a/src/hyperlight_host/src/metrics/mod.rs +++ b/src/hyperlight_host/src/metrics/mod.rs @@ -85,7 +85,6 @@ pub(crate) fn maybe_time_and_emit_host_call T>( #[cfg(test)] mod tests { - use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; use hyperlight_testing::simple_guest_as_string; use metrics::Key; use metrics_util::CompositeKey; @@ -116,15 +115,11 @@ mod tests { let mut multi = uninit.evolve(Noop::default()).unwrap(); multi - .call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String("Hello".to_string())]), - ) + .call_guest_function_by_name::("PrintOutput", "Hello".to_string()) .unwrap(); multi - .call_guest_function_by_name("Spin", ReturnType::Int, None) + .call_guest_function_by_name::("Spin", ()) .unwrap_err(); snapshotter.snapshot() diff --git a/src/hyperlight_host/src/sandbox/host_funcs.rs b/src/hyperlight_host/src/sandbox/host_funcs.rs index dbdea7187..7bab361b3 100644 --- a/src/hyperlight_host/src/sandbox/host_funcs.rs +++ b/src/hyperlight_host/src/sandbox/host_funcs.rs @@ -159,7 +159,7 @@ fn maybe_with_seccomp( // Use a scoped thread so that we can pass around references without having to clone them. crossbeam::thread::scope(|s| { s.builder() - .name(format!("Host Function Worker Thread for: {name:?}",)) + .name(format!("Host Function Worker Thread for: {name:?}")) .spawn(move |_| { let seccomp_filter = get_seccomp_filter_for_host_function_worker_thread(syscalls)?; seccompiler::apply_filter(&seccomp_filter)?; diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 8769e1643..3ca9f5899 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -25,6 +25,7 @@ use super::host_funcs::FunctionRegistry; use super::{MemMgrWrapper, WrapperGetter}; use crate::func::call_ctx::MultiUseGuestCallContext; use crate::func::guest_dispatch::call_function_on_guest; +use crate::func::{ParameterTuple, SupportedReturnType}; use crate::hypervisor::hypervisor_handler::HypervisorHandler; use crate::mem::shared_mem::HostSharedMemory; use crate::sandbox_state::sandbox::{DevolvableSandbox, EvolvableSandbox, Sandbox}; @@ -155,15 +156,28 @@ impl MultiUseSandbox { /// Call a guest function by name, with the given return type and arguments. #[instrument(err(Debug), skip(self, args), parent = Span::current())] - pub fn call_guest_function_by_name( + pub fn call_guest_function_by_name( &mut self, func_name: &str, - func_ret_type: ReturnType, - args: Option>, + args: impl ParameterTuple, + ) -> Result { + let ret = call_function_on_guest(self, func_name, Output::TYPE, args.into_value()); + self.restore_state()?; + Output::from_value(ret?) + } + + /// This function is kept here for fuzz testing the parameter and return types + #[cfg(feature = "fuzzing")] + #[instrument(err(Debug), skip(self, args), parent = Span::current())] + pub fn call_type_erased_guest_function_by_name( + &mut self, + func_name: &str, + ret_type: ReturnType, + args: Vec, ) -> Result { - let res = call_function_on_guest(self, func_name, func_ret_type, args); + let ret = call_function_on_guest(self, func_name, ret_type, args); self.restore_state()?; - res + ret } /// Restore the Sandbox's state @@ -255,9 +269,6 @@ where #[cfg(test)] mod tests { - use hyperlight_common::flatbuffer_wrappers::function_types::{ - ParameterValue, ReturnType, ReturnValue, - }; use hyperlight_testing::simple_guest_as_string; use crate::func::call_ctx::MultiUseGuestCallContext; @@ -284,12 +295,7 @@ mod tests { let mut ctx = sbox1.new_call_context(); for _ in 0..1000 { - ctx.call( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String("hello".to_string())]), - ) - .unwrap(); + ctx.call::("Echo", "hello".to_string()).unwrap(); } let sbox2: MultiUseSandbox = { @@ -302,12 +308,9 @@ mod tests { let mut ctx = sbox2.new_call_context(); for i in 0..1000 { - ctx.call( + ctx.call::( "PrintUsingPrintf", - ReturnType::Int, - Some(vec![ParameterValue::String( - format!("Hello World {}\n", i).to_string(), - )]), + format!("Hello World {}\n", i).to_string(), ) .unwrap(); } @@ -325,23 +328,15 @@ mod tests { .unwrap(); let func = Box::new(|call_ctx: &mut MultiUseGuestCallContext| { - call_ctx.call( - "AddToStatic", - ReturnType::Int, - Some(vec![ParameterValue::Int(5)]), - )?; + call_ctx.call::("AddToStatic", 5i32)?; Ok(()) }); let transition_func = MultiUseContextCallback::from(func); let mut sbox2 = sbox1.evolve(transition_func).unwrap(); - let res = sbox2 - .call_guest_function_by_name("GetStatic", ReturnType::Int, None) - .unwrap(); - assert_eq!(res, ReturnValue::Int(5)); + let res: i32 = sbox2.call_guest_function_by_name("GetStatic", ()).unwrap(); + assert_eq!(res, 5); let mut sbox3: MultiUseSandbox = sbox2.devolve(Noop::default()).unwrap(); - let res = sbox3 - .call_guest_function_by_name("GetStatic", ReturnType::Int, None) - .unwrap(); - assert_eq!(res, ReturnValue::Int(0)); + let res: i32 = sbox3.call_guest_function_by_name("GetStatic", ()).unwrap(); + assert_eq!(res, 0); } } diff --git a/src/hyperlight_host/src/seccomp/guest.rs b/src/hyperlight_host/src/seccomp/guest.rs index d8321044f..3b34dd2f5 100644 --- a/src/hyperlight_host/src/seccomp/guest.rs +++ b/src/hyperlight_host/src/seccomp/guest.rs @@ -57,6 +57,8 @@ fn syscalls_allowlist() -> Result)>> { // `sched_yield` is needed for many synchronization primitives that may be invoked // on the host function worker thread (libc::SYS_sched_yield, vec![]), + // `mprotect` is needed by malloc during memory allocation + (libc::SYS_mprotect, vec![]), ]) } diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 2df993070..aa8ffe0e9 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -16,7 +16,6 @@ limitations under the License. #![allow(clippy::disallowed_macros)] use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::mem::PAGE_SIZE; -use hyperlight_host::func::{ParameterValue, ReturnType, ReturnValue}; use hyperlight_host::sandbox::SandboxConfiguration; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -35,18 +34,12 @@ fn print_four_args_c_guest() { let uninit = UninitializedSandbox::new(guest_path, None); let mut sbox1 = uninit.unwrap().evolve(Noop::default()).unwrap(); - let res = sbox1.call_guest_function_by_name( + let res = sbox1.call_guest_function_by_name::( "PrintFourArgs", - ReturnType::String, - Some(vec![ - ParameterValue::String("Test4".to_string()), - ParameterValue::Int(3_i32), - ParameterValue::Long(4_i64), - ParameterValue::String("Tested".to_string()), - ]), + ("Test4".to_string(), 3_i32, 4_i64, "Tested".to_string()), ); println!("{:?}", res); - assert!(matches!(res, Ok(ReturnValue::Int(46)))); + assert!(matches!(res, Ok(46))); } // Checks that guest can abort with a specific code. @@ -55,11 +48,7 @@ fn guest_abort() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); let error_code: u8 = 13; // this is arbitrary let res = sbox1 - .call_guest_function_by_name( - "GuestAbortWithCode", - ReturnType::Void, - Some(vec![ParameterValue::Int(error_code as i32)]), - ) + .call_guest_function_by_name::<()>("GuestAbortWithCode", error_code as i32) .unwrap_err(); println!("{:?}", res); assert!( @@ -72,14 +61,7 @@ fn guest_abort_with_context1() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); let res = sbox1 - .call_guest_function_by_name( - "GuestAbortWithMessage", - ReturnType::Void, - Some(vec![ - ParameterValue::Int(25), - ParameterValue::String("Oh no".to_string()), - ]), - ) + .call_guest_function_by_name::<()>("GuestAbortWithMessage", (25_i32, "Oh no".to_string())) .unwrap_err(); println!("{:?}", res); assert!( @@ -124,13 +106,9 @@ fn guest_abort_with_context2() { Proin sagittis nisl rhoncus mattis rhoncus urna. Magna eget est lorem ipsum."; let res = sbox1 - .call_guest_function_by_name( + .call_guest_function_by_name::<()>( "GuestAbortWithMessage", - ReturnType::Void, - Some(vec![ - ParameterValue::Int(60), - ParameterValue::String(abort_message.to_string()), - ]), + (60_i32, abort_message.to_string()), ) .unwrap_err(); println!("{:?}", res); @@ -150,13 +128,9 @@ fn guest_abort_c_guest() { let mut sbox1 = uninit.unwrap().evolve(Noop::default()).unwrap(); let res = sbox1 - .call_guest_function_by_name( + .call_guest_function_by_name::<()>( "GuestAbortWithMessage", - ReturnType::Void, - Some(vec![ - ParameterValue::Int(75_i32), - ParameterValue::String("This is a test error message".to_string()), - ]), + (75_i32, "This is a test error message".to_string()), ) .unwrap_err(); println!("{:?}", res); @@ -171,13 +145,7 @@ fn guest_panic() { let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); let res = sbox1 - .call_guest_function_by_name( - "guest_panic", - ReturnType::Void, - Some(vec![ParameterValue::String( - "Error... error...".to_string(), - )]), - ) + .call_guest_function_by_name::<()>("guest_panic", "Error... error...".to_string()) .unwrap_err(); println!("{:?}", res); assert!( @@ -190,32 +158,26 @@ fn guest_malloc() { // this test is rust-only let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); - let size_to_allocate = 2000; - let res = sbox1 - .call_guest_function_by_name( - "TestMalloc", - ReturnType::Int, - Some(vec![ParameterValue::Int(size_to_allocate)]), - ) + let size_to_allocate = 2000_i32; + sbox1 + .call_guest_function_by_name::("TestMalloc", size_to_allocate) .unwrap(); - assert!(matches!(res, ReturnValue::Int(_))); } #[test] fn guest_allocate_vec() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); - let size_to_allocate = 2000; + let size_to_allocate = 2000_i32; let res = sbox1 - .call_guest_function_by_name( + .call_guest_function_by_name::( "CallMalloc", // uses the rust allocator to allocate a vector on heap - ReturnType::Int, - Some(vec![ParameterValue::Int(size_to_allocate)]), + size_to_allocate, ) .unwrap(); - assert!(matches!(res, ReturnValue::Int(returned_size) if returned_size == size_to_allocate)); + assert_eq!(res, size_to_allocate); } // checks that malloc failures are captured correctly @@ -223,14 +185,10 @@ fn guest_allocate_vec() { fn guest_malloc_abort() { let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); - let size = 20000000; // some big number that should fail when allocated + let size = 20000000_i32; // some big number that should fail when allocated let res = sbox1 - .call_guest_function_by_name( - "TestMalloc", - ReturnType::Int, - Some(vec![ParameterValue::Int(size)]), - ) + .call_guest_function_by_name::("TestMalloc", size) .unwrap_err(); println!("{:?}", res); assert!( @@ -251,10 +209,9 @@ fn guest_malloc_abort() { .unwrap(); let mut sbox2 = uninit.evolve(Noop::default()).unwrap(); - let res = sbox2.call_guest_function_by_name( + let res = sbox2.call_guest_function_by_name::( "CallMalloc", // uses the rust allocator to allocate a vector on heap - ReturnType::Int, - Some(vec![ParameterValue::Int(size_to_allocate as i32)]), + size_to_allocate as i32, ); println!("{:?}", res); assert!(matches!( @@ -272,21 +229,13 @@ fn dynamic_stack_allocate_c_guest() { let uninit = UninitializedSandbox::new(guest_path, None); let mut sbox1: MultiUseSandbox = uninit.unwrap().evolve(Noop::default()).unwrap(); - let res2 = sbox1 - .call_guest_function_by_name( - "StackAllocate", - ReturnType::Int, - Some(vec![ParameterValue::Int(100)]), - ) + let res: i32 = sbox1 + .call_guest_function_by_name("StackAllocate", 100_i32) .unwrap(); - assert!(matches!(res2, ReturnValue::Int(n) if n == 100)); + assert_eq!(res, 100); let res = sbox1 - .call_guest_function_by_name( - "StackAllocate", - ReturnType::Int, - Some(vec![ParameterValue::Int(128 * 1024 * 1024)]), - ) + .call_guest_function_by_name::("StackAllocate", 0x800_0000_i32) .unwrap_err(); assert!(matches!(res, HyperlightError::StackOverflow())); } @@ -296,10 +245,8 @@ fn dynamic_stack_allocate_c_guest() { fn static_stack_allocate() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); - let res = sbox1 - .call_guest_function_by_name("SmallVar", ReturnType::Int, Some(Vec::new())) - .unwrap(); - assert!(matches!(res, ReturnValue::Int(1024))); + let res: i32 = sbox1.call_guest_function_by_name("SmallVar", ()).unwrap(); + assert_eq!(res, 1024); } // checks that a huge buffer on stack fails with stackoverflow @@ -307,7 +254,7 @@ fn static_stack_allocate() { fn static_stack_allocate_overflow() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); let res = sbox1 - .call_guest_function_by_name("LargeVar", ReturnType::Int, Some(Vec::new())) + .call_guest_function_by_name::("LargeVar", ()) .unwrap_err(); assert!(matches!(res, HyperlightError::StackOverflow())); } @@ -317,14 +264,10 @@ fn static_stack_allocate_overflow() { fn recursive_stack_allocate() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); - let iterations = 1; + let iterations = 1_i32; sbox1 - .call_guest_function_by_name( - "StackOverflow", - ReturnType::Int, - Some(vec![ParameterValue::Int(iterations)]), - ) + .call_guest_function_by_name::("StackOverflow", iterations) .unwrap(); } @@ -350,11 +293,7 @@ fn guard_page_check() { // we have to create a sandbox each iteration because can't reuse after MMIO error in release mode let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); - let result = sbox1.call_guest_function_by_name( - "test_write_raw_ptr", - ReturnType::String, - Some(vec![ParameterValue::Long(offset)]), - ); + let result = sbox1.call_guest_function_by_name::("test_write_raw_ptr", offset); if guard_range.contains(&offset) { // should have failed assert!(matches!( @@ -373,7 +312,7 @@ fn guard_page_check_2() { let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); let result = sbox1 - .call_guest_function_by_name("InfiniteRecursion", ReturnType::Void, Some(vec![])) + .call_guest_function_by_name::<()>("InfiniteRecursion", ()) .unwrap_err(); assert!(matches!(result, HyperlightError::StackOverflow())); } @@ -383,7 +322,7 @@ fn execute_on_stack() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); let result = sbox1 - .call_guest_function_by_name("ExecuteOnStack", ReturnType::String, Some(vec![])) + .call_guest_function_by_name::("ExecuteOnStack", ()) .unwrap_err(); let err = result.to_string(); @@ -397,8 +336,7 @@ fn execute_on_stack() { #[ignore] // ran from Justfile because requires feature "executable_heap" fn execute_on_heap() { let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); - let result = - sbox1.call_guest_function_by_name("ExecuteOnHeap", ReturnType::String, Some(vec![])); + let result = sbox1.call_guest_function_by_name::("ExecuteOnHeap", ()); println!("{:#?}", result); #[cfg(feature = "executable_heap")] @@ -417,16 +355,12 @@ fn execute_on_heap() { fn memory_resets_after_failed_guestcall() { let mut sbox1 = new_uninit_rust().unwrap().evolve(Noop::default()).unwrap(); sbox1 - .call_guest_function_by_name("AddToStaticAndFail", ReturnType::String, None) + .call_guest_function_by_name::("AddToStaticAndFail", ()) .unwrap_err(); let res = sbox1 - .call_guest_function_by_name("GetStatic", ReturnType::Int, None) + .call_guest_function_by_name::("GetStatic", ()) .unwrap(); - assert!( - matches!(res, ReturnValue::Int(0)), - "Expected 0, got {:?}", - res - ); + assert_eq!(res, 0, "Expected 0, got {:?}", res); } // checks that a recursive function with stack allocation eventually fails with stackoverflow @@ -434,14 +368,10 @@ fn memory_resets_after_failed_guestcall() { fn recursive_stack_allocate_overflow() { let mut sbox1 = new_uninit().unwrap().evolve(Noop::default()).unwrap(); - let iterations = 10; + let iterations = 10_i32; let res = sbox1 - .call_guest_function_by_name( - "StackOverflow", - ReturnType::Void, - Some(vec![ParameterValue::Int(iterations)]), - ) + .call_guest_function_by_name::<()>("StackOverflow", iterations) .unwrap_err(); println!("{:?}", res); assert!(matches!(res, HyperlightError::StackOverflow())); @@ -511,14 +441,7 @@ fn log_test_messages(levelfilter: Option) { let message = format!("Hello from log_message level {}", level as i32); sbox1 - .call_guest_function_by_name( - "LogMessage", - ReturnType::Void, - Some(vec![ - ParameterValue::String(message.to_string()), - ParameterValue::Int(level as i32), - ]), - ) + .call_guest_function_by_name::<()>("LogMessage", (message.to_string(), level as i32)) .unwrap(); } } diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index e38b23dab..581e17a07 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -15,10 +15,10 @@ limitations under the License. */ #![allow(clippy::disallowed_macros)] use core::f64; +use std::sync::mpsc::channel; use std::sync::{Arc, Mutex}; use common::new_uninit; -use hyperlight_host::func::{ParameterValue, ReturnType, ReturnValue}; use hyperlight_host::sandbox::SandboxConfiguration; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -39,26 +39,13 @@ fn pass_byte_array() { let mut ctx = sandbox.new_call_context(); const LEN: usize = 10; let bytes = vec![1u8; LEN]; - let res = ctx.call( - "SetByteArrayToZero", - ReturnType::VecBytes, - Some(vec![ParameterValue::VecBytes(bytes.clone())]), - ); + let res: Vec = ctx + .call("SetByteArrayToZero", bytes.clone()) + .expect("Expected VecBytes"); + assert_eq!(res, [0; LEN]); - match res.unwrap() { - ReturnValue::VecBytes(res_bytes) => { - assert_eq!(res_bytes.len(), LEN); - assert!(res_bytes.iter().all(|&b| b == 0)); - } - _ => panic!("Expected VecBytes"), - } - - let res = ctx.call( - "SetByteArrayToZeroNoLength", - ReturnType::Int, - Some(vec![ParameterValue::VecBytes(bytes.clone())]), - ); - assert!(res.is_err()); // missing length param + ctx.call::("SetByteArrayToZeroNoLength", bytes.clone()) + .unwrap_err(); // missing length param } } @@ -100,28 +87,24 @@ fn float_roundtrip() { ]; let mut sandbox: MultiUseSandbox = new_uninit().unwrap().evolve(Noop::default()).unwrap(); for f in doubles.iter() { - let res = sandbox.call_guest_function_by_name( - "EchoDouble", - ReturnType::Double, - Some(vec![ParameterValue::Double(*f)]), - ); + let res: f64 = sandbox + .call_guest_function_by_name("EchoDouble", *f) + .unwrap(); assert!( - matches!(res, Ok(ReturnValue::Double(f2)) if f2 == *f || f2.is_nan() && f.is_nan()), + res.total_cmp(f).is_eq(), "Expected {:?} but got {:?}", f, res ); } for f in floats.iter() { - let res = sandbox.call_guest_function_by_name( - "EchoFloat", - ReturnType::Float, - Some(vec![ParameterValue::Float(*f)]), - ); + let res: f32 = sandbox + .call_guest_function_by_name("EchoFloat", *f) + .unwrap(); assert!( - matches!(res, Ok(ReturnValue::Float(f2)) if f2 == *f || f2.is_nan() && f.is_nan()), + res.total_cmp(f).is_eq(), "Expected {:?} but got {:?}", f, res @@ -134,7 +117,7 @@ fn float_roundtrip() { fn invalid_guest_function_name() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { let fn_name = "FunctionDoesntExist"; - let res = sandbox.call_guest_function_by_name(fn_name, ReturnType::Int, None); + let res = sandbox.call_guest_function_by_name::(fn_name, ()); println!("{:?}", res); assert!( matches!(res.unwrap_err(), HyperlightError::GuestError(hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode::GuestFunctionNotFound, error_name) if error_name == fn_name) @@ -147,208 +130,68 @@ fn invalid_guest_function_name() { fn set_static() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { let fn_name = "SetStatic"; - let res = sandbox.call_guest_function_by_name(fn_name, ReturnType::Int, None); + let res = sandbox.call_guest_function_by_name::(fn_name, ()); println!("{:?}", res); assert!(res.is_ok()); // the result is the size of the static array in the guest - assert_eq!(res.unwrap(), ReturnValue::Int(1024 * 1024)); + assert_eq!(res.unwrap(), 1024 * 1024); } } #[test] #[cfg_attr(target_os = "windows", serial)] // using LoadLibrary requires serial tests fn multiple_parameters() { - let messages = Arc::new(Mutex::new(Vec::new())); - let messages_clone = messages.clone(); + let (tx, rx) = channel(); let writer = move |msg: String| { - let mut lock = messages_clone - .try_lock() - .map_err(|_| new_error!("Error locking")) - .unwrap(); - lock.push(msg); + tx.send(msg).unwrap(); 0 }; - let test_cases = vec![ - ( - "PrintTwoArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ], - format!("Message: arg1:{} arg2:{}.", "1", 2), - ), - ( - "PrintThreeArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ], - format!("Message: arg1:{} arg2:{} arg3:{}.", "1", 2, 3), - ), - ( - "PrintFourArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ], - format!("Message: arg1:{} arg2:{} arg3:{} arg4:{}.", "1", 2, 3, "4"), - ), - ( - "PrintFiveArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{}.", - "1", 2, 3, "4", "5" - ), - ), - ( - "PrintSixArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ParameterValue::Bool(true), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{} arg6:{}.", - "1", 2, 3, "4", "5", true - ), - ), - ( - "PrintSevenArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ParameterValue::Bool(true), - ParameterValue::Bool(false), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{} arg6:{} arg7:{}.", - "1", 2, 3, "4", "5", true, false - ), - ), - ( - "PrintEightArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ParameterValue::Bool(true), - ParameterValue::Bool(false), - ParameterValue::UInt(8), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{} arg6:{} arg7:{} arg8:{}.", - "1", 2, 3, "4", "5", true, false, 8 - ), - ), - ( - "PrintNineArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ParameterValue::Bool(true), - ParameterValue::Bool(false), - ParameterValue::UInt(8), - ParameterValue::ULong(9), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{} arg6:{} arg7:{} arg8:{} arg9:{}.", - "1", 2, 3, "4", "5", true, false, 8, 9 - ), - ), - ( - "PrintTenArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ParameterValue::Bool(true), - ParameterValue::Bool(false), - ParameterValue::UInt(8), - ParameterValue::ULong(9), - ParameterValue::Int(10), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{} arg6:{} arg7:{} arg8:{} arg9:{} arg10:{}.", - "1", 2, 3, "4", "5", true, false, 8, 9, 10 - ), - ), - ( - "PrintElevenArgs", - vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ParameterValue::Long(3), - ParameterValue::String("4".to_string()), - ParameterValue::String("5".to_string()), - ParameterValue::Bool(true), - ParameterValue::Bool(false), - ParameterValue::UInt(8), - ParameterValue::ULong(9), - ParameterValue::Int(10), - ParameterValue::Float(3.123), - ], - format!( - "Message: arg1:{} arg2:{} arg3:{} arg4:{} arg5:{} arg6:{} arg7:{} arg8:{} arg9:{} arg10:{} arg11:{}.", - "1", 2, 3, "4", "5", true, false, 8, 9, 10, 3.123 - ), - ) - ]; + let args = ( + ("1".to_string(), "arg1:1"), + (2_i32, "arg2:2"), + (3_i64, "arg3:3"), + ("4".to_string(), "arg4:4"), + ("5".to_string(), "arg5:5"), + (true, "arg6:true"), + (false, "arg7:false"), + (8_u32, "arg8:8"), + (9_u64, "arg9:9"), + (10_i32, "arg10:10"), + (3.123_f32, "arg11:3.123"), + ); - for mut sandbox in get_simpleguest_sandboxes(Some(writer.into())).into_iter() { - for (fn_name, args, _expected) in test_cases.clone().into_iter() { - let res = sandbox.call_guest_function_by_name(fn_name, ReturnType::Int, Some(args)); - println!("{:?}", res); - assert!(res.is_ok()); - } + macro_rules! test_case { + ($sandbox:ident, $rx:ident, $name:literal, ($($p:ident),+)) => {{ + let ($($p),+, ..) = args.clone(); + let res: i32 = $sandbox.call_guest_function_by_name($name, ($($p.0,)+)).unwrap(); + println!("{res:?}"); + let output = $rx.try_recv().unwrap(); + println!("{output:?}"); + assert_eq!(output, format!("Message: {}.", [$($p.1),+].join(" "))); + }}; } - let lock = messages - .try_lock() - .map_err(|_| new_error!("Error locking")) - .unwrap(); - lock.clone() - .into_iter() - .zip(test_cases) - .for_each(|(printed_msg, expected)| { - println!("{:?}", printed_msg); - assert_eq!(printed_msg, expected.2); - }); + for mut sb in get_simpleguest_sandboxes(Some(writer.into())).into_iter() { + test_case!(sb, rx, "PrintTwoArgs", (a, b)); + test_case!(sb, rx, "PrintThreeArgs", (a, b, c)); + test_case!(sb, rx, "PrintFourArgs", (a, b, c, d)); + test_case!(sb, rx, "PrintFiveArgs", (a, b, c, d, e)); + test_case!(sb, rx, "PrintSixArgs", (a, b, c, d, e, f)); + test_case!(sb, rx, "PrintSevenArgs", (a, b, c, d, e, f, g)); + test_case!(sb, rx, "PrintEightArgs", (a, b, c, d, e, f, g, h)); + test_case!(sb, rx, "PrintNineArgs", (a, b, c, d, e, f, g, h, i)); + test_case!(sb, rx, "PrintTenArgs", (a, b, c, d, e, f, g, h, i, j)); + test_case!(sb, rx, "PrintElevenArgs", (a, b, c, d, e, f, g, h, i, j, k)); + } } #[test] #[cfg_attr(target_os = "windows", serial)] // using LoadLibrary requires serial tests fn incorrect_parameter_type() { for mut sandbox in get_simpleguest_sandboxes(None) { - let res = sandbox.call_guest_function_by_name( - "Echo", - ReturnType::Int, - Some(vec![ - ParameterValue::Int(2), // should be string - ]), + let res = sandbox.call_guest_function_by_name::( + "Echo", 2_i32, // should be string ); assert!(matches!( @@ -365,14 +208,7 @@ fn incorrect_parameter_type() { #[cfg_attr(target_os = "windows", serial)] // using LoadLibrary requires serial tests fn incorrect_parameter_num() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { - let res = sandbox.call_guest_function_by_name( - "Echo", - ReturnType::Int, - Some(vec![ - ParameterValue::String("1".to_string()), - ParameterValue::Int(2), - ]), - ); + let res = sandbox.call_guest_function_by_name::("Echo", ("1".to_string(), 2_i32)); assert!(matches!( res.unwrap_err(), HyperlightError::GuestError( @@ -402,14 +238,10 @@ fn max_memory_sandbox() { #[cfg_attr(target_os = "windows", serial)] // using LoadLibrary requires serial tests fn iostack_is_working() { for mut sandbox in get_simpleguest_sandboxes(None).into_iter() { - let res = sandbox.call_guest_function_by_name( - "ThisIsNotARealFunctionButTheNameIsImportant", - ReturnType::Int, - None, - ); - println!("{:?}", res); - assert!(res.is_ok()); - assert_eq!(res.unwrap(), ReturnValue::Int(99)); + let res: i32 = sandbox + .call_guest_function_by_name::("ThisIsNotARealFunctionButTheNameIsImportant", ()) + .unwrap(); + assert_eq!(res, 99); } } @@ -430,30 +262,21 @@ fn simple_test_helper() -> Result<()> { let message2 = "world"; for mut sandbox in get_simpleguest_sandboxes(Some(writer.into())).into_iter() { - let res = sandbox.call_guest_function_by_name( - "PrintOutput", - ReturnType::Int, - Some(vec![ParameterValue::String(message.to_string())]), - ); - println!("res: {:?}", res); - assert!(matches!(res, Ok(ReturnValue::Int(5)))); + let res: i32 = sandbox + .call_guest_function_by_name("PrintOutput", message.to_string()) + .unwrap(); + assert_eq!(res, 5); - let res2 = sandbox.call_guest_function_by_name( - "Echo", - ReturnType::String, - Some(vec![ParameterValue::String(message2.to_string())]), - ); - println!("res2: {:?}", res2); - assert!(matches!(res2, Ok(ReturnValue::String(s)) if s == "world")); - - let buffer = vec![1u8, 2, 3, 4, 5, 6]; - let res3 = sandbox.call_guest_function_by_name( - "GetSizePrefixedBuffer", - ReturnType::Int, - Some(vec![ParameterValue::VecBytes(buffer.clone())]), - ); - println!("res3: {:?}", res3); - assert!(matches!(res3, Ok(ReturnValue::VecBytes(v)) if v == buffer)); + let res: String = sandbox + .call_guest_function_by_name("Echo", message2.to_string()) + .unwrap(); + assert_eq!(res, "world"); + + let buffer = [1u8, 2, 3, 4, 5, 6]; + let res: Vec = sandbox + .call_guest_function_by_name("GetSizePrefixedBuffer", buffer.to_vec()) + .unwrap(); + assert_eq!(res, buffer); } let expected_calls = 1; @@ -499,39 +322,20 @@ fn simple_test_parallel() { fn callback_test_helper() -> Result<()> { for mut sandbox in get_callbackguest_uninit_sandboxes(None).into_iter() { // create host function - let vec = Arc::new(Mutex::new(vec![])); - let vec_cloned = vec.clone(); - + let (tx, rx) = channel(); sandbox.register("HostMethod1", move |msg: String| { let len = msg.len(); - vec_cloned - .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .push(msg); + tx.send(msg).unwrap(); Ok(len as i32) })?; // call guest function that calls host function let mut init_sandbox: MultiUseSandbox = sandbox.evolve(Noop::default())?; let msg = "Hello world"; - init_sandbox.call_guest_function_by_name( - "GuestMethod1", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.to_string())]), - )?; - - assert_eq!( - vec.try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .len(), - 1 - ); - assert_eq!( - vec.try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .remove(0), - format!("Hello from GuestFunction1, {}", msg) - ); + init_sandbox.call_guest_function_by_name::("GuestMethod1", msg.to_string())?; + + let messages = rx.try_iter().collect::>(); + assert_eq!(messages, [format!("Hello from GuestFunction1, {msg}")]); } Ok(()) } @@ -570,13 +374,10 @@ fn host_function_error() -> Result<()> { // call guest function that calls host function let mut init_sandbox: MultiUseSandbox = sandbox.evolve(Noop::default())?; let msg = "Hello world"; - let res = init_sandbox.call_guest_function_by_name( - "GuestMethod1", - ReturnType::Int, - Some(vec![ParameterValue::String(msg.to_string())]), - ); - println!("res {:?}", res); - assert!(matches!(res, Err(HyperlightError::Error(msg)) if msg == "Host function error!")); + let res = init_sandbox + .call_guest_function_by_name::("GuestMethod1", msg.to_string()) + .unwrap_err(); + assert!(matches!(res, HyperlightError::Error(msg) if msg == "Host function error!")); } Ok(()) }