diff --git a/melior/src/context.rs b/melior/src/context.rs index ffcedbd5b6..1c292e283f 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -12,7 +12,7 @@ use mlir_sys::{ mlirContextIsRegisteredOperation, mlirContextLoadAllAvailableDialects, mlirContextSetAllowUnregisteredDialects, MlirContext, MlirDiagnostic, MlirLogicalResult, }; -use std::{ffi::c_void, marker::PhantomData}; +use std::{ffi::c_void, marker::PhantomData, mem::transmute}; /// A context of IR, dialects, and passes. /// @@ -166,6 +166,22 @@ impl<'c> ContextRef<'c> { _reference: Default::default(), } } + + /// Returns a context. + /// + /// This function is different from `deref` because the correct lifetime is + /// kept for the return type. + /// + /// # Safety + /// + /// The returned reference is safe to use only in the lifetime scope of the + /// context reference. + pub unsafe fn to_ref(&self) -> &'c Context { + // As we can't deref ContextRef<'a> into `&'a Context`, we forcibly cast its + // lifetime here to extend it from the lifetime of `ObjectRef<'a>` itself into + // `'a`. + transmute(self) + } } impl<'c> PartialEq for ContextRef<'c> { @@ -299,4 +315,13 @@ mod tests { assert_ne!(&other, &one_ref); assert_ne!(&one_ref, &other); } + + #[test] + fn context_to_ref() { + let ctx = Context::new(); + let ctx_ref = ctx.to_ref(); + let ctx_ref_to_ref: &Context = unsafe { &ctx_ref.to_ref() }; + + assert_eq!(&ctx_ref, ctx_ref_to_ref); + } }