diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 5eb6764..23ca818 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -291,6 +291,7 @@ impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15) pub(crate) struct FnRecorder { pub(crate) parent: Option, pub(crate) scopes: Vec, + pub(crate) kernel_id: usize, pub(crate) captured_resources: IndexMap)>, pub(crate) cpu_custom_ops: IndexMap)>, pub(crate) callables: IndexMap, @@ -335,7 +336,7 @@ impl FnRecorder { node } } - pub(crate) fn new() -> Self { + pub(crate) fn new(kernel_id: usize) -> Self { FnRecorder { scopes: vec![], captured_resources: IndexMap::new(), @@ -349,6 +350,7 @@ impl FnRecorder { arena: Bump::new(), building_kernel: false, callable_ret_type: None, + kernel_id, parent: None, } } @@ -398,7 +400,7 @@ fn make_safe_node(node: NodeRef) -> SafeNodeRef { with_recorder(|r| SafeNodeRef { recorder: r as *mut _, node, - kernel_id: KERNEL_ID.load(std::sync::atomic::Ordering::Relaxed), + kernel_id: r.kernel_id, }) } /// check if the node belongs to the current kernel/callable @@ -407,12 +409,13 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef { if node.node.is_user_data() { return node; } - let cur_kernel_id = KERNEL_ID.load(std::sync::atomic::Ordering::Relaxed); - assert_eq!( - cur_kernel_id, node.kernel_id, - "Referencing node from another kernel!" - ); + with_recorder(|r| { + let cur_kernel_id = r.kernel_id; + assert_eq!( + cur_kernel_id, node.kernel_id, + "Referencing node from another kernel!" + ); let ptr = r as *mut _; // defined in same callable, no need to capture if ptr == node.recorder { @@ -421,10 +424,10 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef { r.map_captured_vars(node) }) } -pub(crate) fn push_recorder() { - let mut new = Rc::new(RefCell::new(FnRecorder::new())); +pub(crate) fn push_recorder(kernel_id: usize) { RECORDER.with(|r| { let mut r = r.borrow_mut(); + let new = Rc::new(RefCell::new(FnRecorder::new(kernel_id))); let old = std::mem::replace(&mut *r, Some(new.clone())); new.borrow_mut().parent = old; }) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index eb25e08..92fb78c 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1305,8 +1305,12 @@ impl DynCallable { let r = r_ptr.as_mut().unwrap(); r.borrow().device.clone().unwrap() }; + let kernel_id = r_ptr.as_ref().unwrap().borrow().kernel_id; ( - std::mem::replace(&mut *r_ptr, Some(Rc::new(RefCell::new(FnRecorder::new())))), + std::mem::replace( + &mut *r_ptr, + Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id)))), + ), device.upgrade().unwrap(), ) }); diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 3748879..41ca6b7 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -187,14 +187,21 @@ macro_rules! impl_kernel_param_for_tuple { impl_kernel_param_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15); impl KernelBuilder { pub fn new(device: Option, is_kernel: bool) -> Self { - RECORDER.with(|r| { + let kernel_id = RECORDER.with(|r| { let r = r.borrow(); if is_kernel { assert!(r.is_none(), "Cannot record a kernel inside another kernel"); - KERNEL_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + KERNEL_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + } else { + let r = r.as_ref(); + if let Some(r) = r { + r.borrow().kernel_id + } else { + KERNEL_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + } } }); - push_recorder(); + push_recorder(kernel_id); with_recorder(|r| { r.device = device.as_ref().map(|d| WeakDevice::new(d)); r.pools = CArc::new(ModulePools::new()); @@ -554,7 +561,8 @@ macro_rules! impl_callable { pub fn new_static(f:fn($($Ts,)*)->R)->Self where fn($($Ts,)*)->R :CallableBuildFnR> { let r_backup = RECORDER.with(|r| { let mut r = r.borrow_mut(); - std::mem::replace(&mut *r, Some(Rc::new(RefCell::new(FnRecorder::new())))) + let kernel_id = r.as_ref().unwrap().borrow().kernel_id; + std::mem::replace(&mut *r, Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id))))) }); let mut builder = KernelBuilder::new(None, false); let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder);