Skip to content

Commit

Permalink
add compile_gradient function with enzymeAD
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 30, 2023
1 parent c058009 commit d7e2f92
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 42 deletions.
85 changes: 61 additions & 24 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,27 +1,64 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'diffeq'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=diffeq"
],
"filter": {
"name": "diffeq",
"kind": "lib"
}
},
"args": [],
"cwd": "${workspaceFolder}"
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'diffeq'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=diffeq"
],
"filter": {
"name": "diffeq",
"kind": "lib"
}
]
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug executable 'diffeq'",
"cargo": {
"args": [
"build",
"--bin=diffeq",
"--package=diffeq"
],
"filter": {
"name": "diffeq",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in executable 'diffeq'",
"cargo": {
"args": [
"test",
"--no-run",
"--bin=diffeq",
"--package=diffeq"
],
"filter": {
"name": "diffeq",
"kind": "bin"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pest = ">=2.1.3"
pest_derive = ">=2.1.0"
itertools = ">=0.10.3"
inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm14-0"] }
sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries"] }
sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries", "static_libraries"] }
ouroboros = ">=0.17"
clap = { version = "4.3.23", features = ["derive"] }

Expand Down
132 changes: 119 additions & 13 deletions src/codegen/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use inkwell::basic_block::BasicBlock;
use inkwell::intrinsics::Intrinsic;
use inkwell::passes::PassManager;
use inkwell::types::{FloatType, BasicMetadataTypeEnum, BasicTypeEnum, IntType};
use inkwell::values::{PointerValue, FloatValue, FunctionValue, IntValue, BasicMetadataValueEnum, BasicValueEnum, BasicValue};
use inkwell::values::{PointerValue, FloatValue, FunctionValue, IntValue, BasicMetadataValueEnum, BasicValueEnum, BasicValue, GlobalValue};
use inkwell::{AddressSpace, IntPredicate};
use inkwell::builder::Builder;
use inkwell::module::Module;
use inkwell::module::{Module, Linkage};
use std::collections::HashMap;
use std::iter::zip;
use anyhow::{Result, anyhow};
Expand All @@ -21,13 +21,31 @@ use crate::codegen::{Translation, TranslationFrom, TranslationTo, DataLayout};
/// Calling this is innately `unsafe` because there's no guarantee it doesn't
/// do `unsafe` operations internally.
pub type ResidualFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype, indices: *const i32, rr: *mut realtype);
pub type ResidualGradientFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, du: *const realtype, up: *const realtype, dup: *const realtype, data: *mut realtype, ddata: *mut realtype, indices: *const i32, rr: *mut realtype, drr: *mut realtype);
pub type U0Func = unsafe extern "C" fn(data: *mut realtype, indices: *const i32, u: *mut realtype, up: *mut realtype);
pub type U0GradientFunc = unsafe extern "C" fn(data: *mut realtype, ddata: *mut realtype, indices: *const i32, u: *mut realtype, du: *mut realtype, up: *mut realtype, dup: *mut realtype);
pub type CalcOutFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype, indices: *const i32);
pub type CalcOutGradientFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, du: *const realtype, up: *const realtype, dup: *const realtype, data: *mut realtype, ddata: *mut realtype, indices: *const i32);
pub type GetDimsFunc = unsafe extern "C" fn(states: *mut u32, inputs: *mut u32, outputs: *mut u32, data: *mut u32, indices: *const u32);
pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const realtype, data: *mut realtype);
pub type SetInputsGradientFunc = unsafe extern "C" fn(inputs: *const realtype, dinputs: *const realtype, data: *mut realtype, ddata: *mut realtype);
pub type SetIdFunc = unsafe extern "C" fn(id: *mut realtype);
pub type GetOutFunc = unsafe extern "C" fn(data: *const realtype, tensor_data: *mut *mut realtype, tensor_size: *mut u32);

struct EnzymeGlobals<'ctx> {
enzyme_dup: GlobalValue<'ctx>,
enzyme_const: GlobalValue<'ctx>,
}

impl<'ctx> EnzymeGlobals<'ctx> {
fn new(context: &'ctx inkwell::context::Context, module: &Module<'ctx>) -> Result<Self> {
let int_type = context.i32_type();
Ok(Self {
enzyme_dup: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dup"),
enzyme_const: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_const"),
})
}
}


pub struct CodeGen<'ctx> {
Expand All @@ -43,6 +61,7 @@ pub struct CodeGen<'ctx> {
real_type_str: String,
int_type: IntType<'ctx>,
layout: DataLayout,
enzyme_globals: Option<EnzymeGlobals<'ctx>>,
}

impl<'ctx> CodeGen<'ctx> {
Expand All @@ -57,10 +76,12 @@ impl<'ctx> CodeGen<'ctx> {
fpm.add_instruction_combining_pass();
fpm.add_reassociate_pass();
fpm.initialize();
let builder = context.create_builder();
let enzyme_globals = EnzymeGlobals::new(context, &module).ok();
Self {
context: &context,
module,
builder: context.create_builder(),
builder: builder,
fpm,
real_type,
real_type_str: real_type_str.to_owned(),
Expand All @@ -70,6 +91,7 @@ impl<'ctx> CodeGen<'ctx> {
tensor_ptr_opt: None,
layout: DataLayout::new(model),
int_type: context.i32_type(),
enzyme_globals,
}
}

Expand Down Expand Up @@ -223,7 +245,6 @@ impl<'ctx> CodeGen<'ctx> {
builder
}


fn jit_compile_scalar(&mut self, a: &Tensor, res_ptr_opt: Option<PointerValue<'ctx>>) -> Result<PointerValue<'ctx>> {
let res_type = self.real_type;
let res_ptr = match res_ptr_opt {
Expand Down Expand Up @@ -896,8 +917,93 @@ impl<'ctx> CodeGen<'ctx> {
Err(anyhow!("Invalid generated function."))
}
}

pub fn compile_gradient(&mut self, original_function: FunctionValue<'ctx>, args_is_const: &[bool]) -> Result<FunctionValue<'ctx>> {
self.clear();

let enzyme_globals = match self.enzyme_globals {
Some(ref globals) => globals,
None => panic!("enzyme globals not set"),
};

// construct the gradient function
let mut fn_type: Vec<BasicMetadataTypeEnum> = Vec::new();
let orig_fn_type_ptr = original_function.get_type().ptr_type(AddressSpace::default());
let mut enzyme_fn_type: Vec<BasicMetadataTypeEnum> = vec![orig_fn_type_ptr.into()];
let mut start_param_index: Vec<u32> = Vec::new();
for (i, arg) in original_function.get_param_iter().enumerate() {
start_param_index.push(u32::try_from(fn_type.len()).unwrap());
fn_type.push(arg.get_type().into());

// constant args with type T in the original funciton have 2 args of type [int, T]
enzyme_fn_type.push(self.int_type.into());
enzyme_fn_type.push(arg.get_type().into());

if !args_is_const[i] {
// active args with type T in the original funciton have 3 args of type [int, T, T]
fn_type.push(arg.get_type().into());
enzyme_fn_type.push(arg.get_type().into());
}
}
let void_type = self.context.void_type();
let fn_type = void_type.fn_type(fn_type.as_slice(), false);
let fn_name = format!("{}_grad", original_function.get_name().to_str().unwrap());
let function = self.module.add_function(fn_name.as_str(), fn_type, None);
let basic_block = self.context.append_basic_block(function, "entry");
self.fn_value_opt = Some(function);
self.builder.position_at_end(basic_block);

pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<()> {
let mut enzyme_fn_args: Vec<BasicMetadataValueEnum> = vec![original_function.as_global_value().as_pointer_value().into()];
let enzyme_const = self.builder.build_load(enzyme_globals.enzyme_const.as_pointer_value(), "enzyme_const")?;
let enzyme_dup= self.builder.build_load(enzyme_globals.enzyme_dup.as_pointer_value(), "enzyme_const")?;
for (i, _arg) in original_function.get_param_iter().enumerate() {
let param_index = start_param_index[i];
let fn_arg = function.get_nth_param(param_index).unwrap();
if args_is_const[i] {
// let enzyme know its a constant arg
enzyme_fn_args.push(enzyme_const.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());
} else {
// let enzyme know its an active arg
enzyme_fn_args.push(enzyme_dup.into());

// pass in the arg value
enzyme_fn_args.push(fn_arg.into());

// pass in the darg value
let fn_darg = function.get_nth_param(param_index + 1).unwrap();
enzyme_fn_args.push(fn_darg.into());
}
}

// construct enzyme function
// double df = __enzyme_fwddiff<double>((void*)f, enzyme_dup, x, dx, enzyme_dup, y, dy);
let enzyme_fn_type = void_type.fn_type(&enzyme_fn_type, false);
let orig_fn_name = original_function.get_name().to_str().unwrap();
let enzyme_fn_name = format!("__enzyme_fwddiff_{}", orig_fn_name);
let enzyme_function = self.module.add_function(&enzyme_fn_name.as_str(), enzyme_fn_type, Some(Linkage::External));

// call enzyme function
self.builder.build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?;

// return
self.builder.build_return(None)?;

if function.verify(true) {
self.fpm.run_on(&function);
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
Err(anyhow!("Invalid generated function."))
}
}

pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
self.clear();
let int_ptr_type = self.context.i32_type().ptr_type(AddressSpace::default());
let fn_type = self.context.void_type().fn_type(
Expand Down Expand Up @@ -930,17 +1036,17 @@ impl<'ctx> CodeGen<'ctx> {

if function.verify(true) {
self.fpm.run_on(&function);
Ok(function)
} else {
function.print_to_stderr();
unsafe {
function.delete();
}
return Err(anyhow!("Invalid generated function."))
Err(anyhow!("Invalid generated function."))
}
Ok(())
}

pub fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result<()> {
pub fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result<FunctionValue<'ctx>> {
self.clear();
let real_ptr_ptr_type = self.real_type.ptr_type(AddressSpace::default()).ptr_type(AddressSpace::default());
let real_ptr_type = self.real_type.ptr_type(AddressSpace::default());
Expand Down Expand Up @@ -973,7 +1079,7 @@ impl<'ctx> CodeGen<'ctx> {

if function.verify(true) {
self.fpm.run_on(&function);
Ok(())
Ok(function)
} else {
function.print_to_stderr();
unsafe {
Expand All @@ -983,7 +1089,7 @@ impl<'ctx> CodeGen<'ctx> {
}
}

pub fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result<()> {
pub fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
self.clear();
let real_ptr_type = self.real_type.ptr_type(AddressSpace::default());
let void_type = self.context.void_type();
Expand Down Expand Up @@ -1047,7 +1153,7 @@ impl<'ctx> CodeGen<'ctx> {

if function.verify(true) {
self.fpm.run_on(&function);
Ok(())
Ok(function)
} else {
function.print_to_stderr();
unsafe {
Expand All @@ -1057,7 +1163,7 @@ impl<'ctx> CodeGen<'ctx> {
}
}

pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<()> {
pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<FunctionValue<'ctx>> {
self.clear();
let real_ptr_type = self.real_type.ptr_type(AddressSpace::default());
let void_type = self.context.void_type();
Expand Down Expand Up @@ -1118,7 +1224,7 @@ impl<'ctx> CodeGen<'ctx> {

if function.verify(true) {
self.fpm.run_on(&function);
Ok(())
Ok(function)
} else {
function.print_to_stderr();
unsafe {
Expand Down
10 changes: 6 additions & 4 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ struct CompilerData<'ctx> {
get_dims: JitFunction<'ctx, GetDimsFunc>,
set_inputs: JitFunction<'ctx, SetInputsFunc>,
get_out: JitFunction<'ctx, GetOutFunc>,

}

#[self_referencing]
Expand Down Expand Up @@ -60,14 +59,17 @@ impl Compiler {
let mut codegen = CodeGen::new(model, &context, module, real_type, real_type_str);

let _set_u0 = codegen.compile_set_u0(model)?;
let _set_u0_grad = codegen.compile_gradient(_set_u0, &[false, true, false, false])?;
let _residual = codegen.compile_residual(model)?;
let _residual_grad = codegen.compile_gradient(_residual, &[true, false, false, false, true, false])?;
let _calc_out = codegen.compile_calc_out(model)?;
let _calc_out_grad = codegen.compile_gradient(_calc_out, &[true, false, false, false, true])?;
let _set_id = codegen.compile_set_id(model)?;
let _get_dims= codegen.compile_get_dims(model)?;
let _set_inputs = codegen.compile_set_inputs(model)?;
let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[false, false])?;
let _get_output = codegen.compile_get_tensor(model, "out")?;



let set_u0 = Compiler::jit("set_u0", &ee)?;
let residual = Compiler::jit("residual", &ee)?;
let calc_out = Compiler::jit("calc_out", &ee)?;
Expand Down Expand Up @@ -388,7 +390,7 @@ mod tests {
derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.],
concatenate: "r_i {2, 3} k_i { r_i, 2 * r_i }" expect "k" vec![2., 3., 4., 6.],
ones_matrix_dense: "I_ij { (0:2, 0:2): 1 }" expect "I" vec![1., 1., 1., 1.],
dense_matrix: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 }" expect "A" vec![1., 2., 3., 4.],
dense_matrix: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 }" expect "A" vec![1., 2., 3., 4.],
identity_matrix_diagonal: "I_ij { (0..2, 0..2): 1 }" expect "I" vec![1., 1.],
concatenate_diagonal: "A_ij { (0..2, 0..2): 1 } B_ij { (0:2, 0:2): A_ij, (2:4, 2:4): A_ij }" expect "B" vec![1., 1., 1., 1.],
identity_matrix_sparse: "I_ij { (0, 0): 1, (1, 1): 2 }" expect "I" vec![1., 2.],
Expand Down

0 comments on commit d7e2f92

Please sign in to comment.