Skip to content

Commit

Permalink
Merge pull request #133 from blocklessnetwork/feature/wasi-nn
Browse files Browse the repository at this point in the history
the neural network api support.
  • Loading branch information
Joinhack authored Dec 21, 2024
2 parents 563f4a4 + 8057d56 commit 6403b11
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ libc = { version = "0.2.112", default-features = true }
wasmtime = "=26.0.0"
wasmtime-wasi = "=26.0.0"
wiggle-generate = "=26.0.0"
wasmtime-wasi-nn = { version = "=26.0.0" }
wasmtime-wasi-threads = "=26.0.0"
wasi-common = { path = "crates/wasi-common", version="=26.0.0" }
# witx dependency by wiggle
Expand Down
1 change: 1 addition & 0 deletions blockless/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ wasi-common = {workspace = true}
wasmtime = {workspace = true}
wasmtime-wasi = {workspace = true}
cap-std = {workspace = true}
wasmtime-wasi-nn = {workspace = true}
blockless-drivers = {workspace = true}
blockless-multiaddr = {workspace = true}
blockless-env = {path = "../crates/blockless-env"}
Expand Down
6 changes: 6 additions & 0 deletions blockless/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ pub(crate) struct BlocklessContext {

pub(crate) wasi_threads: Option<Arc<WasiThreadsCtx<BlocklessContext>>>,

pub(crate) wasi_nn_wit: Option<Arc<wasmtime_wasi_nn::wit::WasiNnCtx>>,

pub(crate) wasi_nn_witx: Option<Arc<wasmtime_wasi_nn::witx::WasiNnCtx>>,

pub(crate) store_limits: StoreLimits,
}

impl Default for BlocklessContext {
fn default() -> Self {
Self {
wasi_nn_wit: None,
wasi_nn_witx: None,
preview1_ctx: None,
preview2_ctx: None,
wasi_threads: None,
Expand Down
52 changes: 52 additions & 0 deletions blockless/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use wasmtime::{
component::Component, Config, Engine, Linker, Module, Precompiled, Store, StoreLimits,
StoreLimitsBuilder, Trap,
};
use wasmtime_wasi::WasiView;
use wasmtime_wasi::{DirPerms, FilePerms};
use wasmtime_wasi_nn::wit::WasiNnView;
use wasmtime_wasi_threads::WasiThreadsCtx;

// the default wasm entry name.
Expand Down Expand Up @@ -331,6 +333,9 @@ impl BlocklessRunner {
let (mut linker, mut run_target, entry) =
self.module_linker(entry, &engine, &mut store).await?;
let mut is_component = false;
if b_conf.nn {
self.nn_setup(&mut linker, &mut store)?;
}
// prepare linker.
match linker {
BlsLinker::Core(ref mut linker) => {
Expand Down Expand Up @@ -385,6 +390,18 @@ impl BlocklessRunner {
Ok(())
}

fn collect_preloaded_nn_graphs(
&self,
) -> AnyResult<(Vec<wasmtime_wasi_nn::Backend>, wasmtime_wasi_nn::Registry)> {
let graphs = self
.0
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
wasmtime_wasi_nn::preload(&graphs)
}

fn write_core_dump(
store: &mut Store<BlocklessContext>,
err: &anyhow::Error,
Expand Down Expand Up @@ -481,6 +498,41 @@ impl BlocklessRunner {
result
}

fn nn_setup(
&self,
linker: &mut BlsLinker,
store: &mut Store<BlocklessContext>,
) -> AnyResult<()> {
let (backends, registry) = self.collect_preloaded_nn_graphs()?;
match linker {
BlsLinker::Core(linker) => {
wasmtime_wasi_nn::witx::add_to_linker(linker, |host| {
Arc::get_mut(host.wasi_nn_witx.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})?;
store.data_mut().wasi_nn_witx = Some(Arc::new(
wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry),
));
}
BlsLinker::Component(linker) => {
wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut BlocklessContext| {
let preview2_ctx = h.preview2_ctx.as_mut().expect("wasip2 is not configured");
let preview2_ctx = Arc::get_mut(preview2_ctx)
.expect("wasmtime_wasi is not compatible with threads")
.get_mut()
.unwrap();
let nn_ctx = Arc::get_mut(h.wasi_nn_wit.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support");
WasiNnView::new(preview2_ctx.table(), nn_ctx)
})?;
store.data_mut().wasi_nn_wit = Some(Arc::new(
wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry),
));
}
}
Ok(())
}

fn handle_core_dump(
cfg: &BlocklessConfig,
store: &mut Store<BlocklessContext>,
Expand Down
29 changes: 27 additions & 2 deletions bls-runtime/src/cli_clap.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![allow(unused)]
use anyhow::{bail, Result};
use blockless::{
BlocklessConfig, BlocklessModule, BlsOptions, ModuleType, OptimizeOpts, Permission, Stderr,
Stdin, Stdout,
BlocklessConfig, BlocklessModule, BlsNnGraph, BlsOptions, ModuleType, OptimizeOpts, Permission,
Stderr, Stdin, Stdout,
};
use clap::{
builder::{TypedValueParser, ValueParser},
Expand Down Expand Up @@ -80,6 +80,12 @@ const NETWORK_ERROR_CODE_HELP: &str =

const MAX_MEMORY_SIZE_HELP: &str = "The max memory size limited.";

const NN_HELP: &str = "Enable support for WASI neural network imports .";

const NN_GRAPH_HELP: &str =
"Pre-load machine learning graphs (i.e., models) for use by wasi-nn. \
Each use of the flag will preload a ML model from the host directory using the given model encoding";

fn parse_envs(envs: &str) -> Result<(String, String)> {
let parts: Vec<_> = envs.splitn(2, "=").collect();
if parts.len() != 2 {
Expand All @@ -88,6 +94,17 @@ fn parse_envs(envs: &str) -> Result<(String, String)> {
Ok((parts[0].to_string(), parts[1].to_string()))
}

fn parse_nn_graph(envs: &str) -> Result<BlsNnGraph> {
let parts: Vec<_> = envs.splitn(2, "=").collect();
if parts.len() != 2 {
bail!("must be of the form `key=value`")
}
Ok(BlsNnGraph {
format: parts[0].to_string(),
dir: parts[1].to_string(),
})
}

fn parse_opts(opt: &str) -> Result<OptimizeOpts> {
let kvs: Vec<_> = opt.splitn(2, ",").collect();
if kvs.len() == 1 {
Expand Down Expand Up @@ -280,6 +297,12 @@ pub(crate) struct CliCommandOpts {

#[clap(long = "max_memory_size", value_name = "MAX_MEMORY_SIZE", help = MAX_MEMORY_SIZE_HELP)]
max_memory_size: Option<u64>,

#[clap(long = "nn", value_name = "NN", help = NN_HELP)]
nn: bool,

#[clap(long = "nn-graph", value_name = "NN_GRAPH", value_parser = parse_nn_graph, help = NN_GRAPH_HELP)]
nn_graph: Vec<BlsNnGraph>,
}

impl CliCommandOpts {
Expand Down Expand Up @@ -354,9 +377,11 @@ impl CliCommandOpts {
conf.0
.set_version(blockless::BlocklessConfigVersion::Version1);
}
conf.0.nn = self.nn;
conf.0.tcp_listens = self.tcp_listens;
conf.0.network_error_code = self.network_error_code;
conf.0.unknown_imports_trap = self.unknown_imports_trap;
conf.0.nn_graph = self.nn_graph;
Ok(())
}

Expand Down
10 changes: 10 additions & 0 deletions crates/wasi-common/src/blockless/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,23 @@ impl Default for Stdio {
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct BlsNnGraph {
pub format: String,
pub dir: String,
}

#[derive(Clone)]
pub struct BlocklessConfig {
pub entry: String,
pub nn: bool,
pub stdio: Stdio,
pub debug_info: bool,
pub is_carfile: bool,
pub opts: OptimizeOpts,
pub feature_thread: bool,
pub run_time: Option<u64>,
pub nn_graph: Vec<BlsNnGraph>,
pub stdin_args: Vec<String>,
pub coredump: Option<String>,
pub limited_fuel: Option<u64>,
Expand Down Expand Up @@ -406,6 +414,7 @@ pub struct BlocklessConfig {
impl BlocklessConfig {
pub fn new(entry: &str) -> BlocklessConfig {
Self {
nn: false,
run_time: None,
coredump: None,
envs: Vec::new(),
Expand All @@ -431,6 +440,7 @@ impl BlocklessConfig {
extensions_path: None,
drivers_root_path: None,
unknown_imports_trap: false,
nn_graph: Vec::new(),
entry: String::from(entry),
permisions: Default::default(),
group_permisions: HashMap::new(),
Expand Down

0 comments on commit 6403b11

Please sign in to comment.