Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the neural network api support. #133

Merged
merged 2 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ libc = { version = "0.2.112", default-features = true }
wasmtime = "=26.0.0"
wasmtime-wasi = "=26.0.0"
wiggle-generate = "=26.0.0"
wasmtime-wasi-nn = { version = "=26.0.0" }
wasmtime-wasi-threads = "=26.0.0"
wasi-common = { path = "crates/wasi-common", version="=26.0.0" }
# witx dependency by wiggle
Expand Down
1 change: 1 addition & 0 deletions blockless/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ wasi-common = {workspace = true}
wasmtime = {workspace = true}
wasmtime-wasi = {workspace = true}
cap-std = {workspace = true}
wasmtime-wasi-nn = {workspace = true}
blockless-drivers = {workspace = true}
blockless-multiaddr = {workspace = true}
blockless-env = {path = "../crates/blockless-env"}
Expand Down
6 changes: 6 additions & 0 deletions blockless/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ pub(crate) struct BlocklessContext {

pub(crate) wasi_threads: Option<Arc<WasiThreadsCtx<BlocklessContext>>>,

pub(crate) wasi_nn_wit: Option<Arc<wasmtime_wasi_nn::wit::WasiNnCtx>>,

pub(crate) wasi_nn_witx: Option<Arc<wasmtime_wasi_nn::witx::WasiNnCtx>>,

pub(crate) store_limits: StoreLimits,
}

impl Default for BlocklessContext {
fn default() -> Self {
Self {
wasi_nn_wit: None,
wasi_nn_witx: None,
preview1_ctx: None,
preview2_ctx: None,
wasi_threads: None,
Expand Down
52 changes: 52 additions & 0 deletions blockless/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use wasmtime::{
component::Component, Config, Engine, Linker, Module, Precompiled, Store, StoreLimits,
StoreLimitsBuilder, Trap,
};
use wasmtime_wasi::WasiView;
use wasmtime_wasi::{DirPerms, FilePerms};
use wasmtime_wasi_nn::wit::WasiNnView;
use wasmtime_wasi_threads::WasiThreadsCtx;

// the default wasm entry name.
Expand Down Expand Up @@ -331,6 +333,9 @@ impl BlocklessRunner {
let (mut linker, mut run_target, entry) =
self.module_linker(entry, &engine, &mut store).await?;
let mut is_component = false;
if b_conf.nn {
self.nn_setup(&mut linker, &mut store)?;
}
// prepare linker.
match linker {
BlsLinker::Core(ref mut linker) => {
Expand Down Expand Up @@ -385,6 +390,18 @@ impl BlocklessRunner {
Ok(())
}

fn collect_preloaded_nn_graphs(
&self,
) -> AnyResult<(Vec<wasmtime_wasi_nn::Backend>, wasmtime_wasi_nn::Registry)> {
let graphs = self
.0
.nn_graph
.iter()
.map(|g| (g.format.clone(), g.dir.clone()))
.collect::<Vec<_>>();
wasmtime_wasi_nn::preload(&graphs)
}

fn write_core_dump(
store: &mut Store<BlocklessContext>,
err: &anyhow::Error,
Expand Down Expand Up @@ -481,6 +498,41 @@ impl BlocklessRunner {
result
}

fn nn_setup(
&self,
linker: &mut BlsLinker,
store: &mut Store<BlocklessContext>,
) -> AnyResult<()> {
let (backends, registry) = self.collect_preloaded_nn_graphs()?;
match linker {
BlsLinker::Core(linker) => {
wasmtime_wasi_nn::witx::add_to_linker(linker, |host| {
Arc::get_mut(host.wasi_nn_witx.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support")
})?;
store.data_mut().wasi_nn_witx = Some(Arc::new(
wasmtime_wasi_nn::witx::WasiNnCtx::new(backends, registry),
));
}
BlsLinker::Component(linker) => {
wasmtime_wasi_nn::wit::add_to_linker(linker, |h: &mut BlocklessContext| {
let preview2_ctx = h.preview2_ctx.as_mut().expect("wasip2 is not configured");
let preview2_ctx = Arc::get_mut(preview2_ctx)
.expect("wasmtime_wasi is not compatible with threads")
.get_mut()
.unwrap();
let nn_ctx = Arc::get_mut(h.wasi_nn_wit.as_mut().unwrap())
.expect("wasi-nn is not implemented with multi-threading support");
WasiNnView::new(preview2_ctx.table(), nn_ctx)
})?;
store.data_mut().wasi_nn_wit = Some(Arc::new(
wasmtime_wasi_nn::wit::WasiNnCtx::new(backends, registry),
));
}
}
Ok(())
}

fn handle_core_dump(
cfg: &BlocklessConfig,
store: &mut Store<BlocklessContext>,
Expand Down
29 changes: 27 additions & 2 deletions bls-runtime/src/cli_clap.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#![allow(unused)]
use anyhow::{bail, Result};
use blockless::{
BlocklessConfig, BlocklessModule, BlsOptions, ModuleType, OptimizeOpts, Permission, Stderr,
Stdin, Stdout,
BlocklessConfig, BlocklessModule, BlsNnGraph, BlsOptions, ModuleType, OptimizeOpts, Permission,
Stderr, Stdin, Stdout,
};
use clap::{
builder::{TypedValueParser, ValueParser},
Expand Down Expand Up @@ -80,6 +80,12 @@ const NETWORK_ERROR_CODE_HELP: &str =

const MAX_MEMORY_SIZE_HELP: &str = "The max memory size limited.";

const NN_HELP: &str = "Enable support for WASI neural network imports .";

const NN_GRAPH_HELP: &str =
"Pre-load machine learning graphs (i.e., models) for use by wasi-nn. \
Each use of the flag will preload a ML model from the host directory using the given model encoding";

fn parse_envs(envs: &str) -> Result<(String, String)> {
let parts: Vec<_> = envs.splitn(2, "=").collect();
if parts.len() != 2 {
Expand All @@ -88,6 +94,17 @@ fn parse_envs(envs: &str) -> Result<(String, String)> {
Ok((parts[0].to_string(), parts[1].to_string()))
}

fn parse_nn_graph(envs: &str) -> Result<BlsNnGraph> {
let parts: Vec<_> = envs.splitn(2, "=").collect();
if parts.len() != 2 {
bail!("must be of the form `key=value`")
}
Ok(BlsNnGraph {
format: parts[0].to_string(),
dir: parts[1].to_string(),
})
}

fn parse_opts(opt: &str) -> Result<OptimizeOpts> {
let kvs: Vec<_> = opt.splitn(2, ",").collect();
if kvs.len() == 1 {
Expand Down Expand Up @@ -280,6 +297,12 @@ pub(crate) struct CliCommandOpts {

#[clap(long = "max_memory_size", value_name = "MAX_MEMORY_SIZE", help = MAX_MEMORY_SIZE_HELP)]
max_memory_size: Option<u64>,

#[clap(long = "nn", value_name = "NN", help = NN_HELP)]
nn: bool,

#[clap(long = "nn-graph", value_name = "NN_GRAPH", value_parser = parse_nn_graph, help = NN_GRAPH_HELP)]
nn_graph: Vec<BlsNnGraph>,
}

impl CliCommandOpts {
Expand Down Expand Up @@ -354,9 +377,11 @@ impl CliCommandOpts {
conf.0
.set_version(blockless::BlocklessConfigVersion::Version1);
}
conf.0.nn = self.nn;
conf.0.tcp_listens = self.tcp_listens;
conf.0.network_error_code = self.network_error_code;
conf.0.unknown_imports_trap = self.unknown_imports_trap;
conf.0.nn_graph = self.nn_graph;
Ok(())
}

Expand Down
10 changes: 10 additions & 0 deletions crates/wasi-common/src/blockless/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,23 @@ impl Default for Stdio {
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct BlsNnGraph {
pub format: String,
pub dir: String,
}

#[derive(Clone)]
pub struct BlocklessConfig {
pub entry: String,
pub nn: bool,
pub stdio: Stdio,
pub debug_info: bool,
pub is_carfile: bool,
pub opts: OptimizeOpts,
pub feature_thread: bool,
pub run_time: Option<u64>,
pub nn_graph: Vec<BlsNnGraph>,
pub stdin_args: Vec<String>,
pub coredump: Option<String>,
pub limited_fuel: Option<u64>,
Expand Down Expand Up @@ -406,6 +414,7 @@ pub struct BlocklessConfig {
impl BlocklessConfig {
pub fn new(entry: &str) -> BlocklessConfig {
Self {
nn: false,
run_time: None,
coredump: None,
envs: Vec::new(),
Expand All @@ -431,6 +440,7 @@ impl BlocklessConfig {
extensions_path: None,
drivers_root_path: None,
unknown_imports_trap: false,
nn_graph: Vec::new(),
entry: String::from(entry),
permisions: Default::default(),
group_permisions: HashMap::new(),
Expand Down
Loading