From ac888b1c3cb1c2f32c5096a6103aeb370746b3b4 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 27 Oct 2023 12:25:03 +0000 Subject: [PATCH] attempt to add enzymead bindings --- Cargo.toml | 14 +++---------- README.md | 5 +++++ src/bin/diffeq.rs | 2 +- src/build.rs | 44 +++++++++++++++++++++++++++++++++++++++++ src/codegen/codegen.rs | 24 +++++++++++----------- src/codegen/compiler.rs | 6 ++++++ src/enzyme/mod.rs | 36 +++++++++++++++++++++++++++++++++ src/enzyme/wrapper.h | 2 ++ src/lib.rs | 31 ++++++++++++++++++----------- 9 files changed, 128 insertions(+), 36 deletions(-) create mode 100644 README.md create mode 100644 src/build.rs create mode 100644 src/enzyme/mod.rs create mode 100644 src/enzyme/wrapper.h diff --git a/Cargo.toml b/Cargo.toml index 41b8863..34f05d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,9 +2,7 @@ name = "diffeq" version = "0.1.0" edition = "2021" - -[lib] -crate-type = ["cdylib", "rlib"] +build = "src/build.rs" [[bin]] name = "diffeq" @@ -23,14 +21,8 @@ sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries"] } ouroboros = ">=0.17" clap = { version = "4.3.23", features = ["derive"] } -[target.'cfg(not(target_arch = "wasm32"))'.dependencies] - - -[target.'cfg(target_arch = "wasm32")'.dependencies] -wasm-bindgen = ">=0.2" - -[target.'cfg(target_arch = "wasm32")'.dev-dependencies] -wasm-bindgen-test = "0.3.13" +[build-dependencies] +bindgen = ">=0.68.1" [profile.dev] opt-level = 0 diff --git a/README.md b/README.md new file mode 100644 index 0000000..bee0271 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +### Installing Enzyme AD + +```bash +cmake -DCMAKE_INSTALL_PREFIX= -DCMAKE_BUILD_TYPE=Release .. +``` \ No newline at end of file diff --git a/src/bin/diffeq.rs b/src/bin/diffeq.rs index 23e2d42..300569a 100644 --- a/src/bin/diffeq.rs +++ b/src/bin/diffeq.rs @@ -33,7 +33,7 @@ struct Args { fn main() -> Result<()> { let cli = Args::parse(); let options = CompilerOptions { - compile: cli.compile, + bytecode_only: cli.compile, wasm: cli.wasm, standalone: cli.standalone, }; diff --git a/src/build.rs b/src/build.rs new file mode 100644 index 0000000..d452019 --- /dev/null +++ b/src/build.rs @@ -0,0 +1,44 @@ +use std::env; +use std::path::PathBuf; + +fn generate_enzyme_wrapper() { + + // Tell cargo to look for shared libraries in the specified directory + println!("cargo:rustc-link-search={}", std::env::var("ENZYME_LIB_DIR").unwrap_or("/usr/local/lib".to_string())); + + // Tell cargo to tell rustc to link the system bzip2 + // shared library. + println!("cargo:rustc-link-lib=LLVMEnzyme-14"); + + // Tell cargo to invalidate the built crate whenever the wrapper changes + println!("cargo:rerun-if-changed=src/enzyme/wrapper.h"); + + // The bindgen::Builder is the main entry point + // to bindgen, and lets you build up options for + // the resulting bindings. + let bindings = bindgen::Builder::default() + // The input header we would like to generate + // bindings for. + .header("src/enzyme/wrapper.h") + // set the clang args to include the llvm prefix dir + .clang_arg(format!("-I{}", std::env::var("LLVM_SYS_14_PREFIX/include").unwrap_or("/usr/lib/llvm-14/include".to_string()))) + // set the clang args to include the llvm include dir + .clang_arg(format!("-I{}", std::env::var("ENZYME_INCLUDE_DIR").unwrap_or("/usr/local/include".to_string()))) + // Tell cargo to invalidate the built crate whenever any of the + // included header files changed. + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file. + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); +} + +fn main() { + generate_enzyme_wrapper(); +} \ No newline at end of file diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index eac5b42..83041ff 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -351,7 +351,7 @@ impl<'ctx> CodeGen<'ctx> { // unwind the nested loops for i in (0..expr_rank).rev() { // increment index - let next_index = self.builder.build_int_add(indices_int[i], one, name); + let next_index = self.builder.build_int_add(indices_int[i], one, name)?; indices[i].add_incoming(&[(&next_index, preblock)]); if i == expr_rank - contract_by - 1 && contract_sum.is_some() { @@ -409,7 +409,7 @@ impl<'ctx> CodeGen<'ctx> { ); let end_index = self.builder.build_int_add( start_index, - int_type.const_int(1, false), + int_type.const_int(1, false)?, name ); let start_ptr = unsafe { self.builder.build_gep(*self.get_param("indices"), &[start_index], "start_index_ptr") }; @@ -430,7 +430,7 @@ impl<'ctx> CodeGen<'ctx> { // loop body - load index from layout let expr_index = expr_index_phi.as_basic_value().into_int_value(); - let elmt_index_mult_rank = self.builder.build_int_mul(expr_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name); + let elmt_index_mult_rank = self.builder.build_int_mul(expr_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?; let indices_int: Vec = (0..elmt.expr_layout().rank()).map(|i| { let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false); let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name); @@ -458,7 +458,7 @@ impl<'ctx> CodeGen<'ctx> { self.jit_compile_store(name, elmt, contract_index.as_basic_value().into_int_value(), new_contract_sum_value, translation)?; // increment outer loop index - let next_contract_index = self.builder.build_int_add(contract_index.as_basic_value().into_int_value(), int_type.const_int(1, false), name); + let next_contract_index = self.builder.build_int_add(contract_index.as_basic_value().into_int_value(), int_type.const_int(1, false), name)?; contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]); // outer loop condition @@ -490,7 +490,7 @@ impl<'ctx> CodeGen<'ctx> { // loop body - load index from layout let elmt_index = curr_index.as_basic_value().into_int_value(); - let elmt_index_mult_rank = self.builder.build_int_mul(elmt_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name); + let elmt_index_mult_rank = self.builder.build_int_mul(elmt_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?; let indices_int: Vec = (0..elmt.expr_layout().rank()).map(|i| { let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false); let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name); @@ -505,7 +505,7 @@ impl<'ctx> CodeGen<'ctx> { // increment loop index let one = int_type.const_int(1, false); - let next_index = self.builder.build_int_add(elmt_index, one, name); + let next_index = self.builder.build_int_add(elmt_index, one, name)?; curr_index.add_incoming(&[(&next_index, block)]); // loop condition @@ -547,7 +547,7 @@ impl<'ctx> CodeGen<'ctx> { // increment loop index let one = int_type.const_int(1, false); - let next_index = self.builder.build_int_add(elmt_index, one, name); + let next_index = self.builder.build_int_add(elmt_index, one, name)?; curr_index.add_incoming(&[(&next_index, block)]); // loop condition @@ -584,7 +584,7 @@ impl<'ctx> CodeGen<'ctx> { self.jit_compile_store(name, elmt, store_index, float_value, translation)?; // increment index - let bcast_next_index= self.builder.build_int_add(bcast_index.as_basic_value().into_int_value(), one, name); + let bcast_next_index= self.builder.build_int_add(bcast_index.as_basic_value().into_int_value(), one, name)?; bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]); // loop condition @@ -620,7 +620,7 @@ impl<'ctx> CodeGen<'ctx> { let translate_store_index = translate_index + translation.get_to_index_in_data_layout(); let translate_store_index = int_type.const_int(translate_store_index.try_into().unwrap(), false); let rank_const = int_type.const_int(rank.try_into().unwrap(), false); - let elmt_index_strided = self.builder.build_int_mul(store_index, rank_const, name); + let elmt_index_strided = self.builder.build_int_mul(store_index, rank_const, name)?; let curr_index = self.builder.build_int_add(elmt_index_strided, translate_store_index, name); let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name) }; self.builder.build_load(ptr, name).into_int_value() @@ -715,7 +715,7 @@ impl<'ctx> CodeGen<'ctx> { panic!("unexpected layout"); }; let value_ptr = match iname_elmt_index { - Some(index) => unsafe { self.builder.build_in_bounds_gep(*ptr, &[index], name) }, + Some(index) => unsafe { self.builder.build_in_bounds_gep(*ptr, &[index], name)? }, None => *ptr }; Ok(self.builder.build_load(value_ptr, name).into_float_value()) @@ -1029,7 +1029,7 @@ impl<'ctx> CodeGen<'ctx> { // increment loop index let one = self.int_type.const_int(1, false); - let next_index = self.builder.build_int_add(curr_input_index, one, name.as_str()); + let next_index = self.builder.build_int_add(curr_input_index, one, name.as_str())?; index.add_incoming(&[(&next_index, input_block)]); // loop condition @@ -1100,7 +1100,7 @@ impl<'ctx> CodeGen<'ctx> { // increment loop index let one = self.int_type.const_int(1, false); - let next_index = self.builder.build_int_add(curr_blk_index, one, name); + let next_index = self.builder.build_int_add(curr_blk_index, one, name)?; index.add_incoming(&[(&next_index, blk_block)]); // loop condition diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index 430163e..324628e 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -264,6 +264,12 @@ impl Compiler { Ok(target_machine) } + pub fn write_bitcode_to_path(&self, path: &Path) -> Result<()> { + self.with_data(|data| + data.codegen.module().write_bitcode_to_path(path).map_err(|e| anyhow::anyhow!("Error writing bitcode: {:?}", e)) + ) + } + pub fn write_object_file(&self, path: &Path) -> Result<()> { let target_machine = Compiler::get_native_machine()?; self.with_data(|data| diff --git a/src/enzyme/mod.rs b/src/enzyme/mod.rs new file mode 100644 index 0000000..63a0ee2 --- /dev/null +++ b/src/enzyme/mod.rs @@ -0,0 +1,36 @@ +use std::env; +use std::path::PathBuf; + +pub fn generate_enzyme_wrapper() { + + // Tell cargo to look for shared libraries in the specified directory + println!("cargo:rustc-link-search=/home/mrobins/.local/lib"); + + // Tell cargo to tell rustc to link the system bzip2 + // shared library. + println!("cargo:rustc-link-lib=LLVMEnzyme-14"); + + // Tell cargo to invalidate the built crate whenever the wrapper changes + println!("cargo:rerun-if-changed=wrapper.h"); + + // The bindgen::Builder is the main entry point + // to bindgen, and lets you build up options for + // the resulting bindings. + let bindings = bindgen::Builder::default() + // The input header we would like to generate + // bindings for. + .header("wrapper.h") + // Tell cargo to invalidate the built crate whenever any of the + // included header files changed. + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + // Finish the builder and generate the bindings. + .generate() + // Unwrap the Result and panic on failure. + .expect("Unable to generate bindings"); + + // Write the bindings to the $OUT_DIR/bindings.rs file. + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); +} \ No newline at end of file diff --git a/src/enzyme/wrapper.h b/src/enzyme/wrapper.h new file mode 100644 index 0000000..1f6ac9e --- /dev/null +++ b/src/enzyme/wrapper.h @@ -0,0 +1,2 @@ +#include +#include \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 7fb5768..a2aa49f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,12 @@ +#![allow( + non_upper_case_globals, + non_camel_case_types, + non_snake_case, + improper_ctypes, + clippy::all +)] +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); + use std::{path::Path, ffi::OsStr, process::Command, env}; use anyhow::{Result, anyhow}; use codegen::Compiler; @@ -14,9 +23,11 @@ pub mod ast; pub mod discretise; pub mod continuous; pub mod codegen; +pub mod enzyme; + pub struct CompilerOptions { - pub compile: bool, + pub bytecode_only: bool, pub wasm: bool, pub standalone: bool, } @@ -40,8 +51,8 @@ pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: Com }; let out = if let Some(out) = out { out.clone() - } else if options.compile { - "out.o" + } else if options.bytecode_only { + "out.ll" } else { "out" }; @@ -50,12 +61,12 @@ pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: Com } pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOptions, is_discrete: bool) -> Result<()> { - let CompilerOptions { compile, wasm, standalone } = options; + let CompilerOptions { bytecode_only, wasm, standalone } = options; let is_continuous = !is_discrete; - let objectname = if compile { out.to_owned() } else { format!("{}.o", out) }; - let objectfile = Path::new(objectname.as_str()); + let bytecodename = if bytecode_only { out.to_owned() } else { format!("{}.ll", out) }; + let bytecodefile = Path::new(bytecodename.as_str()); let continuous_ast = if is_continuous { Some(parse_ms_string(text)?) @@ -98,13 +109,9 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp let compiler = Compiler::from_discrete_model(&discrete_model)?; - if wasm { - compiler.write_wasm_object_file(objectfile)?; - } else { - compiler.write_object_file(objectfile)?; - } + compiler.write_bitcode_to_path(bytecodefile); - if compile { + if bytecode_only { return Ok(()); }