Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize internal calls by passing only stack end for arguments #642

Draft
wants to merge 8 commits into
base: skip_nop_instructions
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions lib/fizzy/execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,44 +474,45 @@ void branch(const Code& code, OperandStack& stack, const uint8_t*& pc, uint32_t
stack.drop(stack_drop);
}

inline bool invoke_function(const FuncType& func_type, uint32_t func_idx, Instance& instance,
OperandStack& stack, int depth)
inline bool invoke_function(uint32_t func_idx, Instance& instance, OperandStack& stack, int depth)
{
const auto num_args = func_type.inputs.size();
assert(stack.size() >= num_args);
const auto call_args = stack.rend() - num_args;

const auto ret = execute(instance, func_idx, call_args, depth + 1);
const auto ret = execute_internal(instance, func_idx, stack.rend(), depth + 1);
// Bubble up traps
if (ret.trapped)
if (ret == nullptr)
return false;

stack.drop(num_args);

const auto num_outputs = func_type.outputs.size();
// NOTE: we can assume these two from validation
assert(num_outputs <= 1);
assert(ret.has_value == (num_outputs == 1));
// Push back the result
if (num_outputs != 0)
stack.push(ret.value);

stack.set_end(ret);
return true;
}

} // namespace

ExecutionResult execute(Instance& instance, FuncIdx func_idx, const Value* args, int depth)
Value* execute_internal(Instance& instance, FuncIdx func_idx, Value* args_end, int depth)
{
assert(args_end != nullptr);

assert(depth >= 0);
if (depth > CallStackLimit)
return Trap;
return nullptr;

const auto& func_type = instance.module->get_function_type(func_idx);

auto* args = args_end - func_type.inputs.size();

assert(instance.module->imported_function_types.size() == instance.imported_functions.size());
if (func_idx < instance.imported_functions.size())
return instance.imported_functions[func_idx].function(instance, args, depth);
{
const auto res = instance.imported_functions[func_idx].function(instance, args, depth);
if (res.trapped)
return nullptr;

if (res.has_value)
{
args[0] = res.value;
return args + 1;
}
return args;
}

const auto& code = instance.module->get_code(func_idx);
auto* const memory = instance.memory.get();
Expand Down Expand Up @@ -586,9 +587,8 @@ ExecutionResult execute(Instance& instance, FuncIdx func_idx, const Value* args,
case Instr::call:
{
const auto called_func_idx = read<uint32_t>(pc);
const auto& called_func_type = instance.module->get_function_type(called_func_idx);

if (!invoke_function(called_func_type, called_func_idx, instance, stack, depth))
if (!invoke_function(called_func_idx, instance, stack, depth))
goto trap;
break;
}
Expand All @@ -614,8 +614,7 @@ ExecutionResult execute(Instance& instance, FuncIdx func_idx, const Value* args,
if (expected_type != actual_type)
goto trap;

if (!invoke_function(
actual_type, called_func.func_idx, *called_func.instance, stack, depth))
if (!invoke_function(called_func.func_idx, *called_func.instance, stack, depth))
goto trap;
break;
}
Expand Down Expand Up @@ -1535,9 +1534,15 @@ ExecutionResult execute(Instance& instance, FuncIdx func_idx, const Value* args,
assert(pc == &code.instructions[code.instructions.size()]); // End of code must be reached.
assert(stack.size() == instance.module->get_function_type(func_idx).outputs.size());

return stack.size() != 0 ? ExecutionResult{stack.top()} : Void;
if (stack.size() != 0 && args != nullptr)
{
args[0] = stack.top();
return args + 1;
}

return args;

trap:
return Trap;
return nullptr;
}
} // namespace fizzy
33 changes: 32 additions & 1 deletion lib/fizzy/execute.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,39 @@ struct ExecutionResult
constexpr ExecutionResult Void{true};
constexpr ExecutionResult Trap{false};

/// The "unsafe" internal execute function.
Value* execute_internal(Instance& instance, FuncIdx func_idx, Value* args, int depth);

// Execute a function on an instance.
ExecutionResult execute(Instance& instance, FuncIdx func_idx, const Value* args, int depth = 0);
inline ExecutionResult execute(
Instance& instance, FuncIdx func_idx, const Value* args, int depth = 0)
{
const auto& func_type = instance.module->get_function_type(func_idx);
const auto num_args = func_type.inputs.size();
const auto num_outputs = func_type.outputs.size();
assert(num_outputs <= 1);

const auto arg0 = num_args >= 1 ? args[0] : Value{};

Value fake_arg;
auto* p_args = num_args == 0 ? &fake_arg : const_cast<Value*>(args);

auto* args_end = p_args + num_args;
const auto res = execute_internal(instance, func_idx, args_end, depth);

if (res == nullptr)
return Trap;

if (num_outputs == 1)
{
// Restore original value, because the caller does not expect it being modified.
const auto result_value = p_args[0];
p_args[0] = arg0;
return ExecutionResult(result_value);
}

return Void;
}

inline ExecutionResult execute(
Instance& instance, FuncIdx func_idx, std::initializer_list<Value> args)
Expand Down
8 changes: 5 additions & 3 deletions lib/fizzy/stack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class OperandStack
return *m_top;
}

void set_end(Value* end) noexcept { m_top = end - 1; }

/// Returns the reference to the stack item on given position from the stack top.
/// Requires index < size().
Value& operator[](size_t index) noexcept
Expand All @@ -160,14 +162,14 @@ class OperandStack

void drop(size_t num) noexcept
{
assert(num <= size());
// assert(num <= size());
m_top -= num;
}

/// Returns iterator to the bottom of the stack.
const Value* rbegin() const noexcept { return m_bottom; }
Value* rbegin() const noexcept { return m_bottom; }

/// Returns end iterator counting from the bottom of the stack.
const Value* rend() const noexcept { return m_top + 1; }
Value* rend() const noexcept { return m_top + 1; }
};
} // namespace fizzy
88 changes: 88 additions & 0 deletions test/unittests/execute_call_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,69 @@ TEST(execute_call, call_with_arguments)
EXPECT_THAT(execute(parse(wasm), 1, {}), Result(4));
}

TEST(execute_call, call_void_with_zero_arguments)
{
/* wat2wasm
(module
(global $z (mut i32) (i32.const -1))
(func $set
(global.set $z (i32.const 1))
)
(func (result i32)
call $set
global.get $z
)
)
*/
const auto wasm = from_hex(
"0061736d010000000108026000006000017f03030200010606017f01417f0b0a0f020600410124000b06001000"
"23000b");

EXPECT_THAT(execute(parse(wasm), 1, {}), Result(1));
}

TEST(execute_call, call_void_with_one_argument)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chfast I think these tests could be merged anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#641.

{
/* wat2wasm
(module
(global $z (mut i32) (i32.const -1))
(func $set (param $a i32)
(global.set $z (local.get $a))
)
(func (result i32)
(call $set (i32.const 1))
global.get $z
)
)
*/
const auto wasm = from_hex(
"0061736d0100000001090260017f006000017f03030200010606017f01417f0b0a11020600200024000b080041"
"01100023000b");

EXPECT_THAT(execute(parse(wasm), 1, {}), Result(1));
}

TEST(execute_call, call_void_with_two_arguments)
{
/* wat2wasm
(module
(global $z (mut i32) (i32.const -1))
(func $set (param $a i32) (param $b i32)
(global.set $z (i32.add (local.get $a) (local.get $b)))
)
(func (result i32)
(call $set (i32.const 2) (i32.const 3))
global.get $z
)
)
*/
const auto wasm = from_hex(
"0061736d01000000010a0260027f7f006000017f03030200010606017f01417f0b0a16020900200020016a2400"
"0b0a0041024103100023000b");

EXPECT_THAT(execute(parse(wasm), 1, {}), Result(2 + 3));
}

TEST(execute_call, call_shared_stack_space)
{
/* wat2wasm
Expand Down Expand Up @@ -285,6 +348,31 @@ TEST(execute_call, imported_function_call)
EXPECT_THAT(execute(*instance, 1, {}), Result(42));
}

TEST(execute_call, imported_function_call_void)
{
/* wat2wasm
(func (import "m" "foo"))
(func
call 0
)
*/
const auto wasm =
from_hex("0061736d01000000010401600000020901016d03666f6f0000030201000a0601040010000b");

const auto module = parse(wasm);

bool called = false;
const auto host_foo = [&called](Instance&, const Value*, int) {
called = true;
return Void;
};
const auto host_foo_type = module->typesec[0];

auto instance = instantiate(*module, {{host_foo, host_foo_type}});
execute(*instance, 1, {});
EXPECT_TRUE(called);
}

TEST(execute_call, imported_function_call_with_arguments)
{
/* wat2wasm
Expand Down