Skip to content

Commit

Permalink
Merge pull request #81 from NLnetLabs/global-constants
Browse files Browse the repository at this point in the history
Global constants defined by the runtime
  • Loading branch information
tertsdiepraam authored Dec 6, 2024
2 parents 2b3b86e + d294062 commit 6e193b1
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 84 deletions.
4 changes: 2 additions & 2 deletions examples/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ fn main() -> Result<(), roto::RotoReport> {

let mut arguments = args();
let _program_name = arguments.next().unwrap();

let subcommand = arguments.next();
if Some("doc") == subcommand.as_deref() {
runtime.print_documentation();
return Ok(())
return Ok(());
}

let mut compiled = read_files(["examples/simple.roto"])?
Expand Down
50 changes: 45 additions & 5 deletions src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ use crate::{
value::IrType,
IrFunction,
},
runtime::ty::{Reflect, GLOBAL_TYPE_REGISTRY},
runtime::{
ty::{Reflect, GLOBAL_TYPE_REGISTRY},
RuntimeConstant,
},
typechecker::{info::TypeInfo, scope::ScopeRef, types},
IrValue,
};
Expand All @@ -36,7 +39,7 @@ use cranelift::{
Variable,
},
jit::{JITBuilder, JITModule},
module::{DataDescription, FuncId, Linkage, Module as _},
module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module as _},
prelude::Signature,
};
use cranelift_codegen::ir::SigRef;
Expand All @@ -47,7 +50,7 @@ pub mod check;
mod tests;

/// A wrapper around a cranelift [`JITModule`] that cleans up after itself
///
///
/// This is achieved by wrapping the module in an [`Arc`].
#[derive(Clone)]
pub struct ModuleData(Arc<ManuallyDrop<JITModule>>);
Expand Down Expand Up @@ -105,10 +108,10 @@ pub struct Module {
}

/// A function extracted from Roto
///
///
/// A [`TypedFunc`] can be retrieved from a compiled script using
/// [`Compiled::get_function`](crate::Compiled::get_function).
///
///
/// The function can be called with one of the [`TypedFunc::call`] functions.
#[derive(Clone, Debug)]
pub struct TypedFunc<Params, Return> {
Expand Down Expand Up @@ -226,6 +229,7 @@ const MEMFLAGS: MemFlags = MemFlags::new().with_aligned();
pub fn codegen(
ir: &[ir::Function],
runtime_functions: &HashMap<usize, IrFunction>,
constants: &[RuntimeConstant],
label_store: LabelStore,
type_info: TypeInfo,
) -> Module {
Expand Down Expand Up @@ -274,6 +278,10 @@ pub fn codegen(
clone_signature,
};

for constant in constants {
module.declare_constant(constant);
}

for (roto_func_id, func) in runtime_functions {
let mut sig = module.inner.make_signature();
for ty in &func.params {
Expand Down Expand Up @@ -308,6 +316,23 @@ pub fn codegen(
}

impl ModuleBuilder {
fn declare_constant(&mut self, constant: &RuntimeConstant) {
let data_id = self
.inner
.declare_data(
constant.name.as_str(),
Linkage::Local,
false,
false,
)
.unwrap();

let mut description = DataDescription::new();
description.define(constant.bytes.clone());

self.inner.define_data(data_id, &description).unwrap();
}

/// Declare a function and its signature (without the body)
fn declare_function(&mut self, func: &ir::Function) {
let ir::Function {
Expand Down Expand Up @@ -837,6 +862,21 @@ impl<'c> FuncGen<'c> {
let var = self.variable(to, I32);
self.def(var, val);
}
ir::Instruction::LoadConstant { to, name, ty } => {
let Some(FuncOrDataId::Data(data_id)) =
self.module.inner.get_name(name.as_str())
else {
panic!();
};
let val = self
.module
.inner
.declare_data_in_func(data_id, self.builder.func);
let ty = self.module.cranelift_type(ty);
let val = self.ins().global_value(ty, val);
let to = self.variable(to, ty);
self.def(to, val);
}
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/codegen/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,3 +904,26 @@ fn arc_type() {
output.drops.load(Ordering::Relaxed)
);
}

#[test]
fn use_constant() {
let s = src!(
"
filter-map main() {
define {
safi = 127.0.0.1;
}
apply {
if safi == LOCALHOSTV4 {
reject
}
accept
}
}"
);

let mut p = compile(s);
let f = p.get_function::<(), Verdict<(), ()>>("main").unwrap();
let output = f.call();
assert_eq!(output, Verdict::Reject(()));
}
13 changes: 12 additions & 1 deletion src/lower/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
ir::{Instruction, IntCmp, VarKind},
value::IrValue,
},
runtime::RuntimeFunction,
runtime::{RuntimeConstant, RuntimeFunction},
};
use std::collections::HashMap;

Expand Down Expand Up @@ -247,6 +247,7 @@ pub fn eval(
p: &[Function],
filter_map: &str,
mem: &mut Memory,
constants: &[RuntimeConstant],
rx: Vec<IrValue>,
) -> Option<IrValue> {
let filter_map_ident = Identifier::from(filter_map);
Expand All @@ -267,6 +268,11 @@ pub fn eval(
instructions.extend(block.instructions.clone());
}

let constants: HashMap<Identifier, &[u8]> = constants
.iter()
.map(|g| (g.name, g.bytes.as_ref()))
.collect();

// This is our working memory for the interpreter
let mut vars = HashMap::<Var, IrValue>::new();

Expand Down Expand Up @@ -335,6 +341,11 @@ pub fn eval(
let val = eval_operand(&vars, val);
vars.insert(to.clone(), val.clone());
}
Instruction::LoadConstant { to, name, ty } => {
let val = constants.get(name).unwrap();
let val = IrValue::from_slice(ty, val);
vars.insert(to.clone(), val.clone());
}
Instruction::Call { to, func, args } => {
let f = p.iter().find(|f| f.name == *func).unwrap();

Expand Down
10 changes: 10 additions & 0 deletions src/lower/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ pub enum Instruction {
ty: IrType,
},

/// Load a constant
LoadConstant {
to: Var,
name: Identifier,
ty: IrType,
},

/// Call a function.
Call {
to: Option<(Var, IrType)>,
Expand Down Expand Up @@ -339,6 +346,9 @@ impl<'a> IrPrinter<'a> {
Assign { to, val, ty } => {
format!("{}: {ty} = {}", self.var(to), self.operand(val),)
}
LoadConstant { to, name, ty } => {
format!("{}: {ty} = LoadConstant(\"{}\")", self.var(to), name)
}
Call {
to: Some((to, ty)),
func,
Expand Down
27 changes: 20 additions & 7 deletions src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,13 +650,26 @@ impl<'r> Lowerer<'r> {
ast::Expr::Var(x) => {
let DefinitionRef(scope, ident) =
self.type_info.resolved_name(x);
Some(
Var {
scope,
kind: VarKind::Explicit(ident),
}
.into(),
)

if ScopeRef::ROOT == scope {
let var = self.new_tmp();
let ty = self.type_info.type_of(id);
let ty = self.lower_type(&ty);
self.add(Instruction::LoadConstant {
to: var.clone(),
name: ident,
ty,
});
Some(var.into())
} else {
Some(
Var {
scope,
kind: VarKind::Explicit(ident),
}
.into(),
)
}
}
ast::Expr::TypedRecord(_, record) | ast::Expr::Record(record) => {
let ty = self.type_info.type_of(id);
Expand Down
35 changes: 0 additions & 35 deletions src/lower/test_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,38 +448,3 @@ fn ip_addr_method() {
assert_eq!(expected, u8::from_ne_bytes(res));
}
}

// #[test]
// fn prefix_addr() {
// let s = "
// filter-map main(x: Prefix) {
// apply {
// if x.address() == 0.0.0.0 {
// accept
// }
// reject
// }
// }
// ";

// assert_eq!(
// p(IrValue::from_any(Box::new(
// Prefix::from_str("0.0.0.0/8").unwrap()
// ))),
// Ok(())
// );

// assert_eq!(
// p(IrValue::from_any(Box::new(
// Prefix::from_str("127.0.0.0/8").unwrap()
// ))),
// Err(())
// );

// let mut mem = Memory::new();
// let program = compile(s);
// let pointer = mem.allocate(1);
// program.eval(&mut mem, vec![IrValue::Pointer(pointer), IrValue::U32(0)]);
// let res = mem.read(pointer, 1);
// assert_eq!(&[1], res);
// }
9 changes: 7 additions & 2 deletions src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
meta::{Span, Spans},
ParseError, Parser,
},
runtime::{ty::Reflect, Runtime},
runtime::{ty::Reflect, Runtime, RuntimeConstant},
typechecker::{
error::{Level, TypeError},
info::TypeInfo,
Expand Down Expand Up @@ -84,6 +84,7 @@ pub struct TypeChecked {
pub struct Lowered {
pub ir: Vec<ir::Function>,
runtime_functions: HashMap<usize, IrFunction>,
runtime_constants: Vec<RuntimeConstant>,
label_store: LabelStore,
type_info: TypeInfo,
}
Expand Down Expand Up @@ -397,9 +398,12 @@ impl TypeChecked {
println!("{s}");
}

let runtime_constants = runtime.constants.values().cloned().collect();

Lowered {
ir,
runtime_functions,
runtime_constants,
label_store,
type_info: type_infos.remove(0),
}
Expand All @@ -412,13 +416,14 @@ impl Lowered {
mem: &mut Memory,
rx: Vec<IrValue>,
) -> Option<IrValue> {
eval::eval(&self.ir, "main", mem, rx)
eval::eval(&self.ir, "main", mem, &self.runtime_constants, rx)
}

pub fn codegen(self) -> Compiled {
let module = codegen::codegen(
&self.ir,
&self.runtime_functions,
&self.runtime_constants,
self.label_store,
self.type_info,
);
Expand Down
Loading

0 comments on commit 6e193b1

Please sign in to comment.