Skip to content

Commit

Permalink
fix: ensure intrinsic modules are linked to program
Browse files Browse the repository at this point in the history
  • Loading branch information
bitwalker committed Oct 18, 2023
1 parent a8180bb commit dedea3a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 111 deletions.
25 changes: 25 additions & 0 deletions codegen/masm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ impl<'a> ProgramCompiler<'a> {
self.output.modules.push(masm_module);
}

// Ensure intrinsics modules are linked
self.output
.modules
.push(Module::load_intrinsic("intrinsics::mem").expect("parsing failed"));
self.output
.modules
.push(Module::load_intrinsic("intrinsics::i32").expect("parsing failed"));

Ok(self.output)
}

Expand Down Expand Up @@ -148,6 +156,23 @@ impl<'a> ProgramCompiler<'a> {
}
}

// If this module makes use of any intrinsics modules, add them to the program
for import in output
.imports
.iter()
.filter(|import| import.name.as_str().starts_with("intrinsics::"))
{
if self.output.contains(import.name) {
continue;
}
match Module::load_intrinsic(import.name.as_str()) {
Some(loaded) => {
self.output.modules.push(loaded);
}
None => unimplemented!("unrecognized intrinsic module: '{}'", &import.name),
}
}

// Removing a function via this cursor will move the cursor to
// the next function in the module. Once the end of the module
// is reached, the cursor will point to the null object, and
Expand Down
21 changes: 13 additions & 8 deletions codegen/masm/src/masm/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ const I32_INTRINSICS: &'static str =
const MEM_INTRINSICS: &'static str =
include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/intrinsics/mem.masm"));

/// This is a mapping of intrinsics module name to the raw MASM source for that module
const INTRINSICS: [(&'static str, &'static str); 2] = [
("intrinsics::i32", I32_INTRINSICS),
("intrinsics::mem", MEM_INTRINSICS),
];

/// This represents a single compiled Miden Assembly module in a form that is
/// designed to integrate well with the rest of our IR. You can think of this
/// as an intermediate representation corresponding to the Miden Assembly AST,
Expand Down Expand Up @@ -225,13 +231,12 @@ impl fmt::Display for Module {
}

impl Module {
/// This is a helper that parses and returns the predefined `intrinsics::mem` module
pub fn mem_intrinsics() -> Self {
Self::parse_str(MEM_INTRINSICS, "intrinsics::mem").expect("invalid module")
}

/// This is a helper that parses and returns the predefined `intrinsics::i32` module
pub fn i32_intrinsics() -> Self {
Self::parse_str(I32_INTRINSICS, "intrinsics::i32").expect("invalid module")
/// This helper loads the named module from the set of intrinsics modules defined in this crate.
///
/// Expects the fully-qualified name to be given, e.g. `intrinsics::mem`
pub fn load_intrinsic<N: AsRef<str>>(name: N) -> Option<Self> {
let name = name.as_ref();
let (_, source) = INTRINSICS.iter().find(|(n, _)| *n == name)?;
Some(Self::parse_str(source, name).expect("invalid module"))
}
}
8 changes: 8 additions & 0 deletions codegen/masm/src/masm/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ impl Program {
self.entrypoint.is_none()
}

/// Returns true if this program contains a [Module] named `name`
pub fn contains<N>(&self, name: N) -> bool
where
Ident: PartialEq<N>,
{
self.modules.iter().any(|m| m.name == name)
}

/// Write this [Program] to the given output directory.
///
/// The provided [miden_diagnostics::CodeMap] is used for computing source locations.
Expand Down
136 changes: 35 additions & 101 deletions codegen/masm/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use miden_hir::{
SourceSpan, Stack, StarkField, Type,
};
use miden_hir_analysis::FunctionAnalysis;
use std::fmt::Write;

use super::*;

Expand Down Expand Up @@ -151,6 +150,18 @@ impl TestByEmulationHarness {
.expect("failed to load module");
self.emulator.invoke(entrypoint, args)
}

pub fn execute_program(
&mut self,
program: Program,
args: &[Felt],
) -> Result<OperandStack<Felt>, EmulationError> {
let entrypoint = program.entrypoint.expect("cannot execute a library");
self.emulator
.load_program(program)
.expect("failed to load program");
self.emulator.invoke(entrypoint, args)
}
}

/// Test the emulator on the fibonacci function
Expand All @@ -163,7 +174,7 @@ fn fib_emulator() {

// Build test module with fib function
let mut mb = builder.module("test");
let id = testing::fib1(mb.as_mut(), &harness.context);
testing::fib1(mb.as_mut(), &harness.context);
mb.build()
.expect("unexpected error constructing test module");

Expand All @@ -173,36 +184,13 @@ fn fib_emulator() {
.link()
.expect("failed to link program");

// Get the fib function
let (mut function, imports) = {
let modules = program.modules_mut();
let mut test = modules.find_mut("test").remove().expect("undefined module");
let function = test
.cursor_mut_at(id.function)
.remove()
.expect("undefined function");
let imports = test.imports();
modules.insert(test);
(function, imports)
};

let masm = harness
.stackify(&program, &mut function)
.expect("stackification failed");

let mut output = String::with_capacity(1024);
write!(&mut output, "{}", masm.display(&imports)).expect("formatting failed");

println!("{}", output.as_str());

let mut module = Module::new(id.module);
module.functions.push_back(masm);
module.entry = Some(id);
let mut compiler = MasmCompiler::new(&harness.context.diagnostics);
let program = compiler.compile(&mut program).expect("compilation failed");

// Test it via the emulator
let n = Felt::new(10);
let mut stack = harness
.execute_module(module, &[n])
.execute_program(program, &[n])
.expect("execution failed");
assert_eq!(stack.len(), 1);
assert_eq!(stack.pop().map(|e| e.as_int()), Some(55));
Expand Down Expand Up @@ -261,31 +249,14 @@ fn stackify_fundamental_if() {
.link()
.expect("failed to link program");

// Get the sum_matrix function
let mut function = {
let modules = program.modules_mut();
let mut test = modules.find_mut("test").remove().expect("undefined module");
let function = test
.cursor_mut_at(id.function)
.remove()
.expect("undefined function");
modules.insert(test);
function
};

let masm = harness
.stackify(&program, &mut function)
.expect("stackification failed");

let mut module = Module::new(id.module);
module.functions.push_back(masm);
module.entry = Some(id);
let mut compiler = MasmCompiler::new(&harness.context.diagnostics);
let program = compiler.compile(&mut program).expect("compilation failed");

let a = Felt::new(3);
let b = Felt::new(4);

let mut stack = harness
.execute_module(module, &[a, b])
.execute_program(program, &[a, b])
.expect("execution failed");
assert_eq!(stack.len(), 1);
assert_eq!(stack.pop().map(|e| e.as_int()), Some(12));
Expand Down Expand Up @@ -356,37 +327,14 @@ fn stackify_fundamental_loops() {
.link()
.expect("failed to link program");

// Get the sum_matrix function
let (mut function, imports) = {
let modules = program.modules_mut();
let mut test = modules.find_mut("test").remove().expect("undefined module");
let function = test
.cursor_mut_at(id.function)
.remove()
.expect("undefined function");
let imports = test.imports();
modules.insert(test);
(function, imports)
};

let masm = harness
.stackify(&program, &mut function)
.expect("stackification failed");

let mut output = String::with_capacity(1024);
write!(&mut output, "{}", masm.display(&imports)).expect("formatting failed");

println!("{}", output.as_str());

let mut module = Module::new(id.module);
module.functions.push_back(masm);
module.entry = Some(id);
let mut compiler = MasmCompiler::new(&harness.context.diagnostics);
let program = compiler.compile(&mut program).expect("compilation failed");

let a = Felt::new(3);
let n = Felt::new(4);

let mut stack = harness
.execute_module(module, &[a, n])
.execute_program(program, &[a, n])
.expect("execution failed");
assert_eq!(stack.len(), 1);
assert_eq!(stack.pop().map(|e| e.as_int()), Some(7));
Expand All @@ -399,7 +347,7 @@ fn verify_i32_intrinsics_syntax() {

harness
.emulator
.load_module(Module::i32_intrinsics())
.load_module(Module::load_intrinsic("intrinsics::i32").expect("parsing failed"))
.expect("failed to load intrinsics::i32");
}

Expand All @@ -408,17 +356,17 @@ fn verify_i32_intrinsics_syntax() {
fn stackify_sum_matrix() {
let mut harness = TestByEmulationHarness::default();

harness
.emulator
.load_module(Module::mem_intrinsics())
.expect("failed to load intrinsics::mem");
//harness
//.emulator
//.load_module(Module::mem_intrinsics())
//.expect("failed to load intrinsics::mem");

// Build a simple program
let mut builder = ProgramBuilder::new(&harness.context.diagnostics);

// Build test module with fib function
let mut mb = builder.module("test");
let id = testing::sum_matrix(mb.as_mut(), &harness.context);
testing::sum_matrix(mb.as_mut(), &harness.context);
mb.build()
.expect("unexpected error constructing test module");

Expand All @@ -428,27 +376,13 @@ fn stackify_sum_matrix() {
.link()
.expect("failed to link program");

// Get the sum_matrix function
let (mut function, _imports) = {
let modules = program.modules_mut();
let mut test = modules.find_mut("test").remove().expect("undefined module");
let function = test
.cursor_mut_at(id.function)
.remove()
.expect("undefined function");
let imports = test.imports();
modules.insert(test);
(function, imports)
};

let masm = harness
.stackify(&program, &mut function)
.expect("stackification failed");
// Compile
let mut compiler = MasmCompiler::new(&harness.context.diagnostics);
let program = compiler.compile(&mut program).expect("compilation failed");

let mut module = Module::new(id.module);
module.functions.push_back(masm);
module.entry = Some(id);
dbg!(program.modules.iter().map(|m| m.name).collect::<Vec<_>>());

// Prep emulator
let addr = harness.malloc(core::mem::size_of::<u32>() * 3 * 3);
let ptr = Felt::new(addr as u64);
let rows = Felt::new(3);
Expand All @@ -467,11 +401,11 @@ fn stackify_sum_matrix() {
harness.store(addr + 24, Felt::ONE);
harness.store(addr + 28, Felt::ONE);
harness.store(addr + 32, Felt::ONE);

harness.set_cycle_budget(1000);

// Execute test::sum_matrix
let mut stack = harness
.execute_module(module, &[ptr, rows, cols])
.execute_program(program, &[ptr, rows, cols])
.expect("execution failed");
assert_eq!(stack.len(), 1);
assert_eq!(stack.pop().map(|e| e.as_int()), Some(6));
Expand Down
5 changes: 3 additions & 2 deletions hir/src/program/linker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,17 @@ impl Linker {
// If the module is pending, it is being linked
let is_linked = self.pending.contains_key(&node.module);
let is_stdlib = node.module.as_str().starts_with("std::");
let is_intrinsic = node.module.as_str().starts_with("intrinsics::");

// If a referenced module is not being linked, raise an error
if !is_linked {
// However we ignore standard library modules in this check,
// However we ignore standard library/intrinsic modules in this check,
// as they are known to be provided at runtime.
//
// TODO: We need to validate that the given module/function
// is actually in the standard library though, and that the
// signature matches what is expected.
if is_stdlib {
if is_stdlib || is_intrinsic {
continue;
}

Expand Down

0 comments on commit dedea3a

Please sign in to comment.