Skip to content

Commit

Permalink
feat: add get_extern_func
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed Jun 14, 2024
1 parent 4ae1b8a commit b55ce54
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 15 deletions.
58 changes: 44 additions & 14 deletions src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use anyhow::{anyhow, Result};
use delegate::delegate;
use hugr::{
ops::{FuncDecl, FuncDefn, NamedOp as _, OpType},
types::PolyFuncType,
HugrView, Node, NodeIndex,
};
use inkwell::{
context::Context,
module::Module,
module::{Linkage, Module},
types::{BasicTypeEnum, FunctionType},
values::{BasicValueEnum, FunctionValue},
};
Expand Down Expand Up @@ -205,39 +206,68 @@ impl<'c, H: HugrView> EmitModuleContext<'c, H> {
fn get_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &hugr::types::PolyFuncType,
func_ty: FunctionType<'c>,
linkage: Option<Linkage>,
) -> Result<FunctionValue<'c>> {
let sig = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(sig)?;
let name = self.name_func(name, node);
let func = self
.module()
.get_function(&name)
.unwrap_or_else(|| self.module.add_function(&name, llvm_func_ty, None));
if func.get_type() != llvm_func_ty {
.get_function(name.as_ref())
.unwrap_or_else(|| self.module.add_function(name.as_ref(), func_ty, linkage));
if func.get_type() != func_ty {
Err(anyhow!(
"Function '{name}' has wrong type: hugr: {func_ty} expected: {llvm_func_ty} actual: {}",
"Function '{}' has wrong type: expected: {func_ty} actual: {}",
name.as_ref(),
func.get_type()
))?
}
Ok(func)
}

fn get_hugr_func_impl(
&self,
name: impl AsRef<str>,
node: Node,
func_ty: &PolyFuncType,
) -> Result<FunctionValue<'c>> {
let func_ty = (func_ty.params().is_empty())
.then_some(func_ty.body())
.ok_or(anyhow!("function has type params"))?;
let llvm_func_ty = self.llvm_func_type(func_ty)?;
let name = self.name_func(name, node);
self.get_func_impl(name, llvm_func_ty, None)
}

/// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDefn].
///
/// The name of the result is mangled by [EmitModuleContext::name_func].
pub fn get_func_defn(&self, node: FatNode<'c, FuncDefn, H>) -> Result<FunctionValue<'c>> {
self.get_func_impl(&node.name, node.node(), &node.signature)
self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or gets the [FunctionValue] in the [Module] corresponding to the given [FuncDecl].
///
/// The name of the result is mangled by [EmitModuleContext::name_func].
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>> {
self.get_func_impl(&node.name, node.node(), &node.signature)
self.get_hugr_func_impl(&node.name, node.node(), &node.signature)
}

/// Adds or get the [FunctionValue] in the [Module] with the given symbol
/// and function type.
///
/// The name undergoes no mangling. The [FunctionValue] will have
/// [Linkage::External].
///
/// If this function is called multiple times with the same arguments it
/// will return the same [FunctionValue].
///
/// If a function with the given name exists but the type does not match
/// then an Error is returned.
pub fn get_extern_func(
&self,
symbol: impl AsRef<str>,
typ: FunctionType<'c>,
) -> Result<FunctionValue<'c>> {
self.get_func_impl(symbol, typ, Some(Linkage::External))
}

/// Consumes the `EmitModuleContext` and returns the internal [Module].
Expand Down
13 changes: 13 additions & 0 deletions src/emit/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub struct EmitFuncContext<'c, H: HugrView> {
impl<'c, H: HugrView> EmitFuncContext<'c, H> {
delegate! {
to self.emit_context {
/// Returns the inkwell [Context].
fn iw_context(&self) -> &'c Context;
/// Returns the internal [CodegenExtsMap] .
pub fn extensions(&self) -> Rc<CodegenExtsMap<'c,H>>;
Expand All @@ -78,6 +79,18 @@ impl<'c, H: HugrView> EmitFuncContext<'c, H> {
///
/// The name of the result may have been mangled.
pub fn get_func_decl(&self, node: FatNode<'c, FuncDecl, H>) -> Result<FunctionValue<'c>>;
/// Adds or get the [FunctionValue] in the [Module] with the given symbol
/// and function type.
///
/// The name undergoes no mangling. The [FunctionValue] will have
/// [Linkage::External].
///
/// If this function is called multiple times with the same arguments it
/// will return the same [FunctionValue].
///
/// If a function with the given name exists but the type does not match
/// then an Error is returned.
pub fn get_extern_func(&self, symbol: impl AsRef<str>, typ: FunctionType<'c>,) -> Result<FunctionValue<'c>>;
}
}

Expand Down
13 changes: 13 additions & 0 deletions src/emit/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,16 @@ fn emit_hugr_custom_op(#[with(-1, add_int_extensions)] llvm_ctx: TestContext) {
});
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn get_external_func(llvm_ctx: TestContext) {
llvm_ctx.with_emit_module_context(|emc| {
let func_type1 = emc.iw_context().i32_type().fn_type(&[], false);
let func_type2 = emc.iw_context().f64_type().fn_type(&[], false);
let foo1 = emc.get_extern_func("foo", func_type1).unwrap();
assert_eq!(foo1.get_name().to_str().unwrap(), "foo");
let foo2 = emc.get_extern_func("foo", func_type1).unwrap();
assert_eq!(foo1, foo2);
assert!(emc.get_extern_func("foo", func_type2).is_err());
});
}
17 changes: 16 additions & 1 deletion src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rstest::fixture;

use crate::{
custom::CodegenExtsMap,
emit::EmitHugr,
emit::{EmitHugr, EmitModuleContext, Namer},
types::{TypeConverter, TypingSession},
};

Expand Down Expand Up @@ -120,6 +120,21 @@ impl TestContext {
(r, ectx.finish())
})
}

pub fn with_emit_module_context<'c, T>(
&'c self,
f: impl FnOnce(EmitModuleContext<'c, THugrView>) -> T,
) -> T {
self.with_context(|ctx| {
let m = ctx.create_module("test_module");
f(EmitModuleContext::new(
m,
Namer::default().into(),
self.extensions(),
TypeConverter::new(ctx),
))
})
}
}

#[fixture]
Expand Down

0 comments on commit b55ce54

Please sign in to comment.