From 3e1048ad378941430b37103114fe02da177daf60 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 25 Aug 2023 22:59:55 +0000 Subject: [PATCH] factor our compile function --- src/bin/diffeq.rs | 83 +++--------------------------- src/codegen/codegen.rs | 4 +- src/codegen/compiler.rs | 43 +++++++++++++--- src/lib.rs | 109 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 84 deletions(-) diff --git a/src/bin/diffeq.rs b/src/bin/diffeq.rs index 66e6dc3..35f9677 100644 --- a/src/bin/diffeq.rs +++ b/src/bin/diffeq.rs @@ -1,9 +1,6 @@ - -use std::{path::Path, ffi::OsStr, process::Command}; - use clap::Parser; -use anyhow::{Result, anyhow}; -use diffeq::{parser::{parse_ms_string, parse_ds_string}, continuous::ModelInfo, discretise::DiscreteModel, codegen::Compiler}; +use anyhow::Result; +use diffeq::compile; /// compiles a model in continuous (.cs) or discrete (.ds) format to an object file #[derive(Parser, Debug)] @@ -23,80 +20,14 @@ struct Args { /// Compile object file only #[arg(short, long)] compile: bool, + + /// Compile to WASM + #[arg(short, long)] + wasm: bool, } fn main() -> Result<()> { let cli = Args::parse(); - - let inputfile = Path::new(&cli.input); - let out = if let Some(out) = cli.out { - out.clone() - } else if cli.compile { - "out.o".to_owned() - } else { - "out".to_owned() - }; - - let objectname = if cli.compile { out.clone() } else { format!("{}.o", out) }; - let objectfile = Path::new(objectname.as_str()); - let is_discrete = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "ds"; - let is_continuous = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "cs"; - if !is_discrete && !is_continuous { - panic!("Input file must have extension .ds or .cs"); - } - let text = std::fs::read_to_string(inputfile)?; - if is_continuous { - let models = parse_ms_string(text.as_str())?; - let model_name = if let Some(model_name) = cli.model { - model_name - } else { - return Err(anyhow!("Model name must be specified for continuous models")); - }; - let model_info = ModelInfo::build(model_name.as_str(), &models).map_err(|e| anyhow!("{}", e))?; - if model_info.errors.len() > 0 { - for error in model_info.errors { - println!("{}", error.as_error_message(text.as_str())); - } - return Err(anyhow!("Errors in model")); - } - let model = DiscreteModel::from(&model_info); - let compiler = Compiler::from_discrete_model(&model)?; - compiler.write_object_file(objectfile)?; - } else { - let model = parse_ds_string(text.as_str())?; - let model = match DiscreteModel::build(&cli.input, &model) { - Ok(model) => model, - Err(e) => { - println!("{}", e.as_error_message(text.as_str())); - return Err(anyhow!("Errors in model")); - } - }; - let compiler = Compiler::from_discrete_model(&model)?; - compiler.write_object_file(objectfile)?; - }; - - if cli.compile { - return Ok(()); - } - - // compile the object file using clang and our runtime library - let output = Command::new("clang") - .arg("-o") - .arg(out) - .arg(objectname.clone()) - .arg("-ldiffeq_runtime") - .output()?; - - // clean up the object file - std::fs::remove_file(objectfile)?; - - if let Some(code) = output.status.code() { - if code != 0 { - println!("{}", String::from_utf8_lossy(&output.stderr)); - return Err(anyhow!("clang returned error code {}", code)); - } - } - Ok(()) - + compile(&cli.input, cli.out.as_deref(), cli.model.as_deref(), cli.compile, cli.wasm) } \ No newline at end of file diff --git a/src/codegen/codegen.rs b/src/codegen/codegen.rs index 9ed5d38..fb43076 100644 --- a/src/codegen/codegen.rs +++ b/src/codegen/codegen.rs @@ -792,10 +792,10 @@ impl<'ctx> CodeGen<'ctx> { let int_ptr_type = self.context.i32_type().ptr_type(AddressSpace::default()); let void_type = self.context.void_type(); let fn_type = void_type.fn_type( - &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into(), int_ptr_type.into(), real_ptr_type.into()] + &[self.real_type.into(), real_ptr_type.into(), real_ptr_type.into(), real_ptr_type.into(), int_ptr_type.into()] , false ); - let fn_arg_names = &["t", "u", "dudt", "data", "indices", "out"]; + let fn_arg_names = &["t", "u", "dudt", "data", "indices"]; let function = self.module.add_function("calc_out", fn_type, None); let basic_block = self.context.append_basic_block(function, "entry"); self.fn_value_opt = Some(function); diff --git a/src/codegen/compiler.rs b/src/codegen/compiler.rs index dbe7849..121b425 100644 --- a/src/codegen/compiler.rs +++ b/src/codegen/compiler.rs @@ -2,6 +2,7 @@ use std::path::Path; use anyhow::anyhow; use anyhow::Result; +use inkwell::targets::TargetMachine; use inkwell::{context::Context, OptimizationLevel, targets::{TargetTriple, InitializationConfig, Target, RelocMode, CodeModel, FileType}, execution_engine::{JitFunction, ExecutionEngine, UnsafeFunctionPointer}}; use ouroboros::self_referencing; use crate::discretise::DiscreteModel; @@ -225,23 +226,53 @@ impl Compiler { }) } - pub fn write_object_file(&self, path: &Path) -> Result<()> { - Target::initialize_x86(&InitializationConfig::default()); + fn get_native_machine() -> Result { + Target::initialize_native(&InitializationConfig::default()).map_err(|e| anyhow!("{}", e))?; + let opt = OptimizationLevel::Default; + let reloc = RelocMode::Default; + let model = CodeModel::Default; + let target_triple = TargetMachine::get_default_triple(); + let target = Target::from_triple(&target_triple).unwrap(); + let target_machine = target.create_target_machine( + &target_triple, + TargetMachine::get_host_cpu_name().to_str().unwrap(), + TargetMachine::get_host_cpu_features().to_str().unwrap(), + opt, + reloc, + model + ) + .unwrap(); + Ok(target_machine) + } + fn get_wasm_machine() -> Result { + Target::initialize_webassembly(&InitializationConfig::default()); let opt = OptimizationLevel::Default; let reloc = RelocMode::Default; let model = CodeModel::Default; - let target = Target::from_name("x86-64").unwrap(); + let target_triple = TargetTriple::create("wasm32-unknown-emscripten"); + let target = Target::from_triple(&target_triple).unwrap(); let target_machine = target.create_target_machine( - &TargetTriple::create("x86_64-pc-linux-gnu"), - "x86-64", - "+avx2", + &target_triple, + "generic", + "", opt, reloc, model ) .unwrap(); + Ok(target_machine) + } + + pub fn write_object_file(&self, path: &Path) -> Result<()> { + let target_machine = Compiler::get_native_machine()?; + self.with_data(|data| + target_machine.write_to_file(data.codegen.module(), FileType::Object, &path).map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e)) + ) + } + pub fn write_wasm_object_file(&self, path: &Path) -> Result<()> { + let target_machine = Compiler::get_wasm_machine()?; self.with_data(|data| target_machine.write_to_file(data.codegen.module(), FileType::Object, &path).map_err(|e| anyhow::anyhow!("Error writing object file: {:?}", e)) ) diff --git a/src/lib.rs b/src/lib.rs index 306fbc4..ba3246a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,10 @@ +use std::{path::Path, ffi::OsStr, process::Command}; +use anyhow::{Result, anyhow}; +use codegen::Compiler; +use continuous::ModelInfo; +use discretise::DiscreteModel; +use parser::{parse_ms_string, parse_ds_string}; + extern crate pest; #[macro_use] extern crate pest_derive; @@ -9,5 +16,107 @@ pub mod continuous; pub mod codegen; +pub fn compile(input: &str, out: Option<&str>, model: Option<&str>, compile: bool, wasm: bool) -> Result<()> { + let inputfile = Path::new(input); + let out = if let Some(out) = out { + out.clone() + } else if compile { + "out.o" + } else { + "out" + }; + + let objectname = if compile { out.to_owned() } else { format!("{}.o", out) }; + let objectfile = Path::new(objectname.as_str()); + let is_discrete = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "ds"; + let is_continuous = inputfile.extension().unwrap_or(OsStr::new("")).to_str().unwrap() == "cs"; + if !is_discrete && !is_continuous { + panic!("Input file must have extension .ds or .cs"); + } + + let model_name = if is_continuous { + if let Some(model_name) = model { + model_name + } else { + return Err(anyhow!("Model name must be specified for continuous models")); + } + } else { + "" + }; + + let text = std::fs::read_to_string(inputfile)?; + let continuous_ast = if is_continuous { + Some(parse_ms_string(text.as_str())?) + } else { + None + }; + + let discrete_ast = if is_discrete { + Some(parse_ds_string(text.as_str())?) + } else { + None + }; + + let continuous_model_info = if let Some(ast) = &continuous_ast { + let model_info = ModelInfo::build(model_name, ast).map_err(|e| anyhow!("{}", e))?; + if model_info.errors.len() > 0 { + for error in model_info.errors { + println!("{}", error.as_error_message(text.as_str())); + } + return Err(anyhow!("Errors in model")); + } + Some(model_info) + } else { + None + }; + + let discrete_model = if let Some(model_info) = &continuous_model_info { + let model = DiscreteModel::from(&model_info); + model + } else if let Some(ast) = &discrete_ast { + match DiscreteModel::build(input, ast) { + Ok(model) => model, + Err(e) => { + println!("{}", e.as_error_message(text.as_str())); + return Err(anyhow!("Errors in model")); + } + } + } else { + panic!("No model found"); + }; + + let compiler = Compiler::from_discrete_model(&discrete_model)?; + + if wasm { + compiler.write_wasm_object_file(objectfile)?; + } else { + compiler.write_object_file(objectfile)?; + } + + if compile { + return Ok(()); + } + + // compile the object file using clang and our runtime library + let output = Command::new("clang") + .arg("-o") + .arg(out) + .arg(objectname.clone()) + .arg("-ldiffeq_runtime") + .output()?; + + // clean up the object file + std::fs::remove_file(objectfile)?; + + if let Some(code) = output.status.code() { + if code != 0 { + println!("{}", String::from_utf8_lossy(&output.stderr)); + return Err(anyhow!("clang returned error code {}", code)); + } + } + Ok(()) +} + +