From d7e2f9263f9f2d9833d88e41aa69fa9dec7026f8 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Mon, 30 Oct 2023 14:21:29 +0000 Subject: [PATCH] add compile_gradient function with enzymeAD --- .vscode/launch.json | 85 ++++++++++++++++++-------- Cargo.toml | 2 +- src/codegen/codegen.rs | 132 ++++++++++++++++++++++++++++++++++++---- src/codegen/compiler.rs | 10 +-- 4 files changed, 187 insertions(+), 42 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 120bd03..b15661f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,27 +1,64 @@ { - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'diffeq'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=diffeq" - ], - "filter": { - "name": "diffeq", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'diffeq'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=diffeq" + ], + "filter": { + "name": "diffeq", + "kind": "lib" } - ] + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'diffeq'", + "cargo": { + "args": [ + "build", + "--bin=diffeq", + "--package=diffeq" + ], + "filter": { + "name": "diffeq", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in executable 'diffeq'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bin=diffeq", + "--package=diffeq" + ], + "filter": { + "name": "diffeq", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 22db40f..03ce7da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ pest = ">=2.1.3" pest_derive = ">=2.1.0" itertools = ">=0.10.3" inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm14-0"] } -sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries"] } +sundials-sys = { version = ">=0.3", features = ["idas", "build_libraries", "static_libraries"] } ouroboros = ">=0.17" clap = { version = "4.3.23", features = ["derive"] } diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index accb2a2..df57fd1 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -2,10 +2,10 @@ use inkwell::basic_block::BasicBlock; use inkwell::intrinsics::Intrinsic; use inkwell::passes::PassManager; use inkwell::types::{FloatType, BasicMetadataTypeEnum, BasicTypeEnum, IntType}; -use inkwell::values::{PointerValue, FloatValue, FunctionValue, IntValue, BasicMetadataValueEnum, BasicValueEnum, BasicValue}; +use inkwell::values::{PointerValue, FloatValue, FunctionValue, IntValue, BasicMetadataValueEnum, BasicValueEnum, BasicValue, GlobalValue}; use inkwell::{AddressSpace, IntPredicate}; use inkwell::builder::Builder; -use inkwell::module::Module; +use inkwell::module::{Module, Linkage}; use std::collections::HashMap; use std::iter::zip; use anyhow::{Result, anyhow}; @@ -21,13 +21,31 @@ use crate::codegen::{Translation, TranslationFrom, TranslationTo, DataLayout}; /// Calling this is innately `unsafe` because there's no guarantee it doesn't /// do `unsafe` operations internally. pub type ResidualFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype, indices: *const i32, rr: *mut realtype); +pub type ResidualGradientFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, du: *const realtype, up: *const realtype, dup: *const realtype, data: *mut realtype, ddata: *mut realtype, indices: *const i32, rr: *mut realtype, drr: *mut realtype); pub type U0Func = unsafe extern "C" fn(data: *mut realtype, indices: *const i32, u: *mut realtype, up: *mut realtype); +pub type U0GradientFunc = unsafe extern "C" fn(data: *mut realtype, ddata: *mut realtype, indices: *const i32, u: *mut realtype, du: *mut realtype, up: *mut realtype, dup: *mut realtype); pub type CalcOutFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, up: *const realtype, data: *mut realtype, indices: *const i32); +pub type CalcOutGradientFunc = unsafe extern "C" fn(time: realtype, u: *const realtype, du: *const realtype, up: *const realtype, dup: *const realtype, data: *mut realtype, ddata: *mut realtype, indices: *const i32); pub type GetDimsFunc = unsafe extern "C" fn(states: *mut u32, inputs: *mut u32, outputs: *mut u32, data: *mut u32, indices: *const u32); pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const realtype, data: *mut realtype); +pub type SetInputsGradientFunc = unsafe extern "C" fn(inputs: *const realtype, dinputs: *const realtype, data: *mut realtype, ddata: *mut realtype); pub type SetIdFunc = unsafe extern "C" fn(id: *mut realtype); pub type GetOutFunc = unsafe extern "C" fn(data: *const realtype, tensor_data: *mut *mut realtype, tensor_size: *mut u32); +struct EnzymeGlobals<'ctx> { + enzyme_dup: GlobalValue<'ctx>, + enzyme_const: GlobalValue<'ctx>, +} + +impl<'ctx> EnzymeGlobals<'ctx> { + fn new(context: &'ctx inkwell::context::Context, module: &Module<'ctx>) -> Result { + let int_type = context.i32_type(); + Ok(Self { + enzyme_dup: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_dup"), + enzyme_const: module.add_global(int_type, Some(AddressSpace::default()), "enzyme_const"), + }) + } +} pub struct CodeGen<'ctx> { @@ -43,6 +61,7 @@ pub struct CodeGen<'ctx> { real_type_str: String, int_type: IntType<'ctx>, layout: DataLayout, + enzyme_globals: Option>, } impl<'ctx> CodeGen<'ctx> { @@ -57,10 +76,12 @@ impl<'ctx> CodeGen<'ctx> { fpm.add_instruction_combining_pass(); fpm.add_reassociate_pass(); fpm.initialize(); + let builder = context.create_builder(); + let enzyme_globals = EnzymeGlobals::new(context, &module).ok(); Self { context: &context, module, - builder: context.create_builder(), + builder: builder, fpm, real_type, real_type_str: real_type_str.to_owned(), @@ -70,6 +91,7 @@ impl<'ctx> CodeGen<'ctx> { tensor_ptr_opt: None, layout: DataLayout::new(model), int_type: context.i32_type(), + enzyme_globals, } } @@ -223,7 +245,6 @@ impl<'ctx> CodeGen<'ctx> { builder } - fn jit_compile_scalar(&mut self, a: &Tensor, res_ptr_opt: Option>) -> Result> { let res_type = self.real_type; let res_ptr = match res_ptr_opt { @@ -896,8 +917,93 @@ impl<'ctx> CodeGen<'ctx> { Err(anyhow!("Invalid generated function.")) } } + + pub fn compile_gradient(&mut self, original_function: FunctionValue<'ctx>, args_is_const: &[bool]) -> Result> { + self.clear(); + + let enzyme_globals = match self.enzyme_globals { + Some(ref globals) => globals, + None => panic!("enzyme globals not set"), + }; + + // construct the gradient function + let mut fn_type: Vec = Vec::new(); + let orig_fn_type_ptr = original_function.get_type().ptr_type(AddressSpace::default()); + let mut enzyme_fn_type: Vec = vec![orig_fn_type_ptr.into()]; + let mut start_param_index: Vec = Vec::new(); + for (i, arg) in original_function.get_param_iter().enumerate() { + start_param_index.push(u32::try_from(fn_type.len()).unwrap()); + fn_type.push(arg.get_type().into()); + + // constant args with type T in the original funciton have 2 args of type [int, T] + enzyme_fn_type.push(self.int_type.into()); + enzyme_fn_type.push(arg.get_type().into()); + + if !args_is_const[i] { + // active args with type T in the original funciton have 3 args of type [int, T, T] + fn_type.push(arg.get_type().into()); + enzyme_fn_type.push(arg.get_type().into()); + } + } + let void_type = self.context.void_type(); + let fn_type = void_type.fn_type(fn_type.as_slice(), false); + let fn_name = format!("{}_grad", original_function.get_name().to_str().unwrap()); + let function = self.module.add_function(fn_name.as_str(), fn_type, None); + let basic_block = self.context.append_basic_block(function, "entry"); + self.fn_value_opt = Some(function); + self.builder.position_at_end(basic_block); - pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result<()> { + let mut enzyme_fn_args: Vec = vec![original_function.as_global_value().as_pointer_value().into()]; + let enzyme_const = self.builder.build_load(enzyme_globals.enzyme_const.as_pointer_value(), "enzyme_const")?; + let enzyme_dup= self.builder.build_load(enzyme_globals.enzyme_dup.as_pointer_value(), "enzyme_const")?; + for (i, _arg) in original_function.get_param_iter().enumerate() { + let param_index = start_param_index[i]; + let fn_arg = function.get_nth_param(param_index).unwrap(); + if args_is_const[i] { + // let enzyme know its a constant arg + enzyme_fn_args.push(enzyme_const.into()); + + // pass in the arg value + enzyme_fn_args.push(fn_arg.into()); + } else { + // let enzyme know its an active arg + enzyme_fn_args.push(enzyme_dup.into()); + + // pass in the arg value + enzyme_fn_args.push(fn_arg.into()); + + // pass in the darg value + let fn_darg = function.get_nth_param(param_index + 1).unwrap(); + enzyme_fn_args.push(fn_darg.into()); + } + } + + // construct enzyme function + // double df = __enzyme_fwddiff((void*)f, enzyme_dup, x, dx, enzyme_dup, y, dy); + let enzyme_fn_type = void_type.fn_type(&enzyme_fn_type, false); + let orig_fn_name = original_function.get_name().to_str().unwrap(); + let enzyme_fn_name = format!("__enzyme_fwddiff_{}", orig_fn_name); + let enzyme_function = self.module.add_function(&enzyme_fn_name.as_str(), enzyme_fn_type, Some(Linkage::External)); + + // call enzyme function + self.builder.build_call(enzyme_function, enzyme_fn_args.as_slice(), "enzyme_call")?; + + // return + self.builder.build_return(None)?; + + if function.verify(true) { + self.fpm.run_on(&function); + Ok(function) + } else { + function.print_to_stderr(); + unsafe { + function.delete(); + } + Err(anyhow!("Invalid generated function.")) + } + } + + pub fn compile_get_dims(&mut self, model: &DiscreteModel) -> Result> { self.clear(); let int_ptr_type = self.context.i32_type().ptr_type(AddressSpace::default()); let fn_type = self.context.void_type().fn_type( @@ -930,17 +1036,17 @@ impl<'ctx> CodeGen<'ctx> { if function.verify(true) { self.fpm.run_on(&function); + Ok(function) } else { function.print_to_stderr(); unsafe { function.delete(); } - return Err(anyhow!("Invalid generated function.")) + Err(anyhow!("Invalid generated function.")) } - Ok(()) } - pub fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result<()> { + pub fn compile_get_tensor(&mut self, model: &DiscreteModel, name: &str) -> Result> { self.clear(); let real_ptr_ptr_type = self.real_type.ptr_type(AddressSpace::default()).ptr_type(AddressSpace::default()); let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); @@ -973,7 +1079,7 @@ impl<'ctx> CodeGen<'ctx> { if function.verify(true) { self.fpm.run_on(&function); - Ok(()) + Ok(function) } else { function.print_to_stderr(); unsafe { @@ -983,7 +1089,7 @@ impl<'ctx> CodeGen<'ctx> { } } - pub fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result<()> { + pub fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result> { self.clear(); let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); @@ -1047,7 +1153,7 @@ impl<'ctx> CodeGen<'ctx> { if function.verify(true) { self.fpm.run_on(&function); - Ok(()) + Ok(function) } else { function.print_to_stderr(); unsafe { @@ -1057,7 +1163,7 @@ impl<'ctx> CodeGen<'ctx> { } } - pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result<()> { + pub fn compile_set_id(&mut self, model: &DiscreteModel) -> Result> { self.clear(); let real_ptr_type = self.real_type.ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); @@ -1118,7 +1224,7 @@ impl<'ctx> CodeGen<'ctx> { if function.verify(true) { self.fpm.run_on(&function); - Ok(()) + Ok(function) } else { function.print_to_stderr(); unsafe { diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index 8d20851..99c7edd 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -23,7 +23,6 @@ struct CompilerData<'ctx> { get_dims: JitFunction<'ctx, GetDimsFunc>, set_inputs: JitFunction<'ctx, SetInputsFunc>, get_out: JitFunction<'ctx, GetOutFunc>, - } #[self_referencing] @@ -60,14 +59,17 @@ impl Compiler { let mut codegen = CodeGen::new(model, &context, module, real_type, real_type_str); let _set_u0 = codegen.compile_set_u0(model)?; + let _set_u0_grad = codegen.compile_gradient(_set_u0, &[false, true, false, false])?; let _residual = codegen.compile_residual(model)?; + let _residual_grad = codegen.compile_gradient(_residual, &[true, false, false, false, true, false])?; let _calc_out = codegen.compile_calc_out(model)?; + let _calc_out_grad = codegen.compile_gradient(_calc_out, &[true, false, false, false, true])?; let _set_id = codegen.compile_set_id(model)?; let _get_dims= codegen.compile_get_dims(model)?; let _set_inputs = codegen.compile_set_inputs(model)?; + let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[false, false])?; let _get_output = codegen.compile_get_tensor(model, "out")?; - - + let set_u0 = Compiler::jit("set_u0", &ee)?; let residual = Compiler::jit("residual", &ee)?; let calc_out = Compiler::jit("calc_out", &ee)?; @@ -388,7 +390,7 @@ mod tests { derived: "r_i {2, 3} k_i { 2 * r_i }" expect "k" vec![4., 6.], concatenate: "r_i {2, 3} k_i { r_i, 2 * r_i }" expect "k" vec![2., 3., 4., 6.], ones_matrix_dense: "I_ij { (0:2, 0:2): 1 }" expect "I" vec![1., 1., 1., 1.], - dense_matrix: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 }" expect "A" vec![1., 2., 3., 4.], + dense_matrix: "A_ij { (0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4 }" expect "A" vec![1., 2., 3., 4.], identity_matrix_diagonal: "I_ij { (0..2, 0..2): 1 }" expect "I" vec![1., 1.], concatenate_diagonal: "A_ij { (0..2, 0..2): 1 } B_ij { (0:2, 0:2): A_ij, (2:4, 2:4): A_ij }" expect "B" vec![1., 1., 1., 1.], identity_matrix_sparse: "I_ij { (0, 0): 1, (1, 1): 2 }" expect "I" vec![1., 2.],