Skip to content

Commit

Permalink
works on the docsite but not the harness
Browse files Browse the repository at this point in the history
  • Loading branch information
jkelleyrtp committed Feb 11, 2025
1 parent c4a13ea commit ea8fc7a
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 175 deletions.
189 changes: 133 additions & 56 deletions packages/wasm-split/wasm-split-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct Splitter<'a> {
split_points: Vec<SplitPoint>,
chunks: Vec<HashSet<Node>>,
data_symbols: BTreeMap<usize, DataSymbol>,
main_graph: ReachabilityGraph,
main_graph: HashSet<Node>,
call_graph: HashMap<Node, HashSet<Node>>,
parent_graph: HashMap<Node, HashSet<Node>>,
extra_symbols: HashSet<Node>,
Expand Down Expand Up @@ -157,7 +157,7 @@ impl<'a> Splitter<'a> {
self.replace_segments_with_holes(&mut out, &unused_symbols);

// 2. Wipe away the unused functions and data symbols
self.prune_main_symbols(&mut out, &unused_symbols);
self.prune_main_symbols(&mut out, &unused_symbols)?;

// 3. Change the functions called from split modules to be local functions that call the indirect function
self.create_ifunc_table(&mut out);
Expand Down Expand Up @@ -191,27 +191,33 @@ impl<'a> Splitter<'a> {
// These are the symbols that will only exist in this module and not in the main module.
let mut unique_symbols = split
.reachable_graph
.reachable
.difference(&self.main_graph.reachable)
.difference(&self.main_graph)
.cloned()
.collect::<HashSet<_>>();

// The functions we'll need to import
let mut symbols_to_import: HashSet<_> = split
.reachable_graph
.reachable
.intersection(&self.main_graph.reachable)
.intersection(&self.main_graph)
.cloned()
.collect();

// Identify the functions we'll delete
let mut symbols_to_delete: HashSet<_> = self
.main_graph
.reachable
.difference(&split.reachable_graph.reachable)
.difference(&split.reachable_graph)
.cloned()
.collect();

// for s in symbols_to_import.iter() {
// symbols_to_delete.remove(s);
// }

// for extra in self.extra_symbols.iter() {
// // symbols_to_import.insert(extra.clone());
// symbols_to_delete.remove(extra);
// }

// // Convert split chunk functions to imports
let mut relies_on_chunks = HashSet::new();
// tracing::info!("There are {} chunks", self.chunks.len());
Expand Down Expand Up @@ -257,7 +263,7 @@ impl<'a> Splitter<'a> {
self.add_split_imports(&mut out, split.index, split_export_func, split.export_name);

// Delete all the functions that are not reachable from the main module
self.delete_main_funcs_from_split(&mut out, &symbols_to_delete);
self.delete_main_funcs_from_split(&mut out, &symbols_to_delete, &ids_to_fns);

// Remove the reloc and linking custom sections
self.remove_custom_sections(&mut out);
Expand Down Expand Up @@ -289,7 +295,6 @@ impl<'a> Splitter<'a> {
// The functions we'll need to import
let symbols_to_import: HashSet<_> = self
.main_graph
.reachable
.intersection(&unique_symbols)
.cloned()
.collect();
Expand All @@ -298,7 +303,7 @@ impl<'a> Splitter<'a> {
let mut symbols_to_export = HashSet::new();
for sym in unique_symbols.iter() {
for split in self.split_points.iter() {
if split.reachable_graph.reachable.contains(sym) {
if split.reachable_graph.contains(sym) {
symbols_to_export.insert(*sym);
}
}
Expand Down Expand Up @@ -329,7 +334,7 @@ impl<'a> Splitter<'a> {
self.re_export_functions(&mut out, &symbols_to_export);

// Make sure we haven't deleted anything important....
self.delete_main_funcs_from_split(&mut out, &symbols_to_delete);
self.delete_main_funcs_from_split(&mut out, &symbols_to_delete, &ids_to_fns);

// We have to make sure our table matches that of the other tables even though we don't call them.
let ifunc_table_id = self.load_funcref_table(&mut out);
Expand Down Expand Up @@ -628,7 +633,46 @@ impl<'a> Splitter<'a> {
));
}

fn delete_main_funcs_from_split(&self, out: &mut Module, symbols_to_delete: &HashSet<Node>) {
fn delete_main_funcs_from_split(
&self,
out: &mut Module,
symbols_to_delete: &HashSet<Node>,
ids_to_fns: &[FunctionId],
) {
// let injected_symbols = self.remap_ids(self.extra_symbols.clone(), &ids_to_fns);
// let mut deleted_functions = HashSet::new();
// let _r = "__________".to_string();

// for node in symbols_to_delete {
// if let Node::Function(id) = *node {
// if !injected_symbols.contains(node) {
// let func = out.funcs.get(id);
// let func_name = func.name.as_ref();
// let func_name = func_name.unwrap_or(&_r);

// // if func_name == "_ZN5alloc7raw_vec20RawVecInner$LT$A$GT$17try_reserve_exact17hbb1ba48adad83534E" {
// // tracing::error!("deleting {:?}", func);
// // }

// // // we shouldn't delete unnamed functions?
// // let Some(func_name) = func_name else {
// // tracing::error!("Could not find name for function {:?}", func);
// // continue;
// // };

// // let FunctionKind::Local(func) = &func.kind else {
// // continue;
// // };

// // n.contains("__externref_table_")
// if !func_name.contains("__externref_table_") {
// out.funcs.delete(id);
// deleted_functions.insert(*node);
// }
// }
// }
// }

for node in symbols_to_delete {
if let Node::Function(id) = *node {
out.funcs.delete(id);
Expand Down Expand Up @@ -730,6 +774,10 @@ impl<'a> Splitter<'a> {
for unique in unique_symbols {
if let Node::DataSymbol(id) = unique {
let symbol = self.data_symbols.get(&id).expect("missing data symbol");
if symbol.which_data_segment != 0 {
continue;
}

let range = symbol.segment_offset..symbol.segment_offset + symbol.symbol_size;
let offset = ConstExpr::Value(ir::Value::I32(
data_offset + symbol.segment_offset as i32,
Expand Down Expand Up @@ -847,7 +895,7 @@ impl<'a> Splitter<'a> {
// We're only going to try optimizing functions used across multiple chunks
let mut funcs_used_by_chunks: HashMap<Node, HashSet<usize>> = HashMap::new();
for split in self.split_points.iter() {
for item in split.reachable_graph.reachable.iter() {
for item in split.reachable_graph.iter() {
funcs_used_by_chunks
.entry(item.clone())
.or_default()
Expand Down Expand Up @@ -933,12 +981,7 @@ impl<'a> Splitter<'a> {
let mut shared_funcs = HashSet::new();

for split in self.split_points.iter() {
shared_funcs.extend(
split
.reachable_graph
.reachable
.intersection(&self.main_graph.reachable),
);
shared_funcs.extend(split.reachable_graph.intersection(&self.main_graph));
}

for injected in self.extra_symbols.iter() {
Expand All @@ -949,16 +992,63 @@ impl<'a> Splitter<'a> {
}

fn unused_main_symbols(&self) -> HashSet<Node> {
// let mut unique = HashSet::new();

// // Collect *every* symbol
// let all = self.reachable_from_all();

// // get the reachable symbols from every split combined with main
// let mut reachable_from_every = self.main_graph.clone();
// for split in self.split_points.iter() {
// reachable_from_every.extend(split.reachable_graph.iter().cloned());
// unique.extend(split.reachable_graph.difference(&self.main_graph));
// }

// // These are symbols we can't delete in the main module
// let to_save: HashSet<Node> = all.difference(&reachable_from_every).cloned().collect();
// unique.difference(&to_save).cloned().collect()

// let mut unique = HashSet::new();
// // Collect *every* symbol
// // let all = self.reachable_from_all();

// // all.difference(&self.main_graph.reachable)
// // .cloned()
// // .collect()

// // // get the reachable symbols from every split combined with main
// // let mut reachable_from_every = self.main_graph.reachable.clone();
// for split in self.split_points.iter() {
// // reachable_from_every.extend(split_reachable.reachable.iter().cloned());
// unique.extend((&split.reachable_graph).difference(&self.main_graph));
// }

// for import in self.source_module.imports.iter() {
// if let ImportKind::Function(func) = import.kind {
// unique.remove(&Node::Function(func));
// }
// }

// for export in self.source_module.exports.iter() {
// if let ExportItem::Function(func) = export.item {
// unique.remove(&Node::Function(func));
// }
// }

// These are symbols we can't delete in the main module
// let to_save: HashSet<Node> = all.difference(&reachable_from_every).cloned().collect();
// unique.difference(&to_save).cloned().collect()
// unique

let mut unique = HashSet::new(); // self.main_graph.reachable.clone();

for split in self.split_points.iter() {
let roots = [Node::Function(split.export_func)].into();

let graph = ReachabilityGraph::new(&self.call_graph, &roots, &Default::default());
let graph = make_call_graph(&self.call_graph, &roots);

let unique_symbols = graph
.reachable
.difference(&self.main_graph.reachable)
.difference(&self.main_graph)
.cloned()
.collect::<HashSet<_>>();

Expand All @@ -971,7 +1061,7 @@ impl<'a> Splitter<'a> {
tracing::error!("found extra symbol: {:?}", _u);
}

if self.main_graph.reachable.contains(_u) {
if self.main_graph.contains(_u) {
tracing::error!("found main symbol: {:?}", _u);
}

Expand Down Expand Up @@ -1093,13 +1183,11 @@ impl<'a> Splitter<'a> {
}
}

split.reachable_graph =
ReachabilityGraph::new(&self.call_graph, &roots, &Default::default());
split.reachable_graph = make_call_graph(&self.call_graph, &roots);
});

// And then the reachability graph for main
self.main_graph =
ReachabilityGraph::new(&self.call_graph, &self.main_roots(), &Default::default());
self.main_graph = make_call_graph(&self.call_graph, &self.main_roots());

Ok(())
}
Expand Down Expand Up @@ -1355,7 +1443,7 @@ pub struct SplitPoint {
export_func: FunctionId,
component_name: String,
index: usize,
reachable_graph: ReachabilityGraph,
reachable_graph: HashSet<Node>,
hash_name: String,

#[allow(unused)]
Expand Down Expand Up @@ -1426,43 +1514,32 @@ fn accumulate_split_points(module: &Module) -> Vec<SplitPoint> {
.collect()
}

#[derive(Debug, Default, Clone)]
pub struct ReachabilityGraph {
reachable: HashSet<Node>,
}

#[derive(Debug, PartialEq, Eq, Hash, Copy, PartialOrd, Ord, Clone)]
pub enum Node {
Function(FunctionId),
DataSymbol(usize),
}

impl ReachabilityGraph {
fn new(
deps: &HashMap<Node, HashSet<Node>>,
roots: &HashSet<Node>,
exclude: &HashSet<Node>,
) -> ReachabilityGraph {
let mut queue: VecDeque<Node> = roots.iter().copied().collect();
let mut reachable = HashSet::<Node>::new();
let mut parents = HashMap::<Node, Node>::new();

while let Some(node) = queue.pop_front() {
reachable.insert(node);
let Some(children) = deps.get(&node) else {
fn make_call_graph(deps: &HashMap<Node, HashSet<Node>>, roots: &HashSet<Node>) -> HashSet<Node> {
let mut queue: VecDeque<Node> = roots.iter().copied().collect();
let mut reachable = HashSet::<Node>::new();
let mut parents = HashMap::<Node, Node>::new();

while let Some(node) = queue.pop_front() {
reachable.insert(node);
let Some(children) = deps.get(&node) else {
continue;
};
for child in children {
if reachable.contains(&child) {
continue;
};
for child in children {
if reachable.contains(&child) || exclude.contains(&child) {
continue;
}
parents.entry(*child).or_insert(node);
queue.push_back(*child);
}
parents.entry(*child).or_insert(node);
queue.push_back(*child);
}

ReachabilityGraph { reachable }
}

reachable
}

struct RawDataSection<'a> {
Expand Down
11 changes: 5 additions & 6 deletions packages/wasm-split/wasm-split-harness/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ enum Route {
#[layout(Nav)]
#[route("/")]
Home,

#[route("/child")]
ChildSplit,
// #[route("/child")]
// ChildSplit,
}

fn Nav() -> Element {
rsx! {
div {
Link { to: Route::Home, "Home" }
Link { to: Route::ChildSplit, "Child" }
// Link { to: Route::ChildSplit, "Child" }
Outlet::<Route> {}
}
}
Expand Down Expand Up @@ -68,7 +67,7 @@ fn Home(args: ()) -> Element {
// }
h3 { "Global Counter: {GLOBAL_COUNTER}" }
div { id: "output-box" }
// ChildSplit {}
ChildSplit {}
}
}

Expand Down Expand Up @@ -157,7 +156,7 @@ fn ChildSplit() -> Element {
d: "M13.78 4.22a.75.75 0 010 1.06l-7.25 7.25a.75.75 0 01-1.06 0L2.22 9.28a.75.75 0 011.06-1.06L6 10.94l6.72-6.72a.75.75 0 011.06 0z",
fill_rule: "evenodd",
}

}
button {
onclick: move |_| {
Expand Down
Loading

0 comments on commit ea8fc7a

Please sign in to comment.