From b56052f0789232a62853d4711cbc34c0e3402183 Mon Sep 17 00:00:00 2001 From: Terts Diepraam Date: Wed, 4 Dec 2024 15:16:30 +0100 Subject: [PATCH 1/2] implement the codegen for global constants defined by the runtime --- examples/simple.rs | 4 +-- src/codegen/mod.rs | 50 +++++++++++++++++++++++++--- src/codegen/tests.rs | 23 +++++++++++++ src/lower/eval.rs | 13 +++++++- src/lower/ir.rs | 10 ++++++ src/lower/mod.rs | 27 +++++++++++---- src/lower/test_eval.rs | 35 ------------------- src/pipeline.rs | 9 +++-- src/runtime/mod.rs | 72 ++++++++++++++++++++++++++++++++++++++-- src/typechecker/mod.rs | 6 ++-- src/typechecker/scope.rs | 6 +++- src/typechecker/tests.rs | 15 +++++++++ src/typechecker/types.rs | 19 ----------- 13 files changed, 213 insertions(+), 76 deletions(-) diff --git a/examples/simple.rs b/examples/simple.rs index fbf18b59..cac21819 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -9,11 +9,11 @@ fn main() -> Result<(), roto::RotoReport> { let mut arguments = args(); let _program_name = arguments.next().unwrap(); - + let subcommand = arguments.next(); if Some("doc") == subcommand.as_deref() { runtime.print_documentation(); - return Ok(()) + return Ok(()); } let mut compiled = read_files(["examples/simple.roto"])? diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index f4fe5cd8..f308db34 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -13,7 +13,10 @@ use crate::{ value::IrType, IrFunction, }, - runtime::ty::{Reflect, GLOBAL_TYPE_REGISTRY}, + runtime::{ + ty::{Reflect, GLOBAL_TYPE_REGISTRY}, + RuntimeConstant, + }, typechecker::{info::TypeInfo, scope::ScopeRef, types}, IrValue, }; @@ -36,7 +39,7 @@ use cranelift::{ Variable, }, jit::{JITBuilder, JITModule}, - module::{DataDescription, FuncId, Linkage, Module as _}, + module::{DataDescription, FuncId, FuncOrDataId, Linkage, Module as _}, prelude::Signature, }; use cranelift_codegen::ir::SigRef; @@ -47,7 +50,7 @@ pub mod check; mod tests; /// A wrapper around a cranelift [`JITModule`] that cleans up after itself -/// +/// /// This is achieved by wrapping the module in an [`Arc`]. #[derive(Clone)] pub struct ModuleData(Arc>); @@ -105,10 +108,10 @@ pub struct Module { } /// A function extracted from Roto -/// +/// /// A [`TypedFunc`] can be retrieved from a compiled script using /// [`Compiled::get_function`](crate::Compiled::get_function). -/// +/// /// The function can be called with one of the [`TypedFunc::call`] functions. #[derive(Clone, Debug)] pub struct TypedFunc { @@ -226,6 +229,7 @@ const MEMFLAGS: MemFlags = MemFlags::new().with_aligned(); pub fn codegen( ir: &[ir::Function], runtime_functions: &HashMap, + constants: &[RuntimeConstant], label_store: LabelStore, type_info: TypeInfo, ) -> Module { @@ -274,6 +278,10 @@ pub fn codegen( clone_signature, }; + for constant in constants { + module.declare_constant(constant); + } + for (roto_func_id, func) in runtime_functions { let mut sig = module.inner.make_signature(); for ty in &func.params { @@ -308,6 +316,23 @@ pub fn codegen( } impl ModuleBuilder { + fn declare_constant(&mut self, constant: &RuntimeConstant) { + let data_id = self + .inner + .declare_data( + constant.name.as_str(), + Linkage::Local, + false, + false, + ) + .unwrap(); + + let mut description = DataDescription::new(); + description.define(constant.bytes.clone()); + + self.inner.define_data(data_id, &description).unwrap(); + } + /// Declare a function and its signature (without the body) fn declare_function(&mut self, func: &ir::Function) { let ir::Function { @@ -837,6 +862,21 @@ impl<'c> FuncGen<'c> { let var = self.variable(to, I32); self.def(var, val); } + ir::Instruction::LoadConstant { to, name, ty } => { + let Some(FuncOrDataId::Data(data_id)) = + self.module.inner.get_name(name.as_str()) + else { + panic!(); + }; + let val = self + .module + .inner + .declare_data_in_func(data_id, self.builder.func); + let ty = self.module.cranelift_type(ty); + let val = self.ins().global_value(ty, val); + let to = self.variable(to, ty); + self.def(to, val); + } } } diff --git a/src/codegen/tests.rs b/src/codegen/tests.rs index b25665bc..dcadf9a4 100644 --- a/src/codegen/tests.rs +++ b/src/codegen/tests.rs @@ -904,3 +904,26 @@ fn arc_type() { output.drops.load(Ordering::Relaxed) ); } + +#[test] +fn use_constant() { + let s = src!( + " + filter-map main() { + define { + safi = 127.0.0.1; + } + apply { + if safi == LOCALHOSTV4 { + reject + } + accept + } + }" + ); + + let mut p = compile(s); + let f = p.get_function::<(), Verdict<(), ()>>("main").unwrap(); + let output = f.call(); + assert_eq!(output, Verdict::Reject(())); +} diff --git a/src/lower/eval.rs b/src/lower/eval.rs index b0a59fa7..9bcbc739 100644 --- a/src/lower/eval.rs +++ b/src/lower/eval.rs @@ -13,7 +13,7 @@ use crate::{ ir::{Instruction, IntCmp, VarKind}, value::IrValue, }, - runtime::RuntimeFunction, + runtime::{RuntimeConstant, RuntimeFunction}, }; use std::collections::HashMap; @@ -247,6 +247,7 @@ pub fn eval( p: &[Function], filter_map: &str, mem: &mut Memory, + constants: &[RuntimeConstant], rx: Vec, ) -> Option { let filter_map_ident = Identifier::from(filter_map); @@ -267,6 +268,11 @@ pub fn eval( instructions.extend(block.instructions.clone()); } + let constants: HashMap = constants + .iter() + .map(|g| (g.name, g.bytes.as_ref())) + .collect(); + // This is our working memory for the interpreter let mut vars = HashMap::::new(); @@ -335,6 +341,11 @@ pub fn eval( let val = eval_operand(&vars, val); vars.insert(to.clone(), val.clone()); } + Instruction::LoadConstant { to, name, ty } => { + let val = constants.get(name).unwrap(); + let val = IrValue::from_slice(ty, val); + vars.insert(to.clone(), val.clone()); + } Instruction::Call { to, func, args } => { let f = p.iter().find(|f| f.name == *func).unwrap(); diff --git a/src/lower/ir.rs b/src/lower/ir.rs index fcdc599f..79220614 100644 --- a/src/lower/ir.rs +++ b/src/lower/ir.rs @@ -82,6 +82,13 @@ pub enum Instruction { ty: IrType, }, + /// Load a constant + LoadConstant { + to: Var, + name: Identifier, + ty: IrType, + }, + /// Call a function. Call { to: Option<(Var, IrType)>, @@ -339,6 +346,9 @@ impl<'a> IrPrinter<'a> { Assign { to, val, ty } => { format!("{}: {ty} = {}", self.var(to), self.operand(val),) } + LoadConstant { to, name, ty } => { + format!("{}: {ty} = LoadConstant(\"{}\")", self.var(to), name) + } Call { to: Some((to, ty)), func, diff --git a/src/lower/mod.rs b/src/lower/mod.rs index 614836b6..7cb4e0da 100644 --- a/src/lower/mod.rs +++ b/src/lower/mod.rs @@ -650,13 +650,26 @@ impl<'r> Lowerer<'r> { ast::Expr::Var(x) => { let DefinitionRef(scope, ident) = self.type_info.resolved_name(x); - Some( - Var { - scope, - kind: VarKind::Explicit(ident), - } - .into(), - ) + + if ScopeRef::ROOT == scope { + let var = self.new_tmp(); + let ty = self.type_info.type_of(id); + let ty = self.lower_type(&ty); + self.add(Instruction::LoadConstant { + to: var.clone(), + name: ident, + ty, + }); + Some(var.into()) + } else { + Some( + Var { + scope, + kind: VarKind::Explicit(ident), + } + .into(), + ) + } } ast::Expr::TypedRecord(_, record) | ast::Expr::Record(record) => { let ty = self.type_info.type_of(id); diff --git a/src/lower/test_eval.rs b/src/lower/test_eval.rs index 0e2a87fd..c85b3846 100644 --- a/src/lower/test_eval.rs +++ b/src/lower/test_eval.rs @@ -448,38 +448,3 @@ fn ip_addr_method() { assert_eq!(expected, u8::from_ne_bytes(res)); } } - -// #[test] -// fn prefix_addr() { -// let s = " -// filter-map main(x: Prefix) { -// apply { -// if x.address() == 0.0.0.0 { -// accept -// } -// reject -// } -// } -// "; - -// assert_eq!( -// p(IrValue::from_any(Box::new( -// Prefix::from_str("0.0.0.0/8").unwrap() -// ))), -// Ok(()) -// ); - -// assert_eq!( -// p(IrValue::from_any(Box::new( -// Prefix::from_str("127.0.0.0/8").unwrap() -// ))), -// Err(()) -// ); - -// let mut mem = Memory::new(); -// let program = compile(s); -// let pointer = mem.allocate(1); -// program.eval(&mut mem, vec![IrValue::Pointer(pointer), IrValue::U32(0)]); -// let res = mem.read(pointer, 1); -// assert_eq!(&[1], res); -// } diff --git a/src/pipeline.rs b/src/pipeline.rs index 6d895bc9..6f9b224e 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -21,7 +21,7 @@ use crate::{ meta::{Span, Spans}, ParseError, Parser, }, - runtime::{ty::Reflect, Runtime}, + runtime::{ty::Reflect, Runtime, RuntimeConstant}, typechecker::{ error::{Level, TypeError}, info::TypeInfo, @@ -84,6 +84,7 @@ pub struct TypeChecked { pub struct Lowered { pub ir: Vec, runtime_functions: HashMap, + runtime_constants: Vec, label_store: LabelStore, type_info: TypeInfo, } @@ -397,9 +398,12 @@ impl TypeChecked { println!("{s}"); } + let runtime_constants = runtime.constants.values().cloned().collect(); + Lowered { ir, runtime_functions, + runtime_constants, label_store, type_info: type_infos.remove(0), } @@ -412,13 +416,14 @@ impl Lowered { mem: &mut Memory, rx: Vec, ) -> Option { - eval::eval(&self.ir, "main", mem, rx) + eval::eval(&self.ir, "main", mem, &self.runtime_constants, rx) } pub fn codegen(self) -> Compiled { let module = codegen::codegen( &self.ir, &self.runtime_functions, + &self.runtime_constants, self.label_store, self.type_info, ); diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 8c6970ac..3a4cbfa8 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -32,13 +32,19 @@ pub mod ty; pub mod val; pub mod verdict; -use std::{any::TypeId, net::IpAddr}; +use std::{ + any::{type_name, TypeId}, + collections::HashMap, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, +}; use func::{Func, FunctionDescription}; use inetnum::{addr::Prefix, asn::Asn}; use roto_macros::{roto_method, roto_static_method}; use ty::{Ty, TypeDescription, TypeRegistry}; +use crate::ast::Identifier; + /// Provides the types and functions that Roto can access via FFI /// /// Even some types that can be written as literals should be provided here. @@ -48,6 +54,7 @@ use ty::{Ty, TypeDescription, TypeRegistry}; pub struct Runtime { pub runtime_types: Vec, pub functions: Vec, + pub constants: HashMap, pub type_registry: TypeRegistry, } @@ -153,6 +160,13 @@ pub struct DocumentedFunc { pub argument_names: &'static [&'static str], } +#[derive(Clone)] +pub struct RuntimeConstant { + pub name: Identifier, + pub ty: TypeId, + pub bytes: Box<[u8]>, +} + impl Runtime { /// Register a type with a default name /// @@ -373,6 +387,41 @@ impl Runtime { Ok(()) } + pub fn register_constant( + &mut self, + name: &str, + x: T, + ) -> Result<(), String> { + let type_id = TypeId::of::(); + self.find_type(type_id, type_name::())?; + let mut bytes: Vec = vec![0; size_of::()]; + unsafe { + std::ptr::copy_nonoverlapping( + &x as *const T as *const _, + bytes.as_mut_ptr(), + bytes.len(), + ) + }; + + let symbol = Identifier::from(name); + self.constants.insert( + symbol, + RuntimeConstant { + name: symbol, + ty: type_id, + bytes: bytes.into_boxed_slice(), + }, + ); + + Ok(()) + } + + pub fn iter_constants( + &self, + ) -> impl Iterator + '_ { + self.constants.values().map(|g| (g.name, g.ty)) + } + pub fn get_runtime_type(&self, id: TypeId) -> Option<&RuntimeType> { let ty = self.type_registry.get(id)?; let id = match ty.description { @@ -533,6 +582,7 @@ impl Runtime { runtime_types: Default::default(), functions: Default::default(), type_registry: Default::default(), + constants: Default::default(), }; rt.register_copy_type_with_name::<()>( @@ -667,6 +717,18 @@ impl Runtime { ip.to_canonical() } + rt.register_constant( + "LOCALHOSTV4", + IpAddr::from(Ipv4Addr::LOCALHOST), + ) + .unwrap(); + + rt.register_constant( + "LOCALHOSTV6", + IpAddr::from(Ipv6Addr::LOCALHOST), + ) + .unwrap(); + Ok(rt) } @@ -686,7 +748,7 @@ pub mod tests { use roto_macros::{roto_function, roto_method}; use routecore::bgp::{ aspath::{AsPath, HopPath}, - communities::Community, + communities::{Community, Wellknown}, types::{LocalPref, OriginType}, }; @@ -709,6 +771,12 @@ pub mod tests { x % 2 == 0 } + rt.register_constant( + "BLACKHOLE", + Community::from(Wellknown::Blackhole), + ) + .unwrap(); + Ok(rt) } diff --git a/src/typechecker/mod.rs b/src/typechecker/mod.rs index c1cdfd32..fd9e1f7c 100644 --- a/src/typechecker/mod.rs +++ b/src/typechecker/mod.rs @@ -102,14 +102,16 @@ impl TypeChecker<'_> { let root_scope = checker.scope_graph.root(); - for (v, t) in types::globals() { + for (v, t) in runtime.iter_constants() { checker.insert_var( root_scope, Meta { id: MetaId(0), node: v, }, - t, + Type::Name( + runtime.get_runtime_type(t).unwrap().name().into(), + ), )?; } diff --git a/src/typechecker/scope.rs b/src/typechecker/scope.rs index 6b4db469..7514e1d2 100644 --- a/src/typechecker/scope.rs +++ b/src/typechecker/scope.rs @@ -13,6 +13,10 @@ use super::Type; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ScopeRef(Option); +impl ScopeRef { + pub const ROOT: Self = Self(None); +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct DefinitionRef(pub ScopeRef, pub Identifier); @@ -47,7 +51,7 @@ impl ScopeGraph { /// Create a new root scope pub fn root(&mut self) -> ScopeRef { - ScopeRef(None) + ScopeRef::ROOT } /// Create a new scope over `scope` diff --git a/src/typechecker/tests.rs b/src/typechecker/tests.rs index fce926d3..ba9ae7e1 100644 --- a/src/typechecker/tests.rs +++ b/src/typechecker/tests.rs @@ -872,3 +872,18 @@ fn issue_51() { typecheck(s).unwrap_err(); } + +#[test] +fn use_globals() { + let s = src!( + " + filter-map main() { + apply { + accept BLACKHOLE + } + } + " + ); + + typecheck(s).unwrap(); +} diff --git a/src/typechecker/types.rs b/src/typechecker/types.rs index 902e737f..7e8a1b29 100644 --- a/src/typechecker/types.rs +++ b/src/typechecker/types.rs @@ -252,25 +252,6 @@ impl Function { } } -pub fn globals() -> Vec<(Identifier, Type)> { - let community = Identifier::from("Community"); - let safi = Identifier::from("Safi"); - let afi = Identifier::from("Afi"); - - [ - ("BLACKHOLE", Type::Name(community)), - ("UNICAST", Type::Name(safi)), - ("MULTICAST", Type::Name(safi)), - ("IPV4", Type::Name(afi)), - ("IPV6", Type::Name(afi)), - ("VPNV4", Type::Name(afi)), - ("VPNV6", Type::Name(afi)), - ] - .into_iter() - .map(|(s, t)| (Identifier::from(s), t)) - .collect() -} - pub fn default_types(runtime: &Runtime) -> Vec<(Identifier, Type)> { use Primitive::*; From d2940625bda90906e8d7be699016a2dbafa50789 Mon Sep 17 00:00:00 2001 From: Terts Diepraam Date: Fri, 6 Dec 2024 14:01:43 +0100 Subject: [PATCH 2/2] add global constants to the documentation --- src/runtime/mod.rs | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 3a4cbfa8..082e40b9 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -164,6 +164,7 @@ pub struct DocumentedFunc { pub struct RuntimeConstant { pub name: Identifier, pub ty: TypeId, + pub docstring: String, pub bytes: Box<[u8]>, } @@ -389,7 +390,8 @@ impl Runtime { pub fn register_constant( &mut self, - name: &str, + name: impl Into, + docstring: &str, x: T, ) -> Result<(), String> { let type_id = TypeId::of::(); @@ -403,12 +405,13 @@ impl Runtime { ) }; - let symbol = Identifier::from(name); + let symbol = Identifier::from(name.into()); self.constants.insert( symbol, RuntimeConstant { name: symbol, ty: type_id, + docstring: docstring.into(), bytes: bytes.into_boxed_slice(), }, ); @@ -458,12 +461,12 @@ impl Runtime { Ok(()) } - fn print_function(&self, f: &RuntimeFunction) { - let print_ty = |ty: TypeId| { - let ty = self.get_runtime_type(ty).unwrap(); - ty.name.as_ref() - }; + fn print_ty(&self, ty: TypeId) -> &str { + let ty = self.get_runtime_type(ty).unwrap(); + ty.name.as_ref() + } + fn print_function(&self, f: &RuntimeFunction) { let RuntimeFunction { name, description, @@ -475,7 +478,7 @@ impl Runtime { let mut params = description .parameter_types() .iter() - .map(|ty| print_ty(*ty)) + .map(|ty| self.print_ty(*ty)) .collect::>(); let ret = params.remove(0); @@ -488,7 +491,7 @@ impl Runtime { format!("{}.", params.next().unwrap()) } FunctionKind::StaticMethod(id) => { - format!("{}.", print_ty(id)) + format!("{}.", self.print_ty(id)) } FunctionKind::Free => "".into(), }; @@ -532,6 +535,22 @@ impl Runtime { self.print_function(f); } + for RuntimeConstant { + name, + ty, + docstring, + .. + } in self.constants.values() + { + println!( + "`````{{roto::constant}} {name}: {}", + self.print_ty(*ty) + ); + for line in docstring.lines() { + println!("{line}"); + } + println!("`````\n"); + } for RuntimeType { name, type_id, @@ -541,7 +560,7 @@ impl Runtime { { println!("`````{{roto:type}} {name}"); for line in docstring.lines() { - println!("{line}") + println!("{line}"); } println!(); @@ -719,12 +738,14 @@ impl Runtime { rt.register_constant( "LOCALHOSTV4", + "The IPv4 address pointing to localhost: `127.0.0.1`", IpAddr::from(Ipv4Addr::LOCALHOST), ) .unwrap(); rt.register_constant( "LOCALHOSTV6", + "The IPv6 address pointing to localhost: `::1`", IpAddr::from(Ipv6Addr::LOCALHOST), ) .unwrap(); @@ -773,6 +794,7 @@ pub mod tests { rt.register_constant( "BLACKHOLE", + "The well-known BLACKHOLE community.", Community::from(Wellknown::Blackhole), ) .unwrap();