diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index bc53fab..199c189 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -9,6 +9,7 @@ use inkwell::{context::Context, OptimizationLevel, targets::{TargetTriple, Initi use ouroboros::self_referencing; use crate::discretise::DiscreteModel; use crate::utils::find_executable; +use crate::utils::find_library_path; use crate::utils::find_runtime_path; use std::process::Command; @@ -62,10 +63,22 @@ pub struct Compiler { number_of_parameters: usize, number_of_outputs: usize, data_layout: DataLayout, - bitcode_filename: String, + output_base_filename: String, } impl Compiler { + const OPT_VARIENTS: [&str; 2] = ["opt-14", "opt"]; + const CLANG_VARIENTS: [&str; 2] = ["clang", "clang-14"]; + fn find_opt() -> Result<&'static str> { + find_executable(&Compiler::OPT_VARIENTS) + } + fn find_clang() -> Result<&'static str> { + find_executable(&Compiler::CLANG_VARIENTS) + } + fn find_enzyme_lib() -> Result { + let enzyme_lib_varients = ["LLVMEnzyme-14.so", "LLVMEnzyme-14.dylib"]; + find_library_path(&enzyme_lib_varients) + } pub fn from_discrete_model(model: &DiscreteModel, out: &str) -> Result { let number_of_states = usize::try_from( *model.state().shape().first().unwrap_or(&1) @@ -75,15 +88,13 @@ impl Compiler { let context = Context::create(); let number_of_parameters = input_names.iter().fold(0, |acc, name| acc + data_layout.get_data_length(name).unwrap()); let number_of_outputs = data_layout.get_data_length("out").unwrap(); - let bitcode_filename = format!("{}.bc", out); - let bitcodefile = Path::new(bitcode_filename.as_str()); CompilerTryBuilder { data_layout, number_of_states, number_of_parameters, number_of_outputs, context, - bitcode_filename: bitcode_filename.clone(), + output_base_filename: out.to_owned(), data_builder: |context| { let module = context.create_module(model.name()); let real_type = context.f64_type(); @@ -102,21 +113,20 @@ impl Compiler { let _set_inputs_grad = codegen.compile_gradient(_set_inputs, &[CompileGradientArgType::Dup, CompileGradientArgType::Dup])?; let _get_output = codegen.compile_get_tensor(model, "out")?; - let pre_enzyme_bitcodefilename = format!("{}.pre-enzyme.bc", out); - let pre_enzyme_bitcodefile = Path::new(pre_enzyme_bitcodefilename.as_str()); - codegen.module().write_bitcode_to_path(pre_enzyme_bitcodefile); + + let pre_enzyme_bitcodefilename = Compiler::get_pre_enzyme_bitcode_filename(out); + codegen.module().write_bitcode_to_path(Path::new(&pre_enzyme_bitcodefilename)); - let opt_name_varients = ["opt-14"]; - let opt_name = find_executable(&opt_name_varients)?; - let enzyme_lib_path = find_runtime_path(&["LLVMEnzyme-14.so"])?; - let enzyme_lib = Path::new(enzyme_lib_path.as_str()).join("LLVMEnzyme-14.so"); + let opt_name = Compiler::find_opt()?; + let enzyme_lib = Compiler::find_enzyme_lib()?; + let bitcodefilename = Compiler::get_bitcode_filename(out); let output = Command::new(opt_name) - .arg(pre_enzyme_bitcodefile.to_str().unwrap()) - .arg(format!("-load={}", enzyme_lib.to_str().unwrap())) + .arg(pre_enzyme_bitcodefilename.as_str()) + .arg(format!("-load={}", enzyme_lib)) .arg("-enzyme") .arg("--enable-new-pm=0") - .arg("-o").arg(bitcodefile.to_str().unwrap()) + .arg("-o").arg(bitcodefilename.as_str()) .output()?; if let Some(code) = output.status.code() { @@ -126,10 +136,10 @@ impl Compiler { } } - let buffer = MemoryBuffer::create_from_file(&bitcodefile).unwrap(); + let buffer = MemoryBuffer::create_from_file(Path::new(bitcodefilename.as_str())).unwrap(); let module = Module::parse_bitcode_from_buffer(&buffer, context).map_err(|e| anyhow::anyhow!("Error parsing bitcode: {:?}", e))?; let ee = module.create_jit_execution_engine(OptimizationLevel::None).map_err(|e| anyhow::anyhow!("Error creating execution engine: {:?}", e))?; - + let set_u0 = Compiler::jit("set_u0", &ee)?; let residual = Compiler::jit("residual", &ee)?; let calc_out = Compiler::jit("calc_out", &ee)?; @@ -142,33 +152,129 @@ impl Compiler { let calc_out_grad = Compiler::jit("calc_out_grad", &ee)?; let residual_grad = Compiler::jit("residual_grad", &ee)?; let set_u0_grad = Compiler::jit("set_u0_grad", &ee)?; - - Ok({ - CompilerData { - codegen: codegen, - jit_functions: JitFunctions { - set_u0, - residual, - calc_out, - set_id, - get_dims, - set_inputs, - get_out, - }, - jit_grad_functions: JitGradFunctions { - set_u0_grad, - residual_grad, - calc_out_grad, - set_inputs_grad, - }, - } - }) + + let data = CompilerData { + codegen: codegen, + jit_functions: JitFunctions { + set_u0, + residual, + calc_out, + set_id, + get_dims, + set_inputs, + get_out, + }, + jit_grad_functions: JitGradFunctions { + set_u0_grad, + residual_grad, + calc_out_grad, + set_inputs_grad, + }, + }; + Ok(data) } }.try_build() } + + pub fn compile(&self, standalone: bool, wasm: bool) -> Result<()> { + + let opt_name = Compiler::find_opt()?; + let clang_name = Compiler::find_clang()?; + let enzyme_lib = Compiler::find_enzyme_lib()?; + let out = self.borrow_output_base_filename(); + let object_filename = Compiler::get_object_filename(out); + + let pre_enzyme_bitcodefilename = Compiler::get_pre_enzyme_bitcode_filename(out); + let output = Command::new(clang_name) + .arg(pre_enzyme_bitcodefilename.as_str()) + .arg("-c") + .arg(format!("-fplugin={}", enzyme_lib)) + .arg("-o").arg(object_filename.as_str()) + .output()?; + + if let Some(code) = output.status.code() { + if code != 0 { + println!("{}", String::from_utf8_lossy(&output.stderr)); + return Err(anyhow!("{} returned error code {}", opt_name, code)); + } + } + + // link the object file and our runtime library + let output = if wasm { + let emcc_varients = ["emcc"]; + let command_name = find_executable(&emcc_varients)?; + let exported_functions = vec![ + "Vector_destroy", + "Vector_create", + "Vector_create_with_capacity", + "Vector_push", + + "Options_destroy", + "Options_create", + + "Sundials_destroy", + "Sundials_create", + "Sundials_init", + "Sundials_solve", + ]; + let mut linked_files = vec![ + "libdiffeq_runtime_lib_wasm.a", + "libsundials_idas_wasm.a", + "libargparse_wasm.a", + ]; + if standalone { + linked_files.push("libdiffeq_runtime_wasm.a"); + } + let linked_files = linked_files; + let runtime_path = find_runtime_path(&linked_files)?; + let mut command = Command::new(command_name); + command.arg("-o").arg(out).arg(object_filename.as_str()); + for file in linked_files { + command.arg(Path::new(runtime_path.as_str()).join(file)); + } + if !standalone { + let exported_functions = exported_functions.into_iter().map(|s| format!("_{}", s)).collect::>().join(","); + command.arg("-s").arg(format!("EXPORTED_FUNCTIONS={}", exported_functions)); + command.arg("--no-entry"); + } + command.output() + } else { + let mut command = Command::new(clang_name); + command.arg("-o").arg(out).arg(out); + if standalone { + command.arg("-ldiffeq_runtime"); + } else { + command.arg("-ldiffeq_runtime_lib"); + } + command.output() + }; + + let output = match output { + Ok(output) => output, + Err(e) => { + return Err(anyhow!("Error linking in runtime: {}", e)); + } + }; + + if let Some(code) = output.status.code() { + if code != 0 { + println!("{}", String::from_utf8_lossy(&output.stderr)); + return Err(anyhow!("Error linking in runtime, returned error code {}", code)); + } + } + Ok(()) + } + + fn get_pre_enzyme_bitcode_filename(out: &str) -> String { + format!("{}.pre-enzyme.bc", out) + } + + fn get_bitcode_filename(out: &str) -> String { + format!("{}.bc", out) + } - pub fn get_bitcode_filename(&self) -> &str { - &self.borrow_bitcode_filename().as_str() + fn get_object_filename(out: &str) -> String { + format!("{}.o", out) } fn jit<'ctx, T>(name: &str, ee: &ExecutionEngine<'ctx>) -> Result> diff --git a/src/lib.rs b/src/lib.rs index 12cdc5f..4d07301 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,9 @@ -use std::{path::Path, ffi::OsStr, process::Command}; +use std::{path::Path, ffi::OsStr}; use anyhow::{Result, anyhow}; use codegen::Compiler; use continuous::ModelInfo; use discretise::DiscreteModel; use parser::{parse_ms_string, parse_ds_string}; -use utils::{find_executable, find_runtime_path}; extern crate pest; #[macro_use] @@ -25,8 +24,6 @@ pub struct CompilerOptions { } - - pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: CompilerOptions) -> Result<()> { let inputfile = Path::new(input); let is_discrete = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "ds"; @@ -49,8 +46,6 @@ pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, options: Com } pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOptions, is_discrete: bool) -> Result<()> { - let CompilerOptions { bitcode_only, wasm, standalone } = options; - let is_continuous = !is_discrete; let continuous_ast = if is_continuous { @@ -92,87 +87,12 @@ pub fn compile_text(text: &str, out: &str, model_name: &str, options: CompilerOp panic!("No model found"); }; let compiler = Compiler::from_discrete_model(&discrete_model, out)?; - - // if we are only compiling to bitcode , we are done - if bitcode_only { + + if options.bitcode_only { return Ok(()); } - - let bitcode_filename = compiler.get_bitcode_filename(); - // compile the bitcode to an object file or standalone wasm or executable - let emcc_varients = ["emcc"]; - let clang_varients = ["clang", "clang-14"]; - let command_name = if wasm { - find_executable(&emcc_varients)? - } else { - find_executable(&clang_varients)? - }; - - - // link the object file and our runtime library - let output = if wasm { - let exported_functions = vec![ - "Vector_destroy", - "Vector_create", - "Vector_create_with_capacity", - "Vector_push", - - "Options_destroy", - "Options_create", - - "Sundials_destroy", - "Sundials_create", - "Sundials_init", - "Sundials_solve", - ]; - let mut linked_files = vec![ - "libdiffeq_runtime_lib_wasm.a", - "libsundials_idas_wasm.a", - "libargparse_wasm.a", - ]; - if standalone { - linked_files.push("libdiffeq_runtime_wasm.a"); - } - let linked_files = linked_files; - let runtime_path = find_runtime_path(&linked_files)?; - let mut command = Command::new(command_name); - command.arg("-o").arg(out).arg(bitcode_filename); - for file in linked_files { - command.arg(Path::new(runtime_path.as_str()).join(file)); - } - if !standalone { - let exported_functions = exported_functions.into_iter().map(|s| format!("_{}", s)).collect::>().join(","); - command.arg("-s").arg(format!("EXPORTED_FUNCTIONS={}", exported_functions)); - command.arg("--no-entry"); - } - command.output() - } else { - let mut command = Command::new(command_name); - command.arg("-o").arg(out).arg(out); - if standalone { - command.arg("-ldiffeq_runtime"); - } else { - command.arg("-ldiffeq_runtime_lib"); - } - command.output() - }; - - let output = match output { - Ok(output) => output, - Err(e) => { - return Err(anyhow!("Error running {}: {}", command_name, e)); - } - }; - - if let Some(code) = output.status.code() { - if code != 0 { - println!("{}", String::from_utf8_lossy(&output.stderr)); - return Err(anyhow!("{} returned error code {}", command_name, code)); - } - } - - Ok(()) + compiler.compile(options.standalone, options.wasm) } #[cfg(test)] diff --git a/src/utils.rs b/src/utils.rs index c79e951..9d8cf2b 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -14,6 +14,7 @@ fn is_executable_on_path(executable_name: &str) -> bool { output.status.success() } + pub fn find_executable<'a>(varients: &[&'a str]) -> Result<&'a str> { let mut command = None; for varient in varients { @@ -37,6 +38,7 @@ pub fn find_runtime_path(libraries: &[&str] ) -> Result { let mut found = true; for library in libraries { let library_path = Path::new(path).join(library); + println!("Checking {:?}", library_path); if !library_path.exists() { found = false; break; @@ -47,4 +49,20 @@ pub fn find_runtime_path(libraries: &[&str] ) -> Result { } } Err(anyhow!("Could not find {:?} in LIBRARY_PATH", libraries)) +} + +pub fn find_library_path(varients: &[& str]) -> Result { + let library_paths_env = env::var("LIBRARY_PATH").unwrap_or("".to_owned()); + let library_paths = library_paths_env.split(":").collect::>(); + for path in library_paths { + // check if one of the varients is in the path + for varient in varients { + let library_path = Path::new(path).join(varient); + if library_path.exists() { + let filename = library_path.as_os_str().to_str().unwrap().to_owned(); + return Ok(filename); + } + } + } + Err(anyhow!("Could not find any of {:?} in LIBRARY_PATH", varients)) } \ No newline at end of file