Skip to content

Commit

Permalink
attempt to add enzymead bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 27, 2023
1 parent 0c069cf commit ac888b1
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 36 deletions.
14 changes: 3 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
name = "diffeq"
version = "0.1.0"
edition = "2021"

[lib]
crate-type = ["cdylib", "rlib"]
build = "src/build.rs"

[[bin]]
name = "diffeq"
Expand All @@ -23,14 +21,8 @@ sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries"] }
ouroboros = ">=0.17"
clap = { version = "4.3.23", features = ["derive"] }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]


[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen = ">=0.2"

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3.13"
[build-dependencies]
bindgen = ">=0.68.1"

[profile.dev]
opt-level = 0
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
### Installing Enzyme AD

```bash
cmake -DCMAKE_INSTALL_PREFIX=<install> -DCMAKE_BUILD_TYPE=Release ..
```
2 changes: 1 addition & 1 deletion src/bin/diffeq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct Args {
fn main() -> Result<()> {
let cli = Args::parse();
let options = CompilerOptions {
compile: cli.compile,
bytecode_only: cli.compile,
wasm: cli.wasm,
standalone: cli.standalone,
};
Expand Down
44 changes: 44 additions & 0 deletions src/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::env;
use std::path::PathBuf;

fn generate_enzyme_wrapper() {

// Tell cargo to look for shared libraries in the specified directory
println!("cargo:rustc-link-search={}", std::env::var("ENZYME_LIB_DIR").unwrap_or("/usr/local/lib".to_string()));

// Tell cargo to tell rustc to link the system bzip2
// shared library.
println!("cargo:rustc-link-lib=LLVMEnzyme-14");

// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=src/enzyme/wrapper.h");

// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("src/enzyme/wrapper.h")
// set the clang args to include the llvm prefix dir
.clang_arg(format!("-I{}", std::env::var("LLVM_SYS_14_PREFIX/include").unwrap_or("/usr/lib/llvm-14/include".to_string())))
// set the clang args to include the llvm include dir
.clang_arg(format!("-I{}", std::env::var("ENZYME_INCLUDE_DIR").unwrap_or("/usr/local/include".to_string())))
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");

// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}

fn main() {
generate_enzyme_wrapper();
}
24 changes: 12 additions & 12 deletions src/codegen/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ impl<'ctx> CodeGen<'ctx> {
// unwind the nested loops
for i in (0..expr_rank).rev() {
// increment index
let next_index = self.builder.build_int_add(indices_int[i], one, name);
let next_index = self.builder.build_int_add(indices_int[i], one, name)?;
indices[i].add_incoming(&[(&next_index, preblock)]);

if i == expr_rank - contract_by - 1 && contract_sum.is_some() {
Expand Down Expand Up @@ -409,7 +409,7 @@ impl<'ctx> CodeGen<'ctx> {
);
let end_index = self.builder.build_int_add(
start_index,
int_type.const_int(1, false),
int_type.const_int(1, false)?,
name
);
let start_ptr = unsafe { self.builder.build_gep(*self.get_param("indices"), &[start_index], "start_index_ptr") };
Expand All @@ -430,7 +430,7 @@ impl<'ctx> CodeGen<'ctx> {

// loop body - load index from layout
let expr_index = expr_index_phi.as_basic_value().into_int_value();
let elmt_index_mult_rank = self.builder.build_int_mul(expr_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name);
let elmt_index_mult_rank = self.builder.build_int_mul(expr_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?;
let indices_int: Vec<IntValue> = (0..elmt.expr_layout().rank()).map(|i| {
let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false);
let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name);
Expand Down Expand Up @@ -458,7 +458,7 @@ impl<'ctx> CodeGen<'ctx> {
self.jit_compile_store(name, elmt, contract_index.as_basic_value().into_int_value(), new_contract_sum_value, translation)?;

// increment outer loop index
let next_contract_index = self.builder.build_int_add(contract_index.as_basic_value().into_int_value(), int_type.const_int(1, false), name);
let next_contract_index = self.builder.build_int_add(contract_index.as_basic_value().into_int_value(), int_type.const_int(1, false), name)?;
contract_index.add_incoming(&[(&next_contract_index, post_contract_block)]);

// outer loop condition
Expand Down Expand Up @@ -490,7 +490,7 @@ impl<'ctx> CodeGen<'ctx> {

// loop body - load index from layout
let elmt_index = curr_index.as_basic_value().into_int_value();
let elmt_index_mult_rank = self.builder.build_int_mul(elmt_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name);
let elmt_index_mult_rank = self.builder.build_int_mul(elmt_index, int_type.const_int(elmt.expr_layout().rank().try_into().unwrap(), false), name)?;
let indices_int: Vec<IntValue> = (0..elmt.expr_layout().rank()).map(|i| {
let layout_index_plus_offset = int_type.const_int((layout_index + i).try_into().unwrap(), false);
let curr_index = self.builder.build_int_add(elmt_index_mult_rank, layout_index_plus_offset, name);
Expand All @@ -505,7 +505,7 @@ impl<'ctx> CodeGen<'ctx> {

// increment loop index
let one = int_type.const_int(1, false);
let next_index = self.builder.build_int_add(elmt_index, one, name);
let next_index = self.builder.build_int_add(elmt_index, one, name)?;
curr_index.add_incoming(&[(&next_index, block)]);

// loop condition
Expand Down Expand Up @@ -547,7 +547,7 @@ impl<'ctx> CodeGen<'ctx> {

// increment loop index
let one = int_type.const_int(1, false);
let next_index = self.builder.build_int_add(elmt_index, one, name);
let next_index = self.builder.build_int_add(elmt_index, one, name)?;
curr_index.add_incoming(&[(&next_index, block)]);

// loop condition
Expand Down Expand Up @@ -584,7 +584,7 @@ impl<'ctx> CodeGen<'ctx> {
self.jit_compile_store(name, elmt, store_index, float_value, translation)?;

// increment index
let bcast_next_index= self.builder.build_int_add(bcast_index.as_basic_value().into_int_value(), one, name);
let bcast_next_index= self.builder.build_int_add(bcast_index.as_basic_value().into_int_value(), one, name)?;
bcast_index.add_incoming(&[(&bcast_next_index, bcast_block)]);

// loop condition
Expand Down Expand Up @@ -620,7 +620,7 @@ impl<'ctx> CodeGen<'ctx> {
let translate_store_index = translate_index + translation.get_to_index_in_data_layout();
let translate_store_index = int_type.const_int(translate_store_index.try_into().unwrap(), false);
let rank_const = int_type.const_int(rank.try_into().unwrap(), false);
let elmt_index_strided = self.builder.build_int_mul(store_index, rank_const, name);
let elmt_index_strided = self.builder.build_int_mul(store_index, rank_const, name)?;
let curr_index = self.builder.build_int_add(elmt_index_strided, translate_store_index, name);
let ptr = unsafe { self.builder.build_in_bounds_gep(*self.get_param("indices"), &[curr_index], name) };
self.builder.build_load(ptr, name).into_int_value()
Expand Down Expand Up @@ -715,7 +715,7 @@ impl<'ctx> CodeGen<'ctx> {
panic!("unexpected layout");
};
let value_ptr = match iname_elmt_index {
Some(index) => unsafe { self.builder.build_in_bounds_gep(*ptr, &[index], name) },
Some(index) => unsafe { self.builder.build_in_bounds_gep(*ptr, &[index], name)? },
None => *ptr
};
Ok(self.builder.build_load(value_ptr, name).into_float_value())
Expand Down Expand Up @@ -1029,7 +1029,7 @@ impl<'ctx> CodeGen<'ctx> {

// increment loop index
let one = self.int_type.const_int(1, false);
let next_index = self.builder.build_int_add(curr_input_index, one, name.as_str());
let next_index = self.builder.build_int_add(curr_input_index, one, name.as_str())?;
index.add_incoming(&[(&next_index, input_block)]);

// loop condition
Expand Down Expand Up @@ -1100,7 +1100,7 @@ impl<'ctx> CodeGen<'ctx> {

// increment loop index
let one = self.int_type.const_int(1, false);
let next_index = self.builder.build_int_add(curr_blk_index, one, name);
let next_index = self.builder.build_int_add(curr_blk_index, one, name)?;
index.add_incoming(&[(&next_index, blk_block)]);

// loop condition
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ impl Compiler {
Ok(target_machine)
}

pub fn write_bitcode_to_path(&self, path: &Path) -> Result<()> {
self.with_data(|data|
data.codegen.module().write_bitcode_to_path(path).map_err(|e| anyhow::anyhow!("Error writing bitcode: {:?}", e))
)
}

pub fn write_object_file(&self, path: &Path) -> Result<()> {
let target_machine = Compiler::get_native_machine()?;
self.with_data(|data|
Expand Down
36 changes: 36 additions & 0 deletions src/enzyme/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::env;
use std::path::PathBuf;

pub fn generate_enzyme_wrapper() {

// Tell cargo to look for shared libraries in the specified directory
println!("cargo:rustc-link-search=/home/mrobins/.local/lib");

// Tell cargo to tell rustc to link the system bzip2
// shared library.
println!("cargo:rustc-link-lib=LLVMEnzyme-14");

// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=wrapper.h");

// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("wrapper.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");

// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap());
bindings
.write_to_file(out_path.join("bindings.rs"))
.expect("Couldn't write bindings!");
}
2 changes: 2 additions & 0 deletions src/enzyme/wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#include <Enzyme/CApi.h>
#include <Enzyme/Enzyme.h>
31 changes: 19 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
#![allow(
non_upper_case_globals,
non_camel_case_types,
non_snake_case,
improper_ctypes,
clippy::all
)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));

use std::{path::Path, ffi::OsStr, process::Command, env};
use anyhow::{Result, anyhow};
use codegen::Compiler;
Expand All @@ -14,9 +23,11 @@ pub mod ast;
pub mod discretise;
pub mod continuous;
pub mod codegen;
pub mod enzyme;


pub struct CompilerOptions {
pub compile: bool,
pub bytecode_only: bool,
pub wasm: bool,
pub standalone: bool,
}
Expand All @@ -40,8 +51,8 @@ pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: Com
};
let out = if let Some(out) = out {
out.clone()
} else if options.compile {
"out.o"
} else if options.bytecode_only {
"out.ll"
} else {
"out"
};
Expand All @@ -50,12 +61,12 @@ pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: Com
}

pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOptions, is_discrete: bool) -> Result<()> {
let CompilerOptions { compile, wasm, standalone } = options;
let CompilerOptions { bytecode_only, wasm, standalone } = options;

let is_continuous = !is_discrete;

let objectname = if compile { out.to_owned() } else { format!("{}.o", out) };
let objectfile = Path::new(objectname.as_str());
let bytecodename = if bytecode_only { out.to_owned() } else { format!("{}.ll", out) };
let bytecodefile = Path::new(bytecodename.as_str());

let continuous_ast = if is_continuous {
Some(parse_ms_string(text)?)
Expand Down Expand Up @@ -98,13 +109,9 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp

let compiler = Compiler::from_discrete_model(&discrete_model)?;

if wasm {
compiler.write_wasm_object_file(objectfile)?;
} else {
compiler.write_object_file(objectfile)?;
}
compiler.write_bitcode_to_path(bytecodefile);

if compile {
if bytecode_only {
return Ok(());
}

Expand Down

0 comments on commit ac888b1

Please sign in to comment.