Skip to content

Commit

Permalink
improved exception handling (#2)
Browse files Browse the repository at this point in the history
Signed-off-by: Georgi Georgiev <[email protected]>
  • Loading branch information
georg-getz authored Jun 12, 2023
1 parent 9e3b36e commit 91600ed
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 48 deletions.
38 changes: 36 additions & 2 deletions src/exception.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,56 @@
use std::num::TryFromIntError;
use std::str::Utf8Error;
use jni::JNIEnv;
use std::thread;
use jni::errors::Error as JNIError;
use wasmer::MemoryError;
use wasmer_wasi::{WasiError, WasiStateCreationError};

#[derive(Debug)]
pub enum Error {
JNIError(JNIError),
Message(String),
WasiError(WasiError),
WasiStateCreationError(WasiStateCreationError),
MemoryError(MemoryError),
TryFromIntError(TryFromIntError),
Utf8Error(Utf8Error)
}

impl From<JNIError> for Error {
fn from(err: JNIError) -> Self { Self::JNIError(err) }
}

impl From<WasiError> for Error {
fn from(err: WasiError) -> Self { Self::WasiError(err) }
}

impl From<WasiStateCreationError> for Error {
fn from(err: WasiStateCreationError) -> Self { Self::WasiStateCreationError(err) }
}

impl From<MemoryError> for Error {
fn from(err: MemoryError) -> Self { Self::MemoryError(err) }
}

impl From<TryFromIntError> for Error {
fn from(err: TryFromIntError) -> Self { Self::TryFromIntError(err) }
}

impl From<Utf8Error> for Error {
fn from(err: Utf8Error) -> Self { Self::Utf8Error(err) }
}

impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
Error::JNIError(err) => err.fmt(fmt),
Error::Message(msg) => msg.fmt(fmt),
Error::WasiError(err) => err.fmt(fmt),
Error::WasiStateCreationError(err) => err.fmt(fmt),
Error::MemoryError(err) => err.fmt(fmt),
Error::TryFromIntError(err) => err.fmt(fmt),
Error::Utf8Error(err) => err.fmt(fmt),
}
}
}
Expand Down Expand Up @@ -47,15 +81,15 @@ pub fn joption_or_throw<T>(env: &JNIEnv, result: thread::Result<Result<T, Error>
Err(error) => {
if !env.exception_check().unwrap() {
env.throw_new("java/lang/RuntimeException", &error.to_string())
.expect("Cannot throw an `java/lang/RuntimeException` exception.");
.expect("Cannot throw a `java/lang/RuntimeException` exception.");
}

JOption::None
}
},
Err(ref error) => {
env.throw_new("java/lang/RuntimeException", format!("{:?}", error))
.expect("Cannot throw an `java/lang/RuntimeException` exception.");
.expect("Cannot throw a `java/lang/RuntimeException` exception.");

JOption::None
}
Expand Down
68 changes: 28 additions & 40 deletions src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,59 +39,60 @@ pub extern "system" fn Java_org_wasmer_Imports_nativeImportsInstantiate(
let mut import_object = ImportObject::new();
let module: &Module = Into::<Pointer<Module>>::into(module).borrow();
let store = module.module.store();
let imports = env.get_list(imports).unwrap();
let imports = env.get_list(imports)?;

for import in imports.iter().unwrap() {
let namespace = env.get_field(import, "namespace", "Ljava/lang/String;").unwrap().l().unwrap();
let namespace = env.get_string(namespace.into()).unwrap().to_str().unwrap().to_string();
let name = env.get_field(import, "name", "Ljava/lang/String;").unwrap().l().unwrap();
let name = env.get_string(name.into()).unwrap().to_str().unwrap().to_string();
for import in imports.iter()? {
let namespace = env.get_field(import, "namespace", "Ljava/lang/String;")?.l()?;
let namespace = env.get_string(namespace.into())?.to_str()?.to_string();
let name = env.get_field(import, "name", "Ljava/lang/String;")?.l()?;
let name = env.get_string(name.into())?.to_str()?.to_string();

if name == "memory" {
let min_pages = env.get_field(import, "minPages", "I").unwrap().i().unwrap();
let max_pages = env.get_field(import, "maxPages", "Ljava/lang/Integer;").unwrap().l().unwrap();
let min_pages = env.get_field(import, "minPages", "I")?.i()?;
let max_pages = env.get_field(import, "maxPages", "Ljava/lang/Integer;")?.l()?;
let max_pages = if max_pages.is_null() {
None
} else {
//have to get the field again if not null as it cannot be cast to int
let max_pages = env.get_field(import, "maxPages", "I").unwrap().i().unwrap();
Some(u32::try_from(max_pages).unwrap())
let max_pages = env.get_field(import, "maxPages", "I")?.i()?;
Some(u32::try_from(max_pages)?)
};
let shared = env.get_field(import, "shared", "Z").unwrap().z().unwrap();
let memory_type = MemoryType::new(u32::try_from(min_pages).unwrap(), max_pages, shared);
namespaces.entry(namespace).or_insert_with(|| Exports::new()).insert(name, Memory::new(&store, memory_type).unwrap())
let shared = env.get_field(import, "shared", "Z")?.z()?;
let memory_type = MemoryType::new(u32::try_from(min_pages)?, max_pages, shared);
namespaces.entry(namespace).or_insert_with(|| Exports::new()).insert(name, Memory::new(&store, memory_type)?)
} else {
let function = env.get_field(import, "function", "Ljava/util/function/Function;").unwrap().l().unwrap();
let params = env.get_field(import, "argTypesInt", "[I").unwrap().l().unwrap();
let returns = env.get_field(import, "retTypesInt", "[I").unwrap().l().unwrap();
let params = env.get_int_array_elements(*params, ReleaseMode::NoCopyBack).unwrap();
let returns = env.get_int_array_elements(*returns, ReleaseMode::NoCopyBack).unwrap();
let function = env.get_field(import, "function", "Ljava/util/function/Function;")?.l()?;
let params = env.get_field(import, "argTypesInt", "[I")?.l()?;
let returns = env.get_field(import, "retTypesInt", "[I")?.l()?;
let params = env.get_int_array_elements(*params, ReleaseMode::NoCopyBack)?;
let returns = env.get_int_array_elements(*returns, ReleaseMode::NoCopyBack)?;
let i2t = |i: &i32| match i { 1 => Type::I32, 2 => Type::I64, 3 => Type::F32, 4 => Type::F64, _ => unreachable!("Unknown {}", i)};
let params = array2vec(&params).into_iter().map(i2t).collect::<Vec<_>>();
let returns = array2vec(&returns).into_iter().map(i2t).collect::<Vec<_>>();
let sig = FunctionType::new(params.clone(), returns.clone());
let function = env.new_global_ref(function).unwrap();
let jvm = env.get_java_vm().unwrap();
let function = env.new_global_ref(function)?;
let jvm = env.get_java_vm()?;
namespaces.entry(namespace).or_insert_with(|| Exports::new()).insert(name, Function::new(store, sig, move |argv| {
// There are many ways of transferring the args from wasm to java, JList being the cleanest,
// but probably also slowest by far (two JNI calls per argument). Benchmark?
let env = jvm.get_env().unwrap();
let env = jvm.get_env().expect("Couldn't get JNIEnv");
env.ensure_local_capacity(argv.len() as i32 + 2).ok();
let jargv = env.new_long_array(argv.len() as i32).unwrap();
let jargv = env.new_long_array(argv.len() as i32).expect("Couldn't create array");
let argv = argv.into_iter().enumerate().map(|(i, arg)| match arg {
Value::I32(arg) => { assert_eq!(params[i], Type::I32); *arg as i64 },
Value::I64(arg) => { assert_eq!(params[i], Type::I64); *arg as i64 },
Value::F32(arg) => { assert_eq!(params[i], Type::F32); arg.to_bits() as i64 },
Value::F64(arg) => { assert_eq!(params[i], Type::F64); arg.to_bits() as i64 },
_ => panic!("Argument of unsupported type {:?}", arg)
}).collect::<Vec<jlong>>();
env.set_long_array_region(jargv, 0, &argv).unwrap();
let jret = env.call_method(function.as_obj(), "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", &[jargv.into()]).unwrap().l().unwrap();
env.set_long_array_region(jargv, 0, &argv).expect("Couldn't set array region");
let jret = env.call_method(function.as_obj(), "apply", "(Ljava/lang/Object;)Ljava/lang/Object;", &[jargv.into()])
.expect("Couldn't call 'apply' function").l().expect("Failed to unwrap object");
let ret = match returns.len() {
0 => vec![],
len => {
let mut ret = vec![0; len];
env.get_long_array_region(*jret, 0, &mut ret).unwrap();
env.get_long_array_region(*jret, 0, &mut ret).expect("Couldn't get array region");
ret.into_iter().enumerate().map(|(i, ret)| match returns[i] {
Type::I32 => Value::I32(ret as i32),
Type::I64 => Value::I64(ret as i64),
Expand All @@ -101,7 +102,6 @@ pub extern "system" fn Java_org_wasmer_Imports_nativeImportsInstantiate(
}).collect()
}
};
// TODO: Error handling
Ok(ret)
}));
}
Expand All @@ -125,10 +125,9 @@ pub extern "system" fn Java_org_wasmer_Imports_nativeImportsWasi(
module: jptr,
) -> jptr {
let output = panic::catch_unwind(|| {
// TODO: When getting serious about this, one might have to expose the wasi builder... :/
let module: &Module = Into::<Pointer<Module>>::into(module).borrow();
let mut wasi = WasiState::new("").finalize().unwrap();
let import_object = wasi.import_object(&module.module).unwrap();
let mut wasi = WasiState::new("").finalize()?;
let import_object = wasi.import_object(&module.module)?;
let import_object = Box::new(import_object);

Ok(Pointer::new(Imports { import_object }).into())
Expand All @@ -145,7 +144,6 @@ pub extern "system" fn Java_org_wasmer_Imports_nativeImportsChain(
front: jptr,
) -> jptr {
let output = panic::catch_unwind(|| {

let back: &Imports = Into::<Pointer<Imports>>::into(back).borrow();
let front: &Imports = Into::<Pointer<Imports>>::into(front).borrow();
let import_object = Box::new((&back.import_object).chain_front(&front.import_object));
Expand All @@ -154,14 +152,4 @@ pub extern "system" fn Java_org_wasmer_Imports_nativeImportsChain(
});

joption_or_throw(&env, output).unwrap_or(0)

}

#[no_mangle]
pub extern "system" fn Java_org_wasmer_Imports_nativeDrop(
_env: JNIEnv,
_class: JClass,
imports_pointer: jptr,
) {
let _: Pointer<Imports> = imports_pointer.into();
}
6 changes: 0 additions & 6 deletions src/java/org/wasmer/Imports.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ public class Imports {
private static native long nativeImportsInstantiate(List<ImportObject> imports, long modulePointer) throws RuntimeException;
private static native long nativeImportsChain(long back, long front) throws RuntimeException;
private static native long nativeImportsWasi(long modulePointer) throws RuntimeException;
private static native void nativeDrop(long nativePointer);

final long importsPointer;

Expand All @@ -31,9 +30,4 @@ public static Imports chain(Imports back, Imports front) {
public static Imports wasi(Module module) {
return new Imports(nativeImportsWasi(module.modulePointer));
}

protected void finalize() {
nativeDrop(importsPointer);
// TODO allow memory-safe user invocation
}
}

0 comments on commit 91600ed

Please sign in to comment.