Skip to content

Commit

Permalink
fix kernel_id
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 10, 2023
1 parent 926e830 commit 5998289
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
21 changes: 12 additions & 9 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ impl_aggregate_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15)
pub(crate) struct FnRecorder {
pub(crate) parent: Option<FnRecorderPtr>,
pub(crate) scopes: Vec<IrBuilder>,
pub(crate) kernel_id: usize,
pub(crate) captured_resources: IndexMap<Binding, (usize, NodeRef, Binding, Arc<dyn Any>)>,
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
pub(crate) callables: IndexMap<u64, CallableModuleRef>,
Expand Down Expand Up @@ -335,7 +336,7 @@ impl FnRecorder {
node
}
}
pub(crate) fn new() -> Self {
pub(crate) fn new(kernel_id: usize) -> Self {
FnRecorder {
scopes: vec![],
captured_resources: IndexMap::new(),
Expand All @@ -349,6 +350,7 @@ impl FnRecorder {
arena: Bump::new(),
building_kernel: false,
callable_ret_type: None,
kernel_id,
parent: None,
}
}
Expand Down Expand Up @@ -398,7 +400,7 @@ fn make_safe_node(node: NodeRef) -> SafeNodeRef {
with_recorder(|r| SafeNodeRef {
recorder: r as *mut _,
node,
kernel_id: KERNEL_ID.load(std::sync::atomic::Ordering::Relaxed),
kernel_id: r.kernel_id,
})
}
/// check if the node belongs to the current kernel/callable
Expand All @@ -407,12 +409,13 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef {
if node.node.is_user_data() {
return node;
}
let cur_kernel_id = KERNEL_ID.load(std::sync::atomic::Ordering::Relaxed);
assert_eq!(
cur_kernel_id, node.kernel_id,
"Referencing node from another kernel!"
);

with_recorder(|r| {
let cur_kernel_id = r.kernel_id;
assert_eq!(
cur_kernel_id, node.kernel_id,
"Referencing node from another kernel!"
);
let ptr = r as *mut _;
// defined in same callable, no need to capture
if ptr == node.recorder {
Expand All @@ -421,10 +424,10 @@ fn process_potential_capture(node: SafeNodeRef) -> SafeNodeRef {
r.map_captured_vars(node)
})
}
pub(crate) fn push_recorder() {
let mut new = Rc::new(RefCell::new(FnRecorder::new()));
pub(crate) fn push_recorder(kernel_id: usize) {
RECORDER.with(|r| {
let mut r = r.borrow_mut();
let new = Rc::new(RefCell::new(FnRecorder::new(kernel_id)));
let old = std::mem::replace(&mut *r, Some(new.clone()));
new.borrow_mut().parent = old;
})
Expand Down
6 changes: 5 additions & 1 deletion luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1305,8 +1305,12 @@ impl<S: CallableSignature> DynCallable<S> {
let r = r_ptr.as_mut().unwrap();
r.borrow().device.clone().unwrap()
};
let kernel_id = r_ptr.as_ref().unwrap().borrow().kernel_id;
(
std::mem::replace(&mut *r_ptr, Some(Rc::new(RefCell::new(FnRecorder::new())))),
std::mem::replace(
&mut *r_ptr,
Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id)))),
),
device.upgrade().unwrap(),
)
});
Expand Down
16 changes: 12 additions & 4 deletions luisa_compute/src/runtime/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,21 @@ macro_rules! impl_kernel_param_for_tuple {
impl_kernel_param_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);
impl KernelBuilder {
pub fn new(device: Option<crate::runtime::Device>, is_kernel: bool) -> Self {
RECORDER.with(|r| {
let kernel_id = RECORDER.with(|r| {
let r = r.borrow();
if is_kernel {
assert!(r.is_none(), "Cannot record a kernel inside another kernel");
KERNEL_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
KERNEL_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
} else {
let r = r.as_ref();
if let Some(r) = r {
r.borrow().kernel_id
} else {
KERNEL_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
}
});
push_recorder();
push_recorder(kernel_id);
with_recorder(|r| {
r.device = device.as_ref().map(|d| WeakDevice::new(d));
r.pools = CArc::new(ModulePools::new());
Expand Down Expand Up @@ -554,7 +561,8 @@ macro_rules! impl_callable {
pub fn new_static(f:fn($($Ts,)*)->R)->Self where fn($($Ts,)*)->R :CallableBuildFn<fn($($Ts,)*)->R> {
let r_backup = RECORDER.with(|r| {
let mut r = r.borrow_mut();
std::mem::replace(&mut *r, Some(Rc::new(RefCell::new(FnRecorder::new()))))
let kernel_id = r.as_ref().unwrap().borrow().kernel_id;
std::mem::replace(&mut *r, Some(Rc::new(RefCell::new(FnRecorder::new(kernel_id)))))
});
let mut builder = KernelBuilder::new(None, false);
let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder);
Expand Down

0 comments on commit 5998289

Please sign in to comment.