Skip to content

Commit

Permalink
refactor compile into Compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Nov 9, 2023
1 parent e08a3c4 commit c74c1ba
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 123 deletions.
184 changes: 145 additions & 39 deletions src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use inkwell::{context::Context, OptimizationLevel, targets::{TargetTriple, Initi
use ouroboros::self_referencing;
use crate::discretise::DiscreteModel;
use crate::utils::find_executable;
use crate::utils::find_library_path;
use crate::utils::find_runtime_path;
use std::process::Command;

Expand Down Expand Up @@ -62,10 +63,22 @@ pub struct Compiler {
number_of_parameters: usize,
number_of_outputs: usize,
data_layout: DataLayout,
bitcode_filename: String,
output_base_filename: String,
}

impl Compiler {
const OPT_VARIENTS: [&str; 2] = ["opt-14", "opt"];
const CLANG_VARIENTS: [&str; 2] = ["clang", "clang-14"];
fn find_opt() -> Result<&'static str> {
find_executable(&Compiler::OPT_VARIENTS)
}
fn find_clang() -> Result<&'static str> {
find_executable(&Compiler::CLANG_VARIENTS)
}
fn find_enzyme_lib() -> Result<String> {
let enzyme_lib_varients = ["LLVMEnzyme-14.so", "LLVMEnzyme-14.dylib"];
find_library_path(&enzyme_lib_varients)
}
pub fn from_discrete_model(model: &DiscreteModel, out: &str) -> Result<Self> {
let number_of_states = usize::try_from(
*model.state().shape().first().unwrap_or(&1)
Expand All @@ -75,15 +88,13 @@ impl Compiler {
let context = Context::create();
let number_of_parameters = input_names.iter().fold(0, |acc, name| acc + data_layout.get_data_length(name).unwrap());
let number_of_outputs = data_layout.get_data_length("out").unwrap();
let bitcode_filename = format!("{}.bc", out);
let bitcodefile = Path::new(bitcode_filename.as_str());
CompilerTryBuilder {
data_layout,
number_of_states,
number_of_parameters,
number_of_outputs,
context,
bitcode_filename: bitcode_filename.clone(),
output_base_filename: out.to_owned(),
data_builder: |context| {
let module = context.create_module(model.name());
let real_type = context.f64_type();
Expand All @@ -102,21 +113,20 @@ impl Compiler {
let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[CompileGradientArgType::Dup, CompileGradientArgType::Dup])?;
let _get_output = codegen.compile_get_tensor(model, "out")?;

let pre_enzyme_bitcodefilename = format!("{}.pre-enzyme.bc", out);
let pre_enzyme_bitcodefile = Path::new(pre_enzyme_bitcodefilename.as_str());
codegen.module().write_bitcode_to_path(pre_enzyme_bitcodefile);

let pre_enzyme_bitcodefilename = Compiler::get_pre_enzyme_bitcode_filename(out);
codegen.module().write_bitcode_to_path(Path::new(&pre_enzyme_bitcodefilename));

let opt_name_varients = ["opt-14"];
let opt_name = find_executable(&opt_name_varients)?;
let enzyme_lib_path = find_runtime_path(&["LLVMEnzyme-14.so"])?;
let enzyme_lib = Path::new(enzyme_lib_path.as_str()).join("LLVMEnzyme-14.so");
let opt_name = Compiler::find_opt()?;
let enzyme_lib = Compiler::find_enzyme_lib()?;

let bitcodefilename = Compiler::get_bitcode_filename(out);
let output = Command::new(opt_name)
.arg(pre_enzyme_bitcodefile.to_str().unwrap())
.arg(format!("-load={}", enzyme_lib.to_str().unwrap()))
.arg(pre_enzyme_bitcodefilename.as_str())
.arg(format!("-load={}", enzyme_lib))
.arg("-enzyme")
.arg("--enable-new-pm=0")
.arg("-o").arg(bitcodefile.to_str().unwrap())
.arg("-o").arg(bitcodefilename.as_str())
.output()?;

if let Some(code) = output.status.code() {
Expand All @@ -126,10 +136,10 @@ impl Compiler {
}
}

let buffer = MemoryBuffer::create_from_file(&bitcodefile).unwrap();
let buffer = MemoryBuffer::create_from_file(Path::new(bitcodefilename.as_str())).unwrap();
let module = Module::parse_bitcode_from_buffer(&buffer, context).map_err(|e| anyhow::anyhow!("Error parsing bitcode: {:?}", e))?;
let ee = module.create_jit_execution_engine(OptimizationLevel::None).map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?;

let set_u0 = Compiler::jit("set_u0", &ee)?;
let residual = Compiler::jit("residual", &ee)?;
let calc_out = Compiler::jit("calc_out", &ee)?;
Expand All @@ -142,33 +152,129 @@ impl Compiler {
let calc_out_grad = Compiler::jit("calc_out_grad", &ee)?;
let residual_grad = Compiler::jit("residual_grad", &ee)?;
let set_u0_grad = Compiler::jit("set_u0_grad", &ee)?;

Ok({
CompilerData {
codegen: codegen,
jit_functions: JitFunctions {
set_u0,
residual,
calc_out,
set_id,
get_dims,
set_inputs,
get_out,
},
jit_grad_functions: JitGradFunctions {
set_u0_grad,
residual_grad,
calc_out_grad,
set_inputs_grad,
},
}
})

let data = CompilerData {
codegen: codegen,
jit_functions: JitFunctions {
set_u0,
residual,
calc_out,
set_id,
get_dims,
set_inputs,
get_out,
},
jit_grad_functions: JitGradFunctions {
set_u0_grad,
residual_grad,
calc_out_grad,
set_inputs_grad,
},
};
Ok(data)
}
}.try_build()
}

pub fn compile(&self, standalone: bool, wasm: bool) -> Result<()> {

let opt_name = Compiler::find_opt()?;
let clang_name = Compiler::find_clang()?;
let enzyme_lib = Compiler::find_enzyme_lib()?;
let out = self.borrow_output_base_filename();
let object_filename = Compiler::get_object_filename(out);

let pre_enzyme_bitcodefilename = Compiler::get_pre_enzyme_bitcode_filename(out);
let output = Command::new(clang_name)
.arg(pre_enzyme_bitcodefilename.as_str())
.arg("-c")
.arg(format!("-fplugin={}", enzyme_lib))
.arg("-o").arg(object_filename.as_str())
.output()?;

if let Some(code) = output.status.code() {
if code != 0 {
println!("{}", String::from_utf8_lossy(&output.stderr));
return Err(anyhow!("{} returned error code {}", opt_name, code));
}
}

// link the object file and our runtime library
let output = if wasm {
let emcc_varients = ["emcc"];
let command_name = find_executable(&emcc_varients)?;
let exported_functions = vec![
"Vector_destroy",
"Vector_create",
"Vector_create_with_capacity",
"Vector_push",

"Options_destroy",
"Options_create",

"Sundials_destroy",
"Sundials_create",
"Sundials_init",
"Sundials_solve",
];
let mut linked_files = vec![
"libdiffeq_runtime_lib_wasm.a",
"libsundials_idas_wasm.a",
"libargparse_wasm.a",
];
if standalone {
linked_files.push("libdiffeq_runtime_wasm.a");
}
let linked_files = linked_files;
let runtime_path = find_runtime_path(&linked_files)?;
let mut command = Command::new(command_name);
command.arg("-o").arg(out).arg(object_filename.as_str());
for file in linked_files {
command.arg(Path::new(runtime_path.as_str()).join(file));
}
if !standalone {
let exported_functions = exported_functions.into_iter().map(|s| format!("_{}", s)).collect::<Vec<_>>().join(",");
command.arg("-s").arg(format!("EXPORTED_FUNCTIONS={}", exported_functions));
command.arg("--no-entry");
}
command.output()
} else {
let mut command = Command::new(clang_name);
command.arg("-o").arg(out).arg(out);
if standalone {
command.arg("-ldiffeq_runtime");
} else {
command.arg("-ldiffeq_runtime_lib");
}
command.output()
};

let output = match output {
Ok(output) => output,
Err(e) => {
return Err(anyhow!("Error linking in runtime: {}", e));
}
};

if let Some(code) = output.status.code() {
if code != 0 {
println!("{}", String::from_utf8_lossy(&output.stderr));
return Err(anyhow!("Error linking in runtime, returned error code {}", code));
}
}
Ok(())
}

fn get_pre_enzyme_bitcode_filename(out: &str) -> String {
format!("{}.pre-enzyme.bc", out)
}

fn get_bitcode_filename(out: &str) -> String {
format!("{}.bc", out)
}

pub fn get_bitcode_filename(&self) -> &str {
&self.borrow_bitcode_filename().as_str()
fn get_object_filename(out: &str) -> String {
format!("{}.o", out)
}

fn jit<'ctx, T>(name: &str, ee: &ExecutionEngine<'ctx>) -> Result<JitFunction<'ctx, T>>
Expand Down
88 changes: 4 additions & 84 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::{path::Path, ffi::OsStr, process::Command};
use std::{path::Path, ffi::OsStr};
use anyhow::{Result, anyhow};
use codegen::Compiler;
use continuous::ModelInfo;
use discretise::DiscreteModel;
use parser::{parse_ms_string, parse_ds_string};
use utils::{find_executable, find_runtime_path};

extern crate pest;
#[macro_use]
Expand All @@ -25,8 +24,6 @@ pub struct CompilerOptions {
}




pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: CompilerOptions) -> Result<()> {
let inputfile = Path::new(input);
let is_discrete = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "ds";
Expand All @@ -49,8 +46,6 @@ pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: Com
}

pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOptions, is_discrete: bool) -> Result<()> {
let CompilerOptions { bitcode_only, wasm, standalone } = options;

let is_continuous = !is_discrete;

let continuous_ast = if is_continuous {
Expand Down Expand Up @@ -92,87 +87,12 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp
panic!("No model found");
};
let compiler = Compiler::from_discrete_model(&discrete_model, out)?;

// if we are only compiling to bitcode , we are done
if bitcode_only {

if options.bitcode_only {
return Ok(());
}

let bitcode_filename = compiler.get_bitcode_filename();

// compile the bitcode to an object file or standalone wasm or executable
let emcc_varients = ["emcc"];
let clang_varients = ["clang", "clang-14"];
let command_name = if wasm {
find_executable(&emcc_varients)?
} else {
find_executable(&clang_varients)?
};


// link the object file and our runtime library
let output = if wasm {
let exported_functions = vec![
"Vector_destroy",
"Vector_create",
"Vector_create_with_capacity",
"Vector_push",

"Options_destroy",
"Options_create",

"Sundials_destroy",
"Sundials_create",
"Sundials_init",
"Sundials_solve",
];
let mut linked_files = vec![
"libdiffeq_runtime_lib_wasm.a",
"libsundials_idas_wasm.a",
"libargparse_wasm.a",
];
if standalone {
linked_files.push("libdiffeq_runtime_wasm.a");
}
let linked_files = linked_files;
let runtime_path = find_runtime_path(&linked_files)?;
let mut command = Command::new(command_name);
command.arg("-o").arg(out).arg(bitcode_filename);
for file in linked_files {
command.arg(Path::new(runtime_path.as_str()).join(file));
}
if !standalone {
let exported_functions = exported_functions.into_iter().map(|s| format!("_{}", s)).collect::<Vec<_>>().join(",");
command.arg("-s").arg(format!("EXPORTED_FUNCTIONS={}", exported_functions));
command.arg("--no-entry");
}
command.output()
} else {
let mut command = Command::new(command_name);
command.arg("-o").arg(out).arg(out);
if standalone {
command.arg("-ldiffeq_runtime");
} else {
command.arg("-ldiffeq_runtime_lib");
}
command.output()
};

let output = match output {
Ok(output) => output,
Err(e) => {
return Err(anyhow!("Error running {}: {}", command_name, e));
}
};

if let Some(code) = output.status.code() {
if code != 0 {
println!("{}", String::from_utf8_lossy(&output.stderr));
return Err(anyhow!("{} returned error code {}", command_name, code));
}
}

Ok(())
compiler.compile(options.standalone, options.wasm)
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit c74c1ba

Please sign in to comment.