Skip to content

Commit

Permalink
use main module's realloc if available when applying an adapter (#900)
Browse files Browse the repository at this point in the history
* use main module's realloc if available when applying an adapter

Originally, adapters were expected to allocate their memory using `memory.grow`,
and `wit-component` would also use `memory.grow` to allocate stack space for the
adapter if needed.  However, it turned out this didn't work reliably with
`wasi-sdk`'s libc, which assumed it owned the entire memory region, leading to
chaos when both the adapter and the main module thought they owned the same
pages.

Alex has since patched `wasi-sdk` to address the issue, but it still affects
older toolchains, and particularly modules which were built prior to the patch.
However, many of those older modules export `cabi_realloc` (or it's older
equivalent, `canonical_abi_realloc`), so we might as well use that if it's
available.  That's what this commit does.

In cases where the main module exports a realloc function and we have a use for
it (i.e. either the adapter imports `cabi_realloc` or it needs a shadow stack),
we'll use it.  In cases where we have a use for it but the main module does
_not_ export it, we fall back to the old behavior of using `memory.grow` and
hope the module was built with a recent-enough `wasi-sdk`.

This also addresses alternative toolchains which are not based on `wasi-sdk` and
may expect to own the module's entire memory.  In that case, the toolchain can
simply export `cabi_realloc` and `wit-component` will use it.

Signed-off-by: Joel Dice <[email protected]>

* simplify code and name synthesized `cabi_realloc` import

Signed-off-by: Joel Dice <[email protected]>

---------

Signed-off-by: Joel Dice <[email protected]>
  • Loading branch information
dicej authored Feb 1, 2023
1 parent 3292dfb commit 5ba052b
Show file tree
Hide file tree
Showing 16 changed files with 757 additions and 56 deletions.
2 changes: 1 addition & 1 deletion crates/wit-component/src/encoding/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl<'a> ComponentWorld<'a> {
if required.is_empty() {
continue;
}
let wasm = crate::gc::run(wasm, &required)
let wasm = crate::gc::run(wasm, &required, self.info.realloc)
.context("failed to reduce input adapter module to its minimal size")?;
let info = validate_adapter_module(&wasm, resolve, *world, metadata, &required)
.context("failed to validate the imports of the minimized adapter module")?;
Expand Down
215 changes: 168 additions & 47 deletions crates/wit-component/src/gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,18 @@ use std::mem;
use wasm_encoder::{Encode, EntityType};
use wasmparser::*;

const PAGE_SIZE: i32 = 64 * 1024;

/// This function will reduce the input core `wasm` module to only the set of
/// exports `required`.
///
/// This internally performs a "gc" pass after removing exports to ensure that
/// the resulting module imports the minimal set of functions necessary.
pub fn run(wasm: &[u8], required: &IndexMap<String, FuncType>) -> Result<Vec<u8>> {
pub fn run(
wasm: &[u8],
required: &IndexMap<String, FuncType>,
main_module_realloc: Option<&str>,
) -> Result<Vec<u8>> {
assert!(!required.is_empty());

let mut module = Module::default();
Expand All @@ -35,7 +41,7 @@ pub fn run(wasm: &[u8], required: &IndexMap<String, FuncType>) -> Result<Vec<u8>
}
assert!(!module.exports.is_empty());
module.liveness()?;
module.encode()
module.encode(main_module_realloc)
}

fn always_keep(name: &str) -> bool {
Expand All @@ -47,6 +53,58 @@ fn always_keep(name: &str) -> bool {
}
}

/// This function generates a Wasm function body which implements `cabi_realloc` in terms of `memory.grow`. It
/// only accepts new, page-sized allocations.
fn realloc_via_memory_grow() -> wasm_encoder::Function {
use wasm_encoder::Instruction::*;

let mut func = wasm_encoder::Function::new([(1, wasm_encoder::ValType::I32)]);

// Assert `old_ptr` is null.
func.instruction(&I32Const(0));
func.instruction(&LocalGet(0));
func.instruction(&I32Ne);
func.instruction(&If(wasm_encoder::BlockType::Empty));
func.instruction(&Unreachable);
func.instruction(&End);

// Assert `old_len` is zero.
func.instruction(&I32Const(0));
func.instruction(&LocalGet(1));
func.instruction(&I32Ne);
func.instruction(&If(wasm_encoder::BlockType::Empty));
func.instruction(&Unreachable);
func.instruction(&End);

// Assert `new_len` is equal to the page size (which is the only value we currently support)
// Note: we could easily support arbitrary multiples of PAGE_SIZE here if the need arises.
func.instruction(&I32Const(PAGE_SIZE));
func.instruction(&LocalGet(3));
func.instruction(&I32Ne);
func.instruction(&If(wasm_encoder::BlockType::Empty));
func.instruction(&Unreachable);
func.instruction(&End);

// Grow the memory by 1 page.
func.instruction(&I32Const(1));
func.instruction(&MemoryGrow(0));
func.instruction(&LocalTee(4));

// Test if the return value of the growth was -1 and, if so, trap due to a failed allocation.
func.instruction(&I32Const(-1));
func.instruction(&I32Eq);
func.instruction(&If(wasm_encoder::BlockType::Empty));
func.instruction(&Unreachable);
func.instruction(&End);

func.instruction(&LocalGet(4));
func.instruction(&I32Const(16));
func.instruction(&I32Shl);
func.instruction(&End);

func
}

// Represents a function called while processing a module work list.
type WorklistFunc<'a> = fn(&mut Module<'a>, u32) -> Result<()>;

Expand Down Expand Up @@ -364,7 +422,7 @@ impl<'a> Module<'a> {

/// Encodes this `Module` to a new wasm module which is gc'd and only
/// contains the items that are live as calculated by the `liveness` pass.
fn encode(&mut self) -> Result<Vec<u8>> {
fn encode(&mut self, main_module_realloc: Option<&str>) -> Result<Vec<u8>> {
// Data structure used to track the mapping of old index to new index
// for all live items.
let mut map = Encoder::default();
Expand Down Expand Up @@ -454,22 +512,95 @@ impl<'a> Module<'a> {
}
}

let mut realloc_index = None;
let mut num_func_imports = 0;

// For functions first assign a new index to all functions and then
// afterwards actually map the body of all functions so the `map` of all
// index mappings is fully populated before instructions are mapped.
let mut num_funcs = 0;
for (i, func) in self.live_funcs() {

let is_realloc = |m, n| m == "__main_module__" && n == "cabi_realloc";

let (imported, local) =
self.live_funcs()
.partition::<Vec<_>, _>(|(_, func)| match &func.def {
Definition::Import(m, n) => {
!is_realloc(*m, *n) || main_module_realloc.is_some()
}
Definition::Local(_) => false,
});

for (i, func) in imported {
map.funcs.push(i);
let ty = map.types.remap(func.ty);
match &func.def {
Definition::Import(m, n) => {
imports.import(m, n, EntityType::Function(ty));
let name = if is_realloc(*m, *n) {
// The adapter is importing `cabi_realloc` from the main module, and the main module
// exports that function, but possibly using a different name
// (e.g. `canonical_abi_realloc`). Update the name to match if necessary.
realloc_index = Some(num_func_imports);
main_module_realloc.unwrap_or(n)
} else {
n
};
imports.import(m, name, EntityType::Function(ty));
num_func_imports += 1;
}
Definition::Local(_) => unreachable!(),
}
}

let add_realloc_type = |types: &mut wasm_encoder::TypeSection| {
let type_index = types.len();
types.function(
[
wasm_encoder::ValType::I32,
wasm_encoder::ValType::I32,
wasm_encoder::ValType::I32,
wasm_encoder::ValType::I32,
],
[wasm_encoder::ValType::I32],
);
type_index
};

let sp = self.find_stack_pointer()?;

let mut func_names = Vec::new();

if let (Some(realloc), Some(_), None) = (main_module_realloc, sp, realloc_index) {
// The main module exports a realloc function, and although the adapter doesn't import it, we're going
// to add a function which calls it to allocate some stack space, so let's add an import now.

// Tell the function remapper we're reserving a slot for our extra import:
map.funcs.next += 1;

realloc_index = Some(num_func_imports);
imports.import(
"__main_module__",
realloc,
EntityType::Function(add_realloc_type(&mut types)),
);
func_names.push((num_func_imports, realloc));
num_func_imports += 1;
}

for (i, func) in local {
map.funcs.push(i);
let ty = map.types.remap(func.ty);
match &func.def {
Definition::Import(_, _) => {
// The adapter is importing `cabi_realloc` from the main module, but the main module isn't
// exporting it. In this case, we need to define a local function it can call instead.
realloc_index = Some(num_func_imports + funcs.len());
funcs.function(ty);
code.function(&realloc_via_memory_grow());
}
Definition::Local(_) => {
funcs.function(ty);
}
}
num_funcs += 1;
}

for (_, func) in self.live_funcs() {
Expand All @@ -489,11 +620,18 @@ impl<'a> Module<'a> {
code.function(&func);
}

if sp.is_some() && realloc_index.is_none() {
// The main module does _not_ export a realloc function, nor does the adapter import it, but we need a
// function to allocate some stack space, so we'll add one here.
realloc_index = Some(num_func_imports + funcs.len());
funcs.function(add_realloc_type(&mut types));
code.function(&realloc_via_memory_grow());
}

// Inject a start function to initialize the stack pointer which will be
// local to this module. This only happens if a memory is preserved and
// a stack pointer global is found.
let mut start = None;
let sp = self.find_stack_pointer()?;
if let Some(sp) = sp {
if num_memories > 0 {
use wasm_encoder::Instruction::*;
Expand All @@ -507,42 +645,30 @@ impl<'a> Module<'a> {

let sp = map.globals.remap(sp);

let function_index = num_func_imports + funcs.len();

// Generate a function type for this start function, adding a new
// function type to the module if necessary.
let empty_type = empty_type.unwrap_or_else(|| {
types.function([], []);
types.len() - 1
});
funcs.function(empty_type);

let mut func = wasm_encoder::Function::new([(1, wasm_encoder::ValType::I32)]);
// Grow the memory by 1 page to allocate ourselves some stack space.
func.instruction(&I32Const(1));
func.instruction(&MemoryGrow(0));
func.instruction(&LocalTee(0));

// Test if the return value of the growth was -1 and trap if so
// since we don't have a stack page.
func.instruction(&I32Const(-1));
func.instruction(&I32Eq);
func.instruction(&If(wasm_encoder::BlockType::Empty));
func.instruction(&Unreachable);
func.instruction(&End);

// Set our stack pointer to the top of the page we were given, which
// is the page index times the page size plus the size of a page.
func.instruction(&LocalGet(0));
func.instruction(&I32Const(1));
func_names.push((function_index, "initialize_stack_pointer"));

let mut func = wasm_encoder::Function::new([]);
func.instruction(&I32Const(0));
func.instruction(&I32Const(0));
func.instruction(&I32Const(8));
func.instruction(&I32Const(PAGE_SIZE));
func.instruction(&Call(realloc_index.unwrap()));
func.instruction(&I32Const(PAGE_SIZE));
func.instruction(&I32Add);
func.instruction(&I32Const(16));
func.instruction(&I32Shl);
func.instruction(&GlobalSet(sp));
func.instruction(&End);
code.function(&func);

start = Some(wasm_encoder::StartSection {
function_index: num_funcs,
});
start = Some(wasm_encoder::StartSection { function_index });
}
}

Expand Down Expand Up @@ -622,7 +748,6 @@ impl<'a> Module<'a> {

// Append a custom `name` section using the names of the functions that
// were found prior to the GC pass in the original module.
let mut func_names = Vec::new();
let mut global_names = Vec::new();
for (i, _func) in self.live_funcs() {
let name = match self.func_names.get(&i) {
Expand All @@ -631,9 +756,6 @@ impl<'a> Module<'a> {
};
func_names.push((map.funcs.remap(i), *name));
}
if start.is_some() {
func_names.push((num_funcs, "initialize_stack_pointer"));
}
for (i, _global) in self.live_globals() {
let name = match self.global_names.get(&i) {
Some(name) => name,
Expand All @@ -655,6 +777,9 @@ impl<'a> Module<'a> {
section.push(code);
subsection.encode(&mut section);
};
if let (Some(realloc_index), None) = (realloc_index, main_module_realloc) {
func_names.push((realloc_index, "realloc_via_memory_grow"));
}
encode_subsection(0x01, &func_names);
encode_subsection(0x07, &global_names);
if !section.is_empty() {
Expand Down Expand Up @@ -977,31 +1102,27 @@ mod bitvec {
#[derive(Default)]
struct Remap {
/// Map, indexed by the old index set, to the new index set.
///
/// Placeholders of `u32::MAX` means that the old index is not present in
/// the new index space.
map: Vec<u32>,
map: HashMap<u32, u32>,
/// The next available index in the new index space.
next: u32,
}

impl Remap {
/// Appends a new live "old index" into this remapping structure.
///
/// This will assign a new index for the old index provided. This method
/// must be called in increasing order of old indexes.
/// This will assign a new index for the old index provided.
fn push(&mut self, old: u32) {
self.map.resize(old as usize, u32::MAX);
self.map.push(self.next);
self.map.insert(old, self.next);
self.next += 1;
}

/// Returns the new index corresponding to an old index.
///
/// Panics if the `old` index was not added via `push` above.
fn remap(&self, old: u32) -> u32 {
let ret = self.map[old as usize];
assert!(ret != u32::MAX);
ret
*self
.map
.get(&old)
.unwrap_or_else(|| panic!("can't map {old} to a new index"))
}
}
6 changes: 4 additions & 2 deletions crates/wit-component/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use wit_parser::{
};

fn is_canonical_function(name: &str) -> bool {
name.starts_with("cabi_")
name.starts_with("cabi_") || name.starts_with("canonical_abi_")
}

fn wasm_sig_to_func_type(signature: WasmSignature) -> FuncType {
Expand Down Expand Up @@ -139,7 +139,9 @@ pub fn validate_module<'a>(
if is_canonical_function(export.name) {
// TODO: validate that the cabi_realloc
// function is [i32, i32, i32, i32] -> [i32]
if export.name == "cabi_realloc" {
if export.name == "cabi_realloc"
|| export.name == "canonical_abi_realloc"
{
ret.realloc = Some(export.name);
}
continue;
Expand Down
Loading

0 comments on commit 5ba052b

Please sign in to comment.