Skip to content

Commit

Permalink
feat: Wasm br_table translation
Browse files Browse the repository at this point in the history
  • Loading branch information
greenhat committed Sep 8, 2023
1 parent 75def81 commit b1db289
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 3 deletions.
94 changes: 93 additions & 1 deletion frontend-wasm/src/code_translator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
//!
//! Based on Cranelift's Wasm -> CLIF translator v11.0.0

use std::collections::{hash_map, HashMap};

use crate::environ::{FuncEnvironment, ModuleInfo};
use crate::error::WasmResult;
use crate::error::{WasmError, WasmResult};
use crate::func_translation_state::ControlStackFrame;
use crate::func_translation_state::{ElseData, FuncTranslationState};
use crate::function_builder_ext::FunctionBuilderExt;
Expand Down Expand Up @@ -95,6 +97,7 @@ pub fn translate_operator(
/**************************** Branch instructions *********************************/
Operator::Br { relative_depth } => translate_br(state, relative_depth, builder, span),
Operator::BrIf { relative_depth } => translate_br_if(*relative_depth, builder, state, span),
Operator::BrTable { targets } => translate_br_table(targets, state, builder, span)?,
Operator::Return => translate_return(state, builder, diagnostics, span)?,
/************************************ Calls ****************************************/
Operator::Call { function_index } => {
Expand Down Expand Up @@ -320,6 +323,95 @@ pub fn translate_operator(
Ok(())
}

fn translate_br_table(
targets: &wasmparser::BrTable<'_>,
state: &mut FuncTranslationState,
builder: &mut FunctionBuilderExt<'_>,
span: SourceSpan,
) -> Result<(), WasmError> {
let default = targets.default();
let mut min_depth = default;
for depth in targets.targets() {
let depth = depth?;
if depth < min_depth {
min_depth = depth;
}
}
let jump_args_count = {
let i = state.control_stack.len() - 1 - (min_depth as usize);
let min_depth_frame = &state.control_stack[i];
if min_depth_frame.is_loop() {
min_depth_frame.num_param_values()
} else {
min_depth_frame.num_return_values()
}
};
let val = state.pop1();
let mut data = Vec::with_capacity(targets.len() as usize);
if jump_args_count == 0 {
// No jump arguments
for depth in targets.targets() {
let depth = depth?;
let block = {
let i = state.control_stack.len() - 1 - (depth as usize);
let frame = &mut state.control_stack[i];
frame.set_branched_to_exit();
frame.br_destination()
};
data.push((depth, block));
}
let def_block = {
let i = state.control_stack.len() - 1 - (default as usize);
let frame = &mut state.control_stack[i];
frame.set_branched_to_exit();
frame.br_destination()
};
builder.ins().switch(val, data, def_block, span);
} else {
// Here we have jump arguments, but Midens's switch op doesn't support them
// We then proceed to split the edges going out of the br_table
let return_count = jump_args_count;
let mut dest_block_sequence = vec![];
let mut dest_block_map = HashMap::new();
for depth in targets.targets() {
let depth = depth?;
let branch_block = match dest_block_map.entry(depth as usize) {
hash_map::Entry::Occupied(entry) => *entry.get(),
hash_map::Entry::Vacant(entry) => {
let block = builder.create_block();
dest_block_sequence.push((depth as usize, block));
*entry.insert(block)
}
};
data.push((depth, branch_block));
}
let default_branch_block = match dest_block_map.entry(default as usize) {
hash_map::Entry::Occupied(entry) => *entry.get(),
hash_map::Entry::Vacant(entry) => {
let block = builder.create_block();
dest_block_sequence.push((default as usize, block));
*entry.insert(block)
}
};
builder.ins().switch(val, data, default_branch_block, span);
for (depth, dest_block) in dest_block_sequence {
builder.switch_to_block(dest_block);
builder.seal_block(dest_block);
let real_dest_block = {
let i = state.control_stack.len() - 1 - depth;
let frame = &mut state.control_stack[i];
frame.set_branched_to_exit();
frame.br_destination()
};
let destination_args = state.peekn_mut(return_count);
builder.ins().br(real_dest_block, destination_args, span);
}
state.popn(return_count);
}
state.reachable = false;
Ok(())
}

/// Return the total Miden VM memory size (2^32 addresses * word (4 felts) bytes) in 64KB pages
const fn mem_total_pages() -> i32 {
let bytes_fit_in_felt = 4; // although more than 32 bits can fit into a felt, use 32 bits to be safe
Expand Down
2 changes: 0 additions & 2 deletions frontend-wasm/src/code_translator/tests_unsupported.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ const UNSUPPORTED_WASM_V1_OPS: &[Operator] = &[
table_index: 0,
table_byte: 0,
},
/**************************** Branch instructions *********************************/
// BrTable { targets: .. },
/****************************** Memory Operators ************************************/
F32Load {
memarg: MemArg {
Expand Down
25 changes: 25 additions & 0 deletions frontend-wasm/src/function_builder_ext.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use miden_diagnostics::SourceSpan;
use miden_hir::cranelift_entity::EntitySet;
use miden_hir::cranelift_entity::SecondaryMap;
use miden_hir::Block;
use miden_hir::Br;
Expand All @@ -11,6 +12,7 @@ use miden_hir::Inst;
use miden_hir::InstBuilderBase;
use miden_hir::Instruction;
use miden_hir::ProgramPoint;
use miden_hir::Switch;
use miden_hir::Value;
use miden_hir_type::Type;

Expand Down Expand Up @@ -438,6 +440,29 @@ impl<'a, 'b> InstBuilderBase<'a> for FuncInstBuilderExt<'a, 'b> {
self.builder.declare_successor(*block_else, inst);
}
}
Instruction::Switch(Switch {
op: _,
arg: _,
arms,
default: _,
}) => {
// Unlike all other jumps/branches, arms are
// capable of having the same successor appear
// multiple times, so we must deduplicate.
let mut unique = EntitySet::<Block>::new();
for (_, dest_block) in arms {
if !unique.insert(*dest_block) {
continue;
}

// Call `declare_block_predecessor` instead of `declare_successor` for
// avoiding the borrow checker.
self.builder
.func_ctx
.ssa
.declare_block_predecessor(*dest_block, inst);
}
}
inst => debug_assert!(!inst.opcode().is_branch()),
}

Expand Down
28 changes: 28 additions & 0 deletions frontend-wasm/tests/rust_source/enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#![no_std]
#![no_main]

#[panic_handler]
fn my_panic(_info: &core::panic::PanicInfo) -> ! {
loop {}
}

enum Op {
Add,
Sub,
Mul,
}

#[inline(never)]
#[no_mangle]
fn match_enum(a: u32, b: u32, foo: Op) -> u32 {
match foo {
Op::Add => a + b,
Op::Sub => a - b,
Op::Mul => a * b,
}
}

#[no_mangle]
pub extern "C" fn __main() -> u32 {
match_enum(3, 5, Op::Add) + match_enum(3, 5, Op::Sub) + match_enum(3, 5, Op::Mul)
}
107 changes: 107 additions & 0 deletions frontend-wasm/tests/test_rust_comp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,110 @@ fn rust_fib() {
"#]],
);
}

#[test]
fn rust_enum() {
check_ir(
include_str!("rust_source/enum.rs"),
expect![[r#"
(module
(type (;0;) (func (param i32 i32 i32) (result i32)))
(type (;1;) (func (result i32)))
(func $match_enum (;0;) (type 0) (param i32 i32 i32) (result i32)
block ;; label = @1
block ;; label = @2
block ;; label = @3
local.get 2
i32.const 255
i32.and
br_table 0 (;@3;) 1 (;@2;) 2 (;@1;) 0 (;@3;)
end
local.get 1
local.get 0
i32.add
return
end
local.get 0
local.get 1
i32.sub
return
end
local.get 1
local.get 0
i32.mul
)
(func $__main (;1;) (type 1) (result i32)
i32.const 3
i32.const 5
i32.const 0
call $match_enum
i32.const 3
i32.const 5
i32.const 1
call $match_enum
i32.add
i32.const 3
i32.const 5
i32.const 2
call $match_enum
i32.add
)
(memory (;0;) 16)
(global $__stack_pointer (;0;) (mut i32) i32.const 1048576)
(global (;1;) i32 i32.const 1048576)
(global (;2;) i32 i32.const 1048576)
(export "memory" (memory 0))
(export "match_enum" (func $match_enum))
(export "__main" (func $__main))
(export "__data_end" (global 1))
(export "__heap_base" (global 2))
)"#]],
expect![[r#"
module noname
pub fn match_enum(i32, i32, i32) -> i32 {
block0(v0: i32, v1: i32, v2: i32):
v4 = const.i32 255 : i32
v5 = band v2, v4 : i32
switch v5, 0 => block4, 1 => block3, 2 => block2, block4
block1(v3: i32):
v11 = ret v3 : ()
block2:
v10 = mul v1, v0 : i32
br block1(v10)
block3:
v8 = sub v0, v1 : i32
v9 = ret v8 : ()
block4:
v6 = add v1, v0 : i32
v7 = ret v6 : ()
}
pub fn __main() -> i32 {
block0:
v1 = const.i32 3 : i32
v2 = const.i32 5 : i32
v3 = const.i32 0 : i32
v4 = call noname::match_enum(v1, v2, v3) : i32
v5 = const.i32 3 : i32
v6 = const.i32 5 : i32
v7 = const.i32 1 : i32
v8 = call noname::match_enum(v5, v6, v7) : i32
v9 = add v4, v8 : i32
v10 = const.i32 3 : i32
v11 = const.i32 5 : i32
v12 = const.i32 2 : i32
v13 = call noname::match_enum(v10, v11, v12) : i32
v14 = add v9, v13 : i32
br block1(v14)
block1(v0: i32):
v15 = ret v0 : ()
}
"#]],
)
}

0 comments on commit b1db289

Please sign in to comment.