Skip to content

Commit

Permalink
feat(threads): deadlock prevention
Browse files Browse the repository at this point in the history
  • Loading branch information
elenaf9 committed Nov 3, 2024
1 parent cdf75ac commit ddf30b3
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 81 deletions.
28 changes: 16 additions & 12 deletions src/riot-rs-threads/src/arch/cortex_m.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{cleanup, Arch, Thread, THREADS};
use crate::{cleanup, Arch, Thread, Threads, THREADS};
use core::arch::asm;
use core::ptr::write_volatile;
use cortex_m::peripheral::{scb::SystemHandler, SCB};
Expand Down Expand Up @@ -195,10 +195,15 @@ unsafe extern "C" fn PendSV() {
unsafe fn sched() -> u128 {
loop {
if let Some(res) = THREADS.with(|threads| {
let (mut guard, mut tcbs) = threads.with_tcbs();
let (mut guard, mut current_threads) = guard.with_current_threads();
let current_pid = current_threads.current_pid_mut();
let mut runqueue = guard.runqueue();

#[cfg(feature = "multi-core")]
threads.add_current_thread_to_rq();
Threads::add_current_thread_to_rq(&mut runqueue, &tcbs, *current_pid);

let next_pid = match threads.get_next_pid() {
let next_pid = match Threads::get_next_pid(&mut runqueue, &tcbs) {
Some(pid) => pid,
None => {
#[cfg(feature = "multi-core")]
Expand All @@ -213,26 +218,25 @@ unsafe fn sched() -> u128 {
}
}
};
runqueue.release();

let mut tcbs = threads.tcbs();
let mut current_threads = threads.current_threads();
let old_pid = *current_pid;
*current_pid = Some(next_pid);
current_threads.release();

// `current_high_regs` will be null if there is no current thread.
// This is only the case once, when the very first thread starts running.
// The returned `r1` therefore will be null, and saving/ restoring
// the context is skipped.
let mut current_high_regs = core::ptr::null();
let current_pid = current_threads.current_pid_mut();
if let Some(current_pid) = current_pid {
if next_pid == *current_pid {
if let Some(current_pid) = old_pid {
if next_pid == current_pid {
return Some(0);
}
let current = tcbs.get_unchecked_mut(*current_pid);
let current = tcbs.get_unchecked_mut(current_pid);
current.sp = cortex_m::register::psp::read() as usize;
current_high_regs = current.data.as_ptr();
}
*current_pid = Some(next_pid);
current_threads.release();
};

let next = tcbs.get_unchecked(next_pid);
let next_sp = next.sp;
Expand Down
28 changes: 15 additions & 13 deletions src/riot-rs-threads/src/arch/riscv.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{cleanup, Arch, Thread, THREADS};
use crate::{cleanup, Arch, Thread, Threads, THREADS};
#[cfg(context = "esp32c6")]
use esp_hal::peripherals::INTPRI as SYSTEM;
#[cfg(context = "esp32c3")]
Expand Down Expand Up @@ -129,29 +129,31 @@ extern "C" fn FROM_CPU_INTR0(trap_frame: &mut TrapFrame) {
unsafe fn sched(trap_frame: &mut TrapFrame) {
loop {
if THREADS.with(|threads| {
let next_pid = match threads.get_next_pid() {
let (mut guard, mut tcbs) = threads.with_tcbs();
let (mut guard, mut current_threads) = guard.with_current_threads();
let mut runqueue = guard.runqueue();

let next_pid = match Threads::get_next_pid(&mut runqueue, &tcbs) {
Some(pid) => pid,
None => {
Cpu::wfi();
return false;
}
};
runqueue.release();

let mut tcbs = threads.tcbs();
let mut current_threads = threads.current_threads();
let current_pid = current_threads.current_pid_mut();
if let Some(current_pid) = current_pid {
if next_pid == *current_pid {
return true;
}
let current = tcbs.get_unchecked_mut(*current_pid);
copy_registers(trap_frame, &mut current.data);
}
let old_pid = *current_pid;
*current_pid = Some(next_pid);
current_threads.release();

let next = tcbs.get_unchecked(next_pid);
copy_registers(&next.data, trap_frame);
if let Some(current_pid) = old_pid {
if next_pid == current_pid {
return true;
}
copy_registers(trap_frame, &mut tcbs.get_unchecked_mut(current_pid).data);
}
copy_registers(&tcbs.get_unchecked(next_pid).data, trap_frame);
true
}) {
break;
Expand Down
32 changes: 17 additions & 15 deletions src/riot-rs-threads/src/arch/xtensa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use esp_hal::{
trapframe::TrapFrame,
};

use crate::{cleanup, Arch, THREADS};
use crate::{cleanup, Arch, Threads, THREADS};

pub struct Cpu;

Expand Down Expand Up @@ -104,29 +104,31 @@ extern "C" fn FROM_CPU_INTR1(trap_frame: &mut TrapFrame) {
unsafe fn sched(trap_frame: &mut TrapFrame) {
loop {
if THREADS.with(|threads| {
let (mut guard, mut tcbs) = threads.with_tcbs();
let (mut guard, mut current_threads) = guard.with_current_threads();
let current_pid = current_threads.current_pid_mut();
let mut runqueue = guard.runqueue();

#[cfg(feature = "multi-core")]
threads.add_current_thread_to_rq();
Threads::add_current_thread_to_rq(&mut runqueue, &tcbs, *current_pid);

let Some(next_pid) = threads.get_next_pid() else {
let Some(next_pid) = Threads::get_next_pid(&mut runqueue, &tcbs) else {
return false;
};
runqueue.release();

let mut tcbs = threads.tcbs();
let mut current_threads = threads.current_threads();
let current_pid = current_threads.current_pid_mut();
if let Some(current_pid) = current_pid {
if next_pid == *current_pid {
let old_pid = *current_pid;
*current_pid = Some(next_pid);
current_threads.release();

if let Some(current_pid) = old_pid {
if next_pid == current_pid {
return true;
}

let current = tcbs.get_unchecked_mut(*current_pid);
current.data = *trap_frame;
tcbs.get_unchecked_mut(current_pid).data = *trap_frame;
}
*current_pid = Some(next_pid);
current_threads.release();
*trap_frame = tcbs.get_unchecked(next_pid).data;

let next = tcbs.get_unchecked(next_pid);
*trap_frame = next.data;
true
}) {
break;
Expand Down
22 changes: 16 additions & 6 deletions src/riot-rs-threads/src/scheduler_lock.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
//! This module provides a Mutex-protected RefCell --- basically a way to ensure
//! at runtime that some reference is used only once.
use core::cell::UnsafeCell;

use critical_section::CriticalSection;

use crate::critical_section;

pub(crate) struct SchedulerLock<T> {
inner: T,
inner: UnsafeCell<T>,
}

impl<T> SchedulerLock<T> {
pub const fn new(inner: T) -> Self {
Self { inner }
Self {
inner: UnsafeCell::new(inner),
}
}

pub fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(&T) -> R,
F: FnOnce(&mut T) -> R,
{
critical_section::no_preemption_with(|| f(&self.inner))
critical_section::no_preemption_with(|| {
let inner = unsafe { &mut *self.inner.get() };
f(inner)
})
}

pub fn with_cs<F, R>(&self, _cs: CriticalSection, f: F) -> R
where
F: FnOnce(&T) -> R,
F: FnOnce(&mut T) -> R,
{
f(&self.inner)
let inner = unsafe { &mut *self.inner.get() };
f(inner)
}
}

unsafe impl<T> Sync for SchedulerLock<T> {}
10 changes: 5 additions & 5 deletions src/riot-rs-threads/src/thread_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub fn get() -> ThreadFlags {

impl Threads {
// thread flags implementation
fn flag_set(&self, thread_id: ThreadId, mask: ThreadFlags) {
fn flag_set(&mut self, thread_id: ThreadId, mask: ThreadFlags) {
let mut tcbs = self.tcbs();
let thread = tcbs.get_unchecked_mut(thread_id);
thread.flags |= mask;
Expand All @@ -114,7 +114,7 @@ impl Threads {
self.set_state(thread_id, ThreadState::Running);
}

fn flag_wait<F>(&self, cond: F, mode: WaitMode) -> Option<ThreadFlags>
fn flag_wait<F>(&mut self, cond: F, mode: WaitMode) -> Option<ThreadFlags>
where
F: Fn(u16) -> Option<u16>,
{
Expand All @@ -134,7 +134,7 @@ impl Threads {
}
}

fn flag_wait_all(&self, mask: ThreadFlags) -> Option<ThreadFlags> {
fn flag_wait_all(&mut self, mask: ThreadFlags) -> Option<ThreadFlags> {
self.flag_wait(
|thread_flags| {
let res = thread_flags & mask;
Expand All @@ -144,7 +144,7 @@ impl Threads {
)
}

fn flag_wait_any(&self, mask: ThreadFlags) -> Option<ThreadFlags> {
fn flag_wait_any(&mut self, mask: ThreadFlags) -> Option<ThreadFlags> {
self.flag_wait(
|thread_flags| {
let res = thread_flags & mask;
Expand All @@ -154,7 +154,7 @@ impl Threads {
)
}

fn flag_wait_one(&self, mask: ThreadFlags) -> Option<ThreadFlags> {
fn flag_wait_one(&mut self, mask: ThreadFlags) -> Option<ThreadFlags> {
self.flag_wait(
|thread_flags| {
let res = thread_flags & mask;
Expand Down
9 changes: 6 additions & 3 deletions src/riot-rs-threads/src/threadlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,25 @@ impl ThreadList {
let pid = threads
.current_pid()
.expect("Function should be called inside a thread context.");
let prio = threads.get_priority(pid);
let (mut guard, tcbs) = threads.with_tcbs();
let prio = tcbs.get_unchecked(pid).prio;
let mut curr = None;
let mut next = self.head;
let mut thread_blocklist = threads.thread_blocklist();
let mut thread_blocklist = guard.thread_blocklist();
while let Some(n) = next {
if threads.get_priority(n) < prio {
if tcbs.get_unchecked(n).prio < prio {
break;
}
curr = next;
next = thread_blocklist[usize::from(n)];
}
tcbs.release();
thread_blocklist[usize::from(pid)] = next;
match curr {
Some(curr) => thread_blocklist[usize::from(curr)] = Some(pid),
_ => self.head = Some(pid),
}
thread_blocklist.release();
threads.set_state(pid, state);
});
}
Expand Down
Loading

0 comments on commit ddf30b3

Please sign in to comment.