From 9a706567b9259f02428a110eafe7e6e956b0ce72 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 27 Oct 2023 15:02:02 +0000 Subject: [PATCH] remove enzyme bindings, fix changes in inkwell api --- Cargo.toml | 4 -- src/build.rs | 44 ------------- src/codegen/codegen.rs | 140 ++++++++++++++++++++--------------------- src/enzyme/mod.rs | 36 ----------- src/enzyme/wrapper.h | 2 - src/lib.rs | 10 --- 6 files changed, 70 insertions(+), 166 deletions(-) delete mode 100644 src/build.rs delete mode 100644 src/enzyme/mod.rs delete mode 100644 src/enzyme/wrapper.h diff --git a/Cargo.toml b/Cargo.toml index 34f05d2..22db40f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ name = "diffeq" version = "0.1.0" edition = "2021" -build = "src/build.rs" [[bin]] name = "diffeq" @@ -21,9 +20,6 @@ sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries"] } ouroboros = ">=0.17" clap = { version = "4.3.23", features = ["derive"] } -[build-dependencies] -bindgen = ">=0.68.1" - [profile.dev] opt-level = 0 debug = true diff --git a/src/build.rs b/src/build.rs deleted file mode 100644 index d452019..0000000 --- a/src/build.rs +++ /dev/null @@ -1,44 +0,0 @@ -use std::env; -use std::path::PathBuf; - -fn generate_enzyme_wrapper() { - - // Tell cargo to look for shared libraries in the specified directory - println!("cargo:rustc-link-search={}", std::env::var("ENZYME_LIB_DIR").unwrap_or("/usr/local/lib".to_string())); - - // Tell cargo to tell rustc to link the system bzip2 - // shared library. - println!("cargo:rustc-link-lib=LLVMEnzyme-14"); - - // Tell cargo to invalidate the built crate whenever the wrapper changes - println!("cargo:rerun-if-changed=src/enzyme/wrapper.h"); - - // The bindgen::Builder is the main entry point - // to bindgen, and lets you build up options for - // the resulting bindings. - let bindings = bindgen::Builder::default() - // The input header we would like to generate - // bindings for. - .header("src/enzyme/wrapper.h") - // set the clang args to include the llvm prefix dir - .clang_arg(format!("-I{}", std::env::var("LLVM_SYS_14_PREFIX/include").unwrap_or("/usr/lib/llvm-14/include".to_string()))) - // set the clang args to include the llvm include dir - .clang_arg(format!("-I{}", std::env::var("ENZYME_INCLUDE_DIR").unwrap_or("/usr/local/include".to_string()))) - // Tell cargo to invalidate the built crate whenever any of the - // included header files changed. - .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - // Finish the builder and generate the bindings. - .generate() - // Unwrap the Result and panic on failure. - .expect("Unable to generate bindings"); - - // Write the bindings to the $OUT_DIR/bindings.rs file. - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); -} - -fn main() { - generate_enzyme_wrapper(); -} \ No newline at end of file diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index 83041ff..e2794e0 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -103,7 +103,7 @@ impl<'ctx> CodeGen<'ctx> { if let Some(name) = blk.name() { let ptr = self.variables.get("u").unwrap(); let i = self.context.i32_type().const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(*ptr, &[i], blk.name().unwrap()) }; + let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(*ptr, &[i], blk.name().unwrap()).unwrap() }; self.variables.insert(name.to_owned(), alloca); } data_index += blk.nnz(); @@ -113,7 +113,7 @@ impl<'ctx> CodeGen<'ctx> { if let Some(name) = blk.name() { let ptr = self.variables.get("dudt").unwrap(); let i = self.context.i32_type().const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(*ptr, &[i], blk.name().unwrap()) }; + let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(*ptr, &[i], blk.name().unwrap()).unwrap() }; self.variables.insert(name.to_owned(), alloca); } data_index += blk.nnz(); @@ -123,14 +123,14 @@ impl<'ctx> CodeGen<'ctx> { let ptr = self.variables.get("data").unwrap().clone(); let mut data_index = self.layout.get_data_index(tensor.name()).unwrap(); let i = self.context.i32_type().const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(ptr, &[i], tensor.name()) }; + let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(ptr, &[i], tensor.name()).unwrap() }; self.variables.insert(tensor.name().to_owned(), alloca); //insert any named blocks for blk in tensor.elmts() { if let Some(name) = blk.name() { let i = self.context.i32_type().const_int(data_index.try_into().unwrap(), false); - let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(ptr, &[i], name) }; + let alloca = unsafe { self.create_entry_block_builder().build_in_bounds_gep(ptr, &[i], name).unwrap() }; self.variables.insert(name.to_owned(), alloca); } // named blocks only supported for rank <= 1, so we can just add the nnz to get the next data index @@ -185,11 +185,11 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(basic_block); let x = fn_val.get_nth_param(0)?.into_float_value(); let one = self.real_type.const_float(1.0); - let negx = self.builder.build_float_neg(x, name); + let negx = self.builder.build_float_neg(x, name).ok()?; let exp = self.get_function("exp").unwrap(); - let exp_negx = self.builder.build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name); - let one_plus_exp_negx = self.builder.build_float_add(exp_negx.try_as_basic_value().left().unwrap().into_float_value(), one, name); - let sigmoid = self.builder.build_float_div(one, one_plus_exp_negx, name); + let exp_negx = self.builder.build_call(exp, &[BasicMetadataValueEnum::FloatValue(negx)], name).ok()?; + let one_plus_exp_negx = self.builder.build_float_add(exp_negx.try_as_basic_value().left().unwrap().into_float_value(), one, name).ok()?; + let sigmoid = self.builder.build_float_div(one, one_plus_exp_negx, name).ok()?; self.builder.build_return(Some(&sigmoid)); self.builder.position_at_end(current_block); Some(fn_val) @@ -228,7 +228,7 @@ impl<'ctx> CodeGen<'ctx> { let res_type = self.real_type; let res_ptr = match res_ptr_opt { Some(ptr) => ptr, - None => self.create_entry_block_builder().build_alloca(res_type, a.name()), + None => self.create_entry_block_builder().build_alloca(res_type, a.name())?, }; let name = a.name(); let elmt = a.elmts().first().unwrap(); @@ -246,7 +246,7 @@ impl<'ctx> CodeGen<'ctx> { let res_type = self.real_type; let res_ptr = match res_ptr_opt { Some(ptr) => ptr, - None => self.create_entry_block_builder().build_alloca(res_type, a.name()), + None => self.create_entry_block_builder().build_alloca(res_type, a.name())?, }; // set up the tensor storage pointer and index into this data @@ -293,8 +293,8 @@ impl<'ctx> CodeGen<'ctx> { let one = int_type.const_int(1, false); let zero = int_type.const_int(0, false); - let expr_index_ptr = self.builder.build_alloca(int_type, "expr_index"); - let elmt_index_ptr = self.builder.build_alloca(int_type, "elmt_index"); + let expr_index_ptr = self.builder.build_alloca(int_type, "expr_index")?; + let elmt_index_ptr = self.builder.build_alloca(int_type, "elmt_index")?; self.builder.build_store(expr_index_ptr, zero); self.builder.build_store(elmt_index_ptr, zero); @@ -305,7 +305,7 @@ impl<'ctx> CodeGen<'ctx> { // allocate the contract sum if needed let (contract_sum, contract_by) = if let TranslationFrom::DenseContraction { contract_by, contract_len: _} = translation.source { - (Some(self.builder.build_alloca(self.real_type, "contract_sum")), contract_by) + (Some(self.builder.build_alloca(self.real_type, "contract_sum")?), contract_by) } else { (None, 0) }; @@ -316,7 +316,7 @@ impl<'ctx> CodeGen<'ctx> { self.builder.position_at_end(block); let start_index = int_type.const_int(0, false); - let curr_index = self.builder.build_phi(int_type, format!["i{}", i].as_str()); + let curr_index = self.builder.build_phi(int_type, format!["i{}", i].as_str())?; curr_index.add_incoming(&[(&start_index, preblock)]); if i == expr_rank - contract_by - 1 && contract_sum.is_some() { @@ -332,19 +332,19 @@ impl<'ctx> CodeGen<'ctx> { let indices_int: Vec = indices.iter().map(|i| i.as_basic_value().into_int_value()).collect(); // load and increment the expression index - let expr_index = self.builder.build_load(expr_index_ptr, "expr_index").into_int_value(); - let elmt_index = self.builder.build_load(elmt_index_ptr, "elmt_index").into_int_value(); - let next_expr_index = self.builder.build_int_add(expr_index, one, "next_expr_index"); + let expr_index = self.builder.build_load(expr_index_ptr, "expr_index")?.into_int_value(); + let elmt_index = self.builder.build_load(elmt_index_ptr, "elmt_index")?.into_int_value(); + let next_expr_index = self.builder.build_int_add(expr_index, one, "next_expr_index")?; self.builder.build_store(expr_index_ptr, next_expr_index); let float_value = self.jit_compile_expr(name, &elmt.expr(), indices_int.as_slice(), elmt, Some(expr_index))?; if contract_sum.is_some() { - let contract_sum_value = self.builder.build_load(contract_sum.unwrap(), "contract_sum").into_float_value(); - let new_contract_sum_value = self.builder.build_float_add(contract_sum_value, float_value, "new_contract_sum"); + let contract_sum_value = self.builder.build_load(contract_sum.unwrap(), "contract_sum")?.into_float_value(); + let new_contract_sum_value = self.builder.build_float_add(contract_sum_value, float_value, "new_contract_sum")?; self.builder.build_store(contract_sum.unwrap(), new_contract_sum_value); } else { preblock = self.jit_compile_broadcast_and_store(name, elmt, float_value, expr_index, translation, preblock)?; - let next_elmt_index = self.builder.build_int_add(elmt_index, one, "next_elmt_index"); + let next_elmt_index = self.builder.build_int_add(elmt_index, one, "next_elmt_index")?; self.builder.build_store(elmt_index_ptr, next_elmt_index); } @@ -355,14 +355,14 @@ impl<'ctx> CodeGen<'ctx> { indices[i].add_incoming(&[(&next_index, preblock)]); if i == expr_rank - contract_by - 1 && contract_sum.is_some() { - let contract_sum_value= self.builder.build_load(contract_sum.unwrap(), "contract_sum").into_float_value(); - let next_elmt_index = self.builder.build_int_add(elmt_index, one, "next_elmt_index"); + let contract_sum_value= self.builder.build_load(contract_sum.unwrap(), "contract_sum")?.into_float_value(); + let next_elmt_index = self.builder.build_int_add(elmt_index, one, "next_elmt_index")?; self.builder.build_store(elmt_index_ptr, next_elmt_index); self.jit_compile_store(name, elmt, elmt_index, contract_sum_value, translation)?; } // loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, expr_shape[i], name); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, expr_shape[i], name)?; let block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_conditional_branch(loop_while, blocks[i], block); self.builder.position_at_end(block); @@ -386,7 +386,7 @@ impl<'ctx> CodeGen<'ctx> { let translation_index = self.layout.get_translation_index(elmt.expr_layout(), elmt.layout()).unwrap(); let translation_index = translation_index + translation.get_from_index_in_data_layout(); - let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum"); + let contract_sum_ptr = self.builder.build_alloca(self.real_type, "contract_sum")?; // loop through each contraction @@ -394,7 +394,7 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_unconditional_branch(block); self.builder.position_at_end(block); - let contract_index = self.builder.build_phi(int_type, "i"); + let contract_index = self.builder.build_phi(int_type, "i")?; let final_contract_index = int_type.const_int(elmt.layout().nnz().try_into().unwrap(), false); contract_index.add_incoming(&[(&int_type.const_int(0, false), preblock)]); @@ -404,18 +404,18 @@ impl<'ctx> CodeGen<'ctx> { int_type.const_int(2, false), contract_index.as_basic_value().into_int_value(), name, - ), + )?, name - ); + )?; let end_index = self.builder.build_int_add( start_index, - int_type.const_int(1, false)?, + int_type.const_int(1, false), name - ); - let start_ptr = unsafe { self.builder.build_gep(*self.get_param("indices"), &[start_index], "start_index_ptr") }; - let start_contract= self.builder.build_load(start_ptr, "start").into_int_value(); - let end_ptr = unsafe { self.builder.build_gep(*self.get_param("indices"), &[end_index], "end_index_ptr") }; - let end_contract = self.builder.build_load(end_ptr, "end").into_int_value(); + )?; + let start_ptr = unsafe { self.builder.build_gep(*self.get_param("indices"), &[start_index], "start_index_ptr")? }; + let start_contract= self.builder.build_load(start_ptr, "start")?.into_int_value(); + let end_ptr = unsafe { self.builder.build_gep(*self.get_param("indices"), &[end_index], "end_index_ptr")? }; + let end_contract = self.builder.build_load(end_ptr, "end")?.into_int_value(); // initialise the contract sum self.builder.build_store(contract_sum_ptr, self.real_type.const_float(0.0)); @@ -425,7 +425,7 @@ impl<'ctx> CodeGen<'ctx> { self.builder.build_unconditional_branch(contract_block); self.builder.position_at_end(contract_block); - let expr_index_phi = self.builder.build_phi(int_type, "j"); + let expr_index_phi = self.builder.build_phi(int_type, "j")?; expr_index_phi.add_incoming(&[(&start_contract, block)]); // loop body - load index from layout @@ -433,23 +433,23 @@ impl<'ctx> CodeGen<'ctx> { let elmt_index_mult_rank = self.builder.build_int_mul(expr_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?; let indices_int: Vec = (0..elmt.expr_layout().rank()).map(|i| { let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false); - let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name); - let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name) }; - self.builder.build_load(ptr, name).into_int_value() + let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name)?; + let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name)? }; + self.builder.build_load(ptr, name)?.into_int_value() }).collect(); // loop body - eval expression and increment sum let float_value = self.jit_compile_expr(name, &elmt.expr(), indices_int.as_slice(), elmt, Some(expr_index))?; - let contract_sum_value = self.builder.build_load(contract_sum_ptr, "contract_sum").into_float_value(); - let new_contract_sum_value = self.builder.build_float_add(contract_sum_value, float_value, "new_contract_sum"); + let contract_sum_value = self.builder.build_load(contract_sum_ptr, "contract_sum")?.into_float_value(); + let new_contract_sum_value = self.builder.build_float_add(contract_sum_value, float_value, "new_contract_sum")?; self.builder.build_store(contract_sum_ptr, new_contract_sum_value); // increment contract loop index - let next_elmt_index = self.builder.build_int_add(expr_index, int_type.const_int(1, false), name); + let next_elmt_index = self.builder.build_int_add(expr_index, int_type.const_int(1, false), name)?; expr_index_phi.add_incoming(&[(&next_elmt_index, contract_block)]); // contract loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_elmt_index, end_contract, name); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_elmt_index, end_contract, name)?; let post_contract_block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_conditional_branch(loop_while, contract_block, post_contract_block); self.builder.position_at_end(post_contract_block); @@ -462,7 +462,7 @@ impl<'ctx> CodeGen<'ctx> { contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]); // outer loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_contract_index, final_contract_index, name); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_contract_index, final_contract_index, name)?; let post_block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_conditional_branch(loop_while, block, post_block); self.builder.position_at_end(post_block); @@ -485,7 +485,7 @@ impl<'ctx> CodeGen<'ctx> { let start_index = int_type.const_int(0, false); let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false); - let curr_index = self.builder.build_phi(int_type, "i"); + let curr_index = self.builder.build_phi(int_type, "i")?; curr_index.add_incoming(&[(&start_index, preblock)]); // loop body - load index from layout @@ -493,9 +493,9 @@ impl<'ctx> CodeGen<'ctx> { let elmt_index_mult_rank = self.builder.build_int_mul(elmt_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?; let indices_int: Vec = (0..elmt.expr_layout().rank()).map(|i| { let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false); - let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name); - let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name) }; - self.builder.build_load(ptr, name).into_int_value() + let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name)?; + let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name)? }; + self.builder.build_load(ptr, name)?.into_int_value() }).collect(); // loop body - eval expression @@ -509,7 +509,7 @@ impl<'ctx> CodeGen<'ctx> { curr_index.add_incoming(&[(&next_index, block)]); // loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name)?; let post_block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_conditional_branch(loop_while, block, post_block); self.builder.position_at_end(post_block); @@ -530,7 +530,7 @@ impl<'ctx> CodeGen<'ctx> { let start_index = int_type.const_int(0, false); let end_index = int_type.const_int(elmt.expr_layout().nnz().try_into().unwrap(), false); - let curr_index = self.builder.build_phi(int_type, "i"); + let curr_index = self.builder.build_phi(int_type, "i")?; curr_index.add_incoming(&[(&start_index, preblock)]); // loop body - index is just the same for each element @@ -551,7 +551,7 @@ impl<'ctx> CodeGen<'ctx> { curr_index.add_incoming(&[(&next_index, block)]); // loop condition - let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name); + let loop_while = self.builder.build_int_compare(IntPredicate::ULT, next_index, end_index, name)?; let post_block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_conditional_branch(loop_while, block, post_block); self.builder.position_at_end(post_block); @@ -572,15 +572,15 @@ impl<'ctx> CodeGen<'ctx> { let bcast_block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_unconditional_branch(bcast_block); self.builder.position_at_end(bcast_block); - let bcast_index = self.builder.build_phi(int_type, "broadcast_index"); + let bcast_index = self.builder.build_phi(int_type, "broadcast_index")?; bcast_index.add_incoming(&[(&bcast_start_index, pre_block)]); // store value let store_index = self.builder.build_int_add( - self.builder.build_int_mul(expr_index, bcast_end_index, "store_index"), + self.builder.build_int_mul(expr_index, bcast_end_index, "store_index")?, bcast_index.as_basic_value().into_int_value(), "bcast_store_index" - ); + )?; self.jit_compile_store(name, elmt, store_index, float_value, translation)?; // increment index @@ -588,7 +588,7 @@ impl<'ctx> CodeGen<'ctx> { bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]); // loop condition - let bcast_cond = self.builder.build_int_compare(IntPredicate::ULT, bcast_next_index, bcast_end_index, "broadcast_cond"); + let bcast_cond = self.builder.build_int_compare(IntPredicate::ULT, bcast_next_index, bcast_end_index, "broadcast_cond")?; let post_bcast_block = self.context.append_basic_block(self.fn_value(), name); self.builder.build_conditional_branch(bcast_cond, bcast_block, post_bcast_block); self.builder.position_at_end(post_bcast_block); @@ -612,7 +612,7 @@ impl<'ctx> CodeGen<'ctx> { let res_index = match &translation.target { TranslationTo::Contiguous { start, end: _ } => { let start_const = int_type.const_int((*start).try_into().unwrap(), false); - self.builder.build_int_add(start_const, store_index, name) + self.builder.build_int_add(start_const, store_index, name)? }, TranslationTo::Sparse { indices: _ } => { // load store index from layout @@ -621,12 +621,12 @@ impl<'ctx> CodeGen<'ctx> { let translate_store_index = int_type.const_int(translate_store_index.try_into().unwrap(), false); let rank_const = int_type.const_int(rank.try_into().unwrap(), false); let elmt_index_strided = self.builder.build_int_mul(store_index, rank_const, name)?; - let curr_index = self.builder.build_int_add(elmt_index_strided, translate_store_index, name); - let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name) }; - self.builder.build_load(ptr, name).into_int_value() + let curr_index = self.builder.build_int_add(elmt_index_strided, translate_store_index, name)?; + let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name)? }; + self.builder.build_load(ptr, name)?.into_int_value() }, }; - let resi_ptr = unsafe { self.builder.build_in_bounds_gep(self.tensor_ptr(), &[res_index], name) }; + let resi_ptr = unsafe { self.builder.build_in_bounds_gep(self.tensor_ptr(), &[res_index], name)? }; self.builder.build_store(resi_ptr, float_value); Ok(()) } @@ -638,17 +638,17 @@ impl<'ctx> CodeGen<'ctx> { let lhs = self.jit_compile_expr(name, binop.left.as_ref(), index, elmt, expr_index)?; let rhs = self.jit_compile_expr(name, binop.right.as_ref(), index, elmt, expr_index)?; match binop.op { - '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)), - '/' => Ok(self.builder.build_float_div(lhs, rhs, name)), - '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)), - '+' => Ok(self.builder.build_float_add(lhs, rhs, name)), + '*' => Ok(self.builder.build_float_mul(lhs, rhs, name)?), + '/' => Ok(self.builder.build_float_div(lhs, rhs, name)?), + '-' => Ok(self.builder.build_float_sub(lhs, rhs, name)?), + '+' => Ok(self.builder.build_float_add(lhs, rhs, name)?), unknown => Err(anyhow!("unknown binop op '{}'", unknown)) } }, AstKind::Monop(monop) => { let child = self.jit_compile_expr(name, monop.child.as_ref(), index, elmt, expr_index)?; match monop.op { - '-' => Ok(self.builder.build_float_neg(child, name)), + '-' => Ok(self.builder.build_float_neg(child, name)?), unknown => Err(anyhow!("unknown monop op '{}'", unknown)) } }, @@ -660,7 +660,7 @@ impl<'ctx> CodeGen<'ctx> { let arg_val = self.jit_compile_expr(name, arg.as_ref(), index, elmt, expr_index)?; args.push(BasicMetadataValueEnum::FloatValue(arg_val)); } - let ret_value = self.builder.build_call(function, args.as_slice(), name) + let ret_value = self.builder.build_call(function, args.as_slice(), name)? .try_as_basic_value().left().unwrap().into_float_value(); Ok(ret_value) }, @@ -696,8 +696,8 @@ impl<'ctx> CodeGen<'ctx> { let shapei: u64 = layout.shape()[i + 1].try_into().unwrap(); stride *= shapei; let stride_intval = self.context.i32_type().const_int(stride, false); - let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name); - iname_elmt_index = self.builder.build_int_add(iname_elmt_index, stride_mul_i, name); + let stride_mul_i = self.builder.build_int_mul(stride_intval, iname_i, name)?; + iname_elmt_index = self.builder.build_int_add(iname_elmt_index, stride_mul_i, name)?; } Some(iname_elmt_index) } else { @@ -718,17 +718,17 @@ impl<'ctx> CodeGen<'ctx> { Some(index) => unsafe { self.builder.build_in_bounds_gep(*ptr, &[index], name)? }, None => *ptr }; - Ok(self.builder.build_load(value_ptr, name).into_float_value()) + Ok(self.builder.build_load(value_ptr, name)?.into_float_value()) }, AstKind::Name(name) => { // must be a scalar, just load the value let ptr = self.get_param(name); - Ok(self.builder.build_load(*ptr, name).into_float_value()) + Ok(self.builder.build_load(*ptr, name)?.into_float_value()) }, AstKind::NamedGradient(name) => { let name_str = name.to_string(); let ptr = self.get_param(name_str.as_str()); - Ok(self.builder.build_load(*ptr, name_str.as_str()).into_float_value()) + Ok(self.builder.build_load(*ptr, name_str.as_str())?.into_float_value()) }, AstKind::Index(_) => todo!(), AstKind::Slice(_) => todo!(), @@ -749,7 +749,7 @@ impl<'ctx> CodeGen<'ctx> { match arg { BasicValueEnum::PointerValue(v) => v, BasicValueEnum::FloatValue(v) => { - let alloca = self.create_entry_block_builder().build_alloca(arg.get_type(), name); + let alloca = self.create_entry_block_builder().build_alloca(arg.get_type(), name)?; self.builder.build_store(alloca, v); alloca } diff --git a/src/enzyme/mod.rs b/src/enzyme/mod.rs deleted file mode 100644 index 63a0ee2..0000000 --- a/src/enzyme/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -use std::env; -use std::path::PathBuf; - -pub fn generate_enzyme_wrapper() { - - // Tell cargo to look for shared libraries in the specified directory - println!("cargo:rustc-link-search=/home/mrobins/.local/lib"); - - // Tell cargo to tell rustc to link the system bzip2 - // shared library. - println!("cargo:rustc-link-lib=LLVMEnzyme-14"); - - // Tell cargo to invalidate the built crate whenever the wrapper changes - println!("cargo:rerun-if-changed=wrapper.h"); - - // The bindgen::Builder is the main entry point - // to bindgen, and lets you build up options for - // the resulting bindings. - let bindings = bindgen::Builder::default() - // The input header we would like to generate - // bindings for. - .header("wrapper.h") - // Tell cargo to invalidate the built crate whenever any of the - // included header files changed. - .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - // Finish the builder and generate the bindings. - .generate() - // Unwrap the Result and panic on failure. - .expect("Unable to generate bindings"); - - // Write the bindings to the $OUT_DIR/bindings.rs file. - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - bindings - .write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); -} \ No newline at end of file diff --git a/src/enzyme/wrapper.h b/src/enzyme/wrapper.h deleted file mode 100644 index 1f6ac9e..0000000 --- a/src/enzyme/wrapper.h +++ /dev/null @@ -1,2 +0,0 @@ -#include -#include \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index a2aa49f..5228b68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,3 @@ -#![allow( - non_upper_case_globals, - non_camel_case_types, - non_snake_case, - improper_ctypes, - clippy::all -)] -include!(concat!(env!("OUT_DIR"), "/bindings.rs")); - use std::{path::Path, ffi::OsStr, process::Command, env}; use anyhow::{Result, anyhow}; use codegen::Compiler; @@ -23,7 +14,6 @@ pub mod ast; pub mod discretise; pub mod continuous; pub mod codegen; -pub mod enzyme; pub struct CompilerOptions {