diff --git a/src/execution/interpreter_loop.rs b/src/execution/interpreter_loop.rs index 7e5874205..2c5c94290 100644 --- a/src/execution/interpreter_loop.rs +++ b/src/execution/interpreter_loop.rs @@ -291,8 +291,12 @@ pub(super) fn run( .unwrap_validated(); let actual_type_idx = func_to_call_inst.ty(); + let actual_ty = modules[*current_module_idx] + .fn_types + .get(actual_type_idx) + .unwrap_validated(); - if given_type_idx != actual_type_idx { + if func_ty != actual_ty { return Err(RuntimeError::SignatureMismatch); } diff --git a/src/execution/mod.rs b/src/execution/mod.rs index e9dc3c7ef..2d41145a7 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -231,6 +231,7 @@ where // Pop return values from stack let return_values = Returns::TYS .iter() + .rev() .map(|ty| stack.pop_value(*ty)) .collect::>(); @@ -308,6 +309,7 @@ where .returns .valtypes .iter() + .rev() .map(|ty| stack.pop_value(*ty)) .collect::>(); @@ -390,6 +392,7 @@ where .returns .valtypes .iter() + .rev() .map(|ty| stack.pop_value(*ty)) .collect::>(); diff --git a/src/validation/code.rs b/src/validation/code.rs index 6fa3dd20c..a7c41b5fb 100644 --- a/src/validation/code.rs +++ b/src/validation/code.rs @@ -267,7 +267,7 @@ fn read_instructions( { if *stp == usize::MAX { //this If was previously matched with an else already, it is already backpatched! - return Err(Error::IfWithoutMatchingElse); + return Err(Error::ElseWithoutMatchingIf); } let stp_here = sidetable.len(); sidetable.push(SidetableEntry { @@ -315,10 +315,9 @@ fn read_instructions( let label_vec = wasm.read_vec(|wasm| wasm.read_var_u32().map(|v| v as LabelIdx))?; let max_label_idx = wasm.read_var_u32()? as LabelIdx; stack.assert_pop_val_type(ValType::NumType(NumType::I32))?; - - for label_idx in label_vec { + for label_idx in &label_vec { validate_intrablock_jump_and_generate_sidetable_entry( - wasm, label_idx, stack, sidetable, + wasm, *label_idx, stack, sidetable, )?; } @@ -329,10 +328,34 @@ fn read_instructions( sidetable, )?; + // The label arity of the branches must be explicitly checked against each other further + // if their arities are the same, then they must unify, as they unify against the stack variables already + // If the following check is not made, the algorithm incorrectly unifies label types with different arities + // in which the smaller arity type is a suffix in the label type list of the larger arity function + + // stack includes all labels, that check is made in the above fn already + let max_label_arity = stack + .ctrl_stack + .get(stack.ctrl_stack.len() - max_label_idx - 1) + .unwrap() + .label_types() + .len(); + for label_idx in &label_vec { + let label_arity = stack + .ctrl_stack + .get(stack.ctrl_stack.len() - *label_idx - 1) + .unwrap() + .label_types() + .len(); + if max_label_arity != label_arity { + return Err(Error::InvalidLabelIdx(*label_idx)); + } + } + stack.make_unspecified()?; } END => { - let (label_info, _) = stack.assert_pop_ctrl()?; + let (label_info, block_ty) = stack.assert_pop_ctrl()?; let stp_here = sidetable.len(); match label_info { @@ -347,6 +370,14 @@ fn read_instructions( stps_to_backpatch, } => { if stp != usize::MAX { + //This If is still not backpatched, meaning it does not have a corresponding + //ELSE. This is only allowed when the corresponding If block has the same input + //types as its output types (an untyped ELSE block with no instruction is valid + //if and only if it is of this type) + if !(block_ty.params == block_ty.returns) { + return Err(Error::IfWithoutMatchingElse); + } + //This If is still not backpatched, meaning it does not have a corresponding //ELSE. Therefore if its condition fails, it jumps after END. sidetable[stp].delta_pc = (wasm.pc as isize) - sidetable[stp].delta_pc; diff --git a/tests/function_recursion.rs b/tests/function_recursion.rs index 860d70241..e2941407d 100644 --- a/tests/function_recursion.rs +++ b/tests/function_recursion.rs @@ -121,3 +121,34 @@ fn recursion_busted_stack() { "validation incorrectly passed" ); } + +#[test_log::test] +fn multivalue_call() { + let wat = r#" + (module + (func $foo (param $x i64) (param $y i32) (param $z f32) (result i32 f32 i64) + local.get $y + local.get $z + local.get $x + ) + (func $bar (export "bar") (result i32 f32 i64) + i64.const 5 + i32.const 10 + f32.const 42.0 + call $foo + ) + ) + "#; + let wasm_bytes = wat::parse_str(wat).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let foo_fn = instance + .get_function_by_name(DEFAULT_MODULE, "bar") + .unwrap(); + + assert_eq!( + (10, 42.0, 5), + instance.invoke::<(), (i32, f32, i64)>(&foo_fn, ()).unwrap() + ); +} diff --git a/tests/structured_control_flow/block.rs b/tests/structured_control_flow/block.rs index 8273ae729..16e691df4 100644 --- a/tests/structured_control_flow/block.rs +++ b/tests/structured_control_flow/block.rs @@ -380,6 +380,30 @@ fn switch_case() { assert_eq!(9, instance.invoke(&switch_case_fn, 7).unwrap()); } +#[test_log::test] +fn br_table_label_typecheck1() { + let wasm_bytes = wat::parse_str( + r#" + (module + (func $test (param $value i32) (result i32) + (block + (block (result i32) + unreachable + (br_table 1 0 1 (i32.const 0)) + ) + ) + ) + (export "test" (func $test)) + )"#, + ) + .unwrap(); + + assert_eq!( + validate(&wasm_bytes).err().unwrap(), + wasm::Error::InvalidLabelIdx(0) + ); +} + const POLYMORPHIC_SELECT_VALIDATION: &str = r#" (module (func $polymorphic_select_validation diff --git a/tests/structured_control_flow/if.rs b/tests/structured_control_flow/if.rs index 6f743aae8..a3d618ce2 100644 --- a/tests/structured_control_flow/if.rs +++ b/tests/structured_control_flow/if.rs @@ -166,3 +166,101 @@ fn recursive_fibonacci_if_else() { assert_eq!(5, instance.invoke(&fibonacci_fn, 4).unwrap()); assert_eq!(8, instance.invoke(&fibonacci_fn, 5).unwrap()); } + +#[test_log::test] +fn if_without_else_type_check1() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $empty (param $cond i32) + (if (local.get $cond) (then)) + ) + + (export "empty" (func $empty)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let empty_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!((), instance.invoke(&empty_fn, 1).unwrap()); + assert_eq!((), instance.invoke(&empty_fn, 0).unwrap()); +} + +#[test_log::test] +fn if_without_else_type_check2() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $empty (param $cond i32) + (i32.const 1) + (if (param i32) (local.get $cond) (then drop)) + ) + + (export "empty" (func $empty)) +)"#, + ) + .unwrap(); + assert_eq!( + validate(&wasm_bytes).err().unwrap(), + wasm::Error::IfWithoutMatchingElse + ); +} + +#[test_log::test] +fn if_without_else_type_check3() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $add_one_if_true (param $cond i32) (result i32) + (i32.const 5) + (if (param i32) (result i32) (local.get $cond) (then (i32.const 2) (i32.add))) + ) + + (export "add_one_if_true" (func $add_one_if_true)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let add_one_if_true_fn = instance.get_function_by_index(0, 0).unwrap(); + + assert_eq!(7, instance.invoke(&add_one_if_true_fn, 1).unwrap()); + assert_eq!(5, instance.invoke(&add_one_if_true_fn, 0).unwrap()); +} + +#[test_log::test] +fn if_without_else_type_check4() { + let wasm_bytes = wat::parse_str( + r#" +(module + (func $do_stuff_if_true (param $cond i32) (result i32) (result i64) + (i32.const 5) + (i64.const 20) + (if (param i32) (param i64) (result i32) (result i64) (local.get $cond) (then drop (i32.const 2) (i32.add) (i64.const 42))) + ) + + (export "do_stuff_if_true" (func $do_stuff_if_true)) +)"#, + ) + .unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let add_one_if_true_fn = instance.get_function_by_index(0, 0).unwrap(); + assert_eq!( + (7, 42), + instance + .invoke::(&add_one_if_true_fn, 1) + .unwrap() + ); + assert_eq!( + (5, 20), + instance + .invoke::(&add_one_if_true_fn, 0) + .unwrap() + ); +} diff --git a/tests/table.rs b/tests/table.rs index 03b0e82f3..7ab214ec2 100644 --- a/tests/table.rs +++ b/tests/table.rs @@ -213,6 +213,73 @@ fn table_get_set_test() { } } +#[test_log::test] +fn call_indirect_type_check() { + let wat = r#" + (module + ;; duplicate same type for different ids to make sure types themselves are compared + ;; during call_indirect, not type ids + (type $type_1 (func (param i32) (result i32))) + (type $type_2 (func (param i32) (result i32))) + (type $type_3 (func (param i32) (result i32))) + + (func $add_one_func (type $type_1) (param $x i32) (result i32) + local.get $x + i32.const 1 + i32.add + ) + + (func $mul_two_func (type $type_2) (param $x i32) (result i32) + local.get $x + i32.const 2 + i32.mul + ) + + (table funcref (elem $add_one_func $mul_two_func)) + + (func $call_function (param $value i32) (param $index i32) (result i32) + local.get $value + local.get $index + call_indirect 0 (type $type_3) + ) + + (export "call_function" (func $call_function)) + ) + "#; + let wasm_bytes = wat::parse_str(wat).unwrap(); + let validation_info = validate(&wasm_bytes).expect("validation failed"); + let mut instance = RuntimeInstance::new(&validation_info).expect("instantiation failed"); + + let call_fn = instance + .get_function_by_name(DEFAULT_MODULE, "call_function") + .unwrap(); + + assert_eq!( + 4, + instance + .invoke::<(i32, i32), i32>(&call_fn, (3, 0)) + .unwrap() + ); + assert_eq!( + 6, + instance + .invoke::<(i32, i32), i32>(&call_fn, (5, 0)) + .unwrap() + ); + assert_eq!( + 6, + instance + .invoke::<(i32, i32), i32>(&call_fn, (3, 1)) + .unwrap() + ); + assert_eq!( + 10, + instance + .invoke::<(i32, i32), i32>(&call_fn, (5, 1)) + .unwrap() + ); +} + // (assert_malformed // (module quote "(table 0x1_0000_0000 funcref)") // "i32 constant out of range"