Skip to content

Commit

Permalink
refactoring compiler so can load from bitcode
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Oct 30, 2023
1 parent d7e2f92 commit 6ef5458
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 27 deletions.
154 changes: 128 additions & 26 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@ use std::path::Path;
use anyhow::anyhow;

use anyhow::Result;
use inkwell::memory_buffer::MemoryBuffer;
use inkwell::module::Module;
use inkwell::targets::TargetMachine;
use inkwell::{context::Context, OptimizationLevel, targets::{TargetTriple, InitializationConfig, Target, RelocMode, CodeModel, FileType}, execution_engine::{JitFunction, ExecutionEngine, UnsafeFunctionPointer}};
use ouroboros::self_referencing;
use crate::discretise::DiscreteModel;


use super::codegen::CalcOutGradientFunc;
use super::codegen::GetDimsFunc;
use super::codegen::GetOutFunc;
use super::codegen::ResidualGradientFunc;
use super::codegen::SetIdFunc;
use super::codegen::SetInputsFunc;
use super::codegen::SetInputsGradientFunc;
use super::codegen::U0GradientFunc;
use super::{CodeGen, codegen::{U0Func, ResidualFunc, CalcOutFunc}, data_layout::DataLayout};

struct CompilerData<'ctx> {
codegen: CodeGen<'ctx>,
struct JitFunctions<'ctx> {
set_u0: JitFunction<'ctx, U0Func>,
residual: JitFunction<'ctx, ResidualFunc>,
calc_out: JitFunction<'ctx, CalcOutFunc>,
Expand All @@ -25,31 +30,120 @@ struct CompilerData<'ctx> {
get_out: JitFunction<'ctx, GetOutFunc>,
}

struct JitGradFunctions<'ctx> {
set_u0_grad: JitFunction<'ctx, U0GradientFunc>,
residual_grad: JitFunction<'ctx, ResidualGradientFunc>,
calc_out_grad: JitFunction<'ctx, CalcOutGradientFunc>,
set_inputs_grad: JitFunction<'ctx, SetInputsGradientFunc>,
}

struct CompilerData<'ctx> {
codegen: Option<CodeGen<'ctx>>,
jit_functions: JitFunctions<'ctx>,
jit_grad_functions: Option<JitGradFunctions<'ctx>>,
}

struct Buffers {
data: Vec<f64>,
indices: Vec<i32>,
}

// if we build from a model file use a DataLayout,
// otherwise use raw buffers
enum CompilerDataLayout {
Layout(DataLayout),
Buffers(Buffers),
}

#[self_referencing]
pub struct Compiler {
context: Context,
#[borrows(context)]
#[not_covariant]
data: CompilerData<'this>,

data_layout: DataLayout,
data_layout: CompilerDataLayout,
number_of_states: usize,
input_names: Vec<String>,
number_of_parameters: usize,
}

impl Compiler {
pub fn from_bitcode(path: &Path) -> Result<Self> {
let context = Context::create();
let buffer = MemoryBuffer::create_from_file(&path).unwrap();
let module = Module::parse_bitcode_from_buffer(&buffer, &context)?;
let ee = module.create_jit_execution_engine(OptimizationLevel::None).map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?;

let set_u0 = Compiler::jit("set_u0", &ee)?;
let residual = Compiler::jit("residual", &ee)?;
let calc_out = Compiler::jit("calc_out", &ee)?;
let set_id = Compiler::jit("set_id", &ee)?;
let get_dims= Compiler::jit("get_dims", &ee)?;
let set_inputs = Compiler::jit("set_inputs", &ee)?;
let get_out= Compiler::jit("get_out", &ee)?;

let set_inputs_grad = Compiler::jit("set_inputs_grad", &ee)?;
let calc_out_grad = Compiler::jit("calc_out_grad", &ee)?;
let residual_grad = Compiler::jit("residual_grad", &ee)?;
let set_u0_grad = Compiler::jit("set_u0_grad", &ee)?;


// get the number of states, inputs, outputs, data, and indices
let get_dims: JitFunction<'_, GetDimsFunc> = get_dims;
let mut n_states = 0u32;
let mut n_inputs= 0u32;
let mut n_outputs = 0u32;
let mut n_data= 0u32;
let mut n_indices = 0u32;
unsafe { get_dims.call(&mut n_states, &mut n_inputs, &mut n_outputs, &mut n_data, &mut n_indices); }

let buffers = Buffers {
data: vec![0.0; usize::try_from(n_data).unwrap()],
indices: vec![0; usize::try_from(n_indices).unwrap()],
};

CompilerTryBuilder {
context,
data_layout: CompilerDataLayout::Buffers(buffers),
number_of_states: usize::try_from(n_states).unwrap(),
number_of_parameters: usize::try_from(n_inputs).unwrap(),
data_builder: |context| {
Ok({
CompilerData {
codegen: None,
jit_functions: JitFunctions {
set_u0,
residual,
calc_out,
set_id,
get_dims,
set_inputs,
get_out,
},
jit_grad_functions: Some(JitGradFunctions {
set_u0_grad,
residual_grad,
calc_out_grad,
set_inputs_grad,
}),
}
})
}
}.try_build()

}
pub fn from_discrete_model(model: &DiscreteModel) -> Result<Self> {
let number_of_states = usize::try_from(
*model.state().shape().first().unwrap_or(&1)
).unwrap();
let input_names = model.inputs().iter().map(|input| input.name().to_owned()).collect::<Vec<_>>();
let data_layout = DataLayout::new(model);
let context = Context::create();
let number_of_parameters = input_names.iter().fold(0, |acc, name| acc + data_layout.get_data_length(name).unwrap());
CompilerTryBuilder {
context,
data_layout,
data_layout: CompilerDataLayout::Layout(data_layout),
number_of_states,
input_names,
number_of_parameters,
data_builder: |context| {
let module = context.create_module(model.name());

Expand Down Expand Up @@ -78,23 +172,26 @@ impl Compiler {
let set_inputs = Compiler::jit("set_inputs", &ee)?;
let get_out= Compiler::jit("get_out", &ee)?;


Ok({
CompilerData {
codegen,
set_u0,
residual,
calc_out,
set_id,
get_dims,
set_inputs,
get_out,
Some(codegen),
jit_functions: JitFunctions {
set_u0,
residual,
calc_out,
set_id,
get_dims,
set_inputs,
get_out,
},
jit_grad_functions: None,
}
})
}
}.try_build()
}


fn jit<'ctx, T>(name: &str, ee: &ExecutionEngine<'ctx>) -> Result<JitFunction<'ctx, T>>
where T: UnsafeFunctionPointer
{
Expand All @@ -109,7 +206,10 @@ impl Compiler {
}

pub fn get_tensor_data(&self, name: &str) -> Option<&[f64]> {
self.borrow_data_layout().get_tensor_data(name)
match self.borrow_data_layout() {
CompilerDataLayout::Layout(layout) => layout.get_tensor_data(name),
CompilerDataLayout::Buffers(buffers) => None
}
}

pub fn set_u0(&mut self, yy: &mut [f64], yp: &mut [f64]) -> Result<()> {
Expand All @@ -121,11 +221,13 @@ impl Compiler {
return Err(anyhow!("Expected {} state derivatives, got {}", number_of_states, yp.len()));
}
self.with_mut(|compiler| {
let data_ptr = compiler.data_layout.data_mut().as_mut_ptr();
let indices_ptr = compiler.data_layout.indices().as_ptr();
let (data_ptr, indices_ptr) = match compiler.data_layout {
CompilerDataLayout::Layout(ref layout) => (layout.data_mut().as_mut_ptr(), layout.indices().as_ptr()),
CompilerDataLayout::Buffers(ref buffers) => (buffers.data.as_mut_ptr(), buffers.indices.as_ptr()),
};
let yy_ptr = yy.as_mut_ptr();
let yp_ptr = yp.as_mut_ptr();
unsafe { compiler.data.set_u0.call(data_ptr, indices_ptr, yy_ptr, yp_ptr); }
unsafe { compiler.data.jit_functions.set_u0.call(data_ptr, indices_ptr, yy_ptr, yp_ptr); }
});

Ok(())
Expand All @@ -149,7 +251,7 @@ impl Compiler {
let yy_ptr = yy.as_ptr();
let yp_ptr = yp.as_ptr();
let rr_ptr = rr.as_mut_ptr();
unsafe { compiler.data.residual.call(t, yy_ptr, yp_ptr, data_ptr, indices_ptr, rr_ptr); }
unsafe { compiler.data.jit_functions.residual.call(t, yy_ptr, yp_ptr, data_ptr, indices_ptr, rr_ptr); }
});
Ok(())
}
Expand All @@ -168,7 +270,7 @@ impl Compiler {
let indices_ptr = layout.indices().as_ptr();
let yy_ptr = yy.as_ptr();
let yp_ptr = yp.as_ptr();
unsafe { compiler.data.calc_out.call(t, yy_ptr, yp_ptr, data_ptr, indices_ptr); }
unsafe { compiler.data.jit_functions.calc_out.call(t, yy_ptr, yp_ptr, data_ptr, indices_ptr); }
});
Ok(())
}
Expand All @@ -185,7 +287,7 @@ impl Compiler {
let mut n_data= 0u32;
let mut n_indices = 0u32;
self.with(|compiler| {
unsafe { compiler.data.get_dims.call(&mut n_states, &mut n_inputs, &mut n_outputs, &mut n_data, &mut n_indices); }
unsafe { compiler.data.jit_functions.get_dims.call(&mut n_states, &mut n_inputs, &mut n_outputs, &mut n_data, &mut n_indices); }
});
(n_states as usize, n_inputs as usize, n_outputs as usize, n_data as usize, n_indices as usize)
}
Expand All @@ -198,7 +300,7 @@ impl Compiler {
self.with_mut(|compiler| {
let layout = compiler.data_layout;
let data_ptr = layout.data_mut().as_mut_ptr();
unsafe { compiler.data.set_inputs.call(inputs.as_ptr(), data_ptr); }
unsafe { compiler.data.jit_functions.set_inputs.call(inputs.as_ptr(), data_ptr); }
});
Ok(())
}
Expand All @@ -212,7 +314,7 @@ impl Compiler {
self.with(|compiler| {
let layout = compiler.data_layout;
let data_ptr = layout.data().as_ptr();
unsafe { compiler.data.get_out.call(data_ptr, tensor_data_ptr_ptr, tensor_data_len_ptr); }
unsafe { compiler.data.get_out.jit_functions.call(data_ptr, tensor_data_ptr_ptr, tensor_data_len_ptr); }
});
assert!(tensor_data_len as usize == n_outputs);
unsafe { std::slice::from_raw_parts(tensor_data_ptr, tensor_data_len as usize) }
Expand All @@ -224,7 +326,7 @@ impl Compiler {
return Err(anyhow!("Expected {} states, got {}", n_states, id.len()));
}
self.with_mut(|compiler| {
Ok(unsafe { compiler.data.set_id.call(id.as_mut_ptr()); })
Ok(unsafe { compiler.data.jit_functions.set_id.call(id.as_mut_ptr()); })
})
}

Expand Down
65 changes: 64 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp
.arg("-enzyme")
.arg("--enable-new-pm=0")
.arg("-o").arg(bitcodefile.to_str().unwrap())
.arg("-S")
.output()?;

if let Some(code) = output.status.code() {
Expand Down Expand Up @@ -247,6 +246,70 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp
Ok(())
}

#[cfg(test)]
mod tests {
use crate::{parser::{parse_ds_string, parse_ms_string}, continuous::ModelInfo};
use approx::assert_relative_eq;

use super::*;

fn test_ds_example(example: &str) -> Compiler {
compile(format!("examples/{}.ds", example).as_str(), None, None, CompilerOptions {
bitcode_only: true,
wasm: false,
standalone: false,
}).unwrap()
let model = parse_ds_string(text.as_str()).unwrap();
let discrete_model = match DiscreteModel::build("name", &model) {
Ok(model) => {
model
}
Err(e) => {
panic!("{}", e.as_error_message(text.as_str()));
}
};
Compiler::from_discrete_model(&discrete_model).unwrap()
}

#[test]
fn test_logistic_ds_example(example: &str) {
let inputs = vec![];
let mut u0 = vec![1.];
let mut up0 = vec![1.];
let mut res = vec![0.];
compiler.set_inputs(inputs.as_slice()).unwrap();
compiler.set_u0(u0.as_mut_slice(), up0.as_mut_slice()).unwrap();
compiler.residual(0., u0.as_slice(), up0.as_slice(), res.as_mut_slice()).unwrap();
let tensor = compiler.get_tensor_data($tensor_name).unwrap();
assert_relative_eq!(tensor, $expected_value.as_slice());

assert_eq!(model_info.errors.len(), 0);
let discrete_model = DiscreteModel::from(&model_info);
let object = Compiler::from_discrete_model(&discrete_model).unwrap();
let path = Path::new("main.o");
object.write_object_file(path).unwrap();
}

#[test]
fn test_object_file() {
let text = "
model logistic_growth(r -> NonNegative, k -> NonNegative, y(t), z(t)) {
dot(y) = r * y * (1 - y / k)
y(0) = 1.0
z = 2 * y
}
";
let models = parse_ms_string(text).unwrap();
let model_info = ModelInfo::build("logistic_growth", &models).unwrap();
assert_eq!(model_info.errors.len(), 0);
let discrete_model = DiscreteModel::from(&model_info);
let object = Compiler::from_discrete_model(&discrete_model).unwrap();
let path = Path::new("main.o");
object.write_object_file(path).unwrap();
}
}





0 comments on commit 6ef5458

Please sign in to comment.