Skip to content

Commit

Permalink
Improve generator performance
Browse files Browse the repository at this point in the history
  • Loading branch information
bkolobara committed Aug 23, 2020
1 parent 4794114 commit d44a1a8
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 61 deletions.
8 changes: 5 additions & 3 deletions example/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use async_wormhole::AsyncWormhole;
fn main() {
let task = AsyncWormhole::new(|mut yielder| {
panic!();
let x = yielder.async_suspend(async { 5 });
assert_eq!(x, 5);
panic!("Will a longer panic also fail. What about a really long one.");
let y = yielder.async_suspend(async { true });
assert_eq!(y, true);
42
}).unwrap();
})
.unwrap();

let output = futures::executor::block_on(task);
assert_eq!(output, 42);
assert_eq!(output.unwrap(), 42);
}
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::task::Poll;
use std::task::Waker;

pub struct AsyncWormhole<'a, Output> {
generator: Generator<'a, std::task::Waker, Output, EightMbStack>,
generator: Generator<'a, std::task::Waker, Option<Output>, EightMbStack>,
}

impl<'a, Output> AsyncWormhole<'a, Output> {
Expand All @@ -29,7 +29,7 @@ impl<'a, Output> AsyncWormhole<'a, Output> {
}

impl<'a, Output> Future for AsyncWormhole<'a, Output> {
type Output = Output;
type Output = Option<Output>;

fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
match self.generator.resume(cx.waker().clone()) {
Expand All @@ -40,12 +40,12 @@ impl<'a, Output> Future for AsyncWormhole<'a, Output> {
}

pub struct AsyncYielder<'a, Output> {
yielder: &'a Yielder<Waker, Output>,
yielder: &'a Yielder<Waker, Option<Output>>,
waker: Waker,
}

impl<'a, Output> AsyncYielder<'a, Output> {
pub(crate) fn new(yielder: &'a Yielder<Waker, Output>, waker: Waker) -> Self {
pub(crate) fn new(yielder: &'a Yielder<Waker, Option<Output>>, waker: Waker) -> Self {
Self { yielder, waker }
}

Expand Down
6 changes: 3 additions & 3 deletions switcheroo/benches/switcheroo_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use criterion::{criterion_group, criterion_main, Criterion};
use criterion::{black_box, criterion_group, criterion_main, Criterion};

use switcheroo::stack::*;
use switcheroo::Generator;
Expand All @@ -10,9 +10,9 @@ fn switcheroo(c: &mut Criterion) {
c.bench_function("switch stacks", |b| {
let stack = EightMbStack::new().unwrap();
let mut gen = Generator::new(stack, |yielder, input| {
yielder.suspend(Some(input + 1));
black_box(yielder.suspend(input + 1));
});
b.iter(|| gen.resume(2))
b.iter(|| black_box(gen.resume(2)))
});
}

Expand Down
32 changes: 25 additions & 7 deletions switcheroo/src/arch/unix_aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::stack;

pub unsafe fn init<S: stack::Stack>(
stack: &S,
f: unsafe extern "C" fn(usize, *mut usize) -> !,
f: unsafe extern "C" fn(usize, *mut usize),
) -> *mut usize {
unsafe fn push(mut sp: *mut usize, val: usize) -> *mut usize {
sp = sp.offset(-1);
Expand All @@ -22,7 +22,8 @@ pub unsafe fn init<S: stack::Stack>(
".cfi_def_cfa x29, 16",
".cfi_offset x30, -8",
".cfi_offset x29, -16",
"nop",)
"nop",
)
}

// Call frame for trampoline_2. The CFA slot is updated by swap::trampoline
Expand Down Expand Up @@ -56,22 +57,20 @@ pub unsafe fn swap_and_link_stacks(
new_sp: *mut usize,
sp: *const usize,
) -> (usize, *mut usize) {
let cfa_of_caller = sp.offset(-4);

let ret_val: usize;
let ret_sp: *mut usize;

asm!(
"adr lr, 1337f",
"stp x29, x30, [sp, #-16]!",
"mov x1, sp",
"str x1, [x3]",
"str x1, [x3 - 32]",
"mov sp, x2",
"ldp x29, x30, [sp], #16",
"br x30",
"1337:",

inout("x3") cfa_of_caller => _,
inout("x3") sp => _,
inout("x2") new_sp => _,
inout("x0") arg => ret_val,
out("x1") ret_sp,
Expand All @@ -82,6 +81,16 @@ pub unsafe fn swap_and_link_stacks(
out("x16") _, out("x17") _, out("x18") _, out("x19") _,
out("x20") _, out("x21") _, out("x22") _, out("x23") _,
out("x24") _, out("x25") _, out("x26") _, out("x27") _,
out("x28") _, out("lr") _,

out("v0") _, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _, out("v7") _,
out("v8") _, out("v9") _, out("v10") _, out("v11") _,
out("v12") _, out("v13") _, out("v14") _, out("v15") _,
out("v16") _, out("v17") _, out("v18") _, out("v19") _,
out("v20") _, out("v21") _, out("v22") _, out("v23") _,
out("v24") _, out("v25") _, out("v26") _, out("v27") _,
out("v28") _, out("v29") _, out("v30") _, out("v31") _,
);

(ret_val, ret_sp)
Expand Down Expand Up @@ -111,7 +120,16 @@ pub unsafe fn swap(arg: usize, new_sp: *mut usize) -> (usize, *mut usize) {
out("x16") _, out("x17") _, out("x18") _, out("x19") _,
out("x20") _, out("x21") _, out("x22") _, out("x23") _,
out("x24") _, out("x25") _, out("x26") _, out("x27") _,
out("x28") _,
out("x28") _, out("lr") _,

out("v0") _, out("v1") _, out("v2") _, out("v3") _,
out("v4") _, out("v5") _, out("v6") _, out("v7") _,
out("v8") _, out("v9") _, out("v10") _, out("v11") _,
out("v12") _, out("v13") _, out("v14") _, out("v15") _,
out("v16") _, out("v17") _, out("v18") _, out("v19") _,
out("v20") _, out("v21") _, out("v22") _, out("v23") _,
out("v24") _, out("v25") _, out("v26") _, out("v27") _,
out("v28") _, out("v29") _, out("v30") _, out("v31") _,
);

(ret_val, ret_sp)
Expand Down
10 changes: 4 additions & 6 deletions switcheroo/src/arch/unix_x64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::stack;

pub unsafe fn init<S: stack::Stack>(
stack: &S,
f: unsafe extern "C" fn(usize, *mut usize) -> !,
f: unsafe extern "C" fn(usize, *mut usize),
) -> *mut usize {
unsafe fn push(mut sp: *mut usize, val: usize) -> *mut usize {
sp = sp.offset(-1);
Expand Down Expand Up @@ -50,8 +50,6 @@ pub unsafe fn swap_and_link_stacks(
new_sp: *mut usize,
sp: *const usize,
) -> (usize, *mut usize) {
let cfa_of_caller = sp.offset(-4);

let ret_val: usize;
let ret_sp: *mut usize;

Expand All @@ -61,8 +59,8 @@ pub unsafe fn swap_and_link_stacks(
"push rax",
// Save the frame pointer as it can't be marked as an output register.
"push rbp",
// Link stacks
"mov [rcx], rsp",
// Link stacks by swapping the CFA value
"mov [rcx-32], rsp",
// Set the current pointer as the 2nd element (rsi) of the function we are jumping to.
"mov rsi, rsp",
// Change the stack pointer to the passed value.
Expand All @@ -76,7 +74,7 @@ pub unsafe fn swap_and_link_stacks(
"1337:",
// Mark all registers as clobbered as we don't know what the code we are jumping to is going to use.
// The compiler will optimise this out and just save the registers it actually knows it must.
inout("rcx") cfa_of_caller => _,
inout("rcx") sp => _,
inout("rdx") new_sp => _,
inout("rdi") arg => ret_val, // 1st argument to called function
out("rsi") ret_sp, // 2nd argument to called function
Expand Down
8 changes: 3 additions & 5 deletions switcheroo/src/arch/windows_x64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::stack;
/// call.
pub unsafe fn init<S: stack::Stack>(
stack: &S,
f: unsafe extern "C" fn(usize, *mut usize) -> !,
f: unsafe extern "C" fn(usize, *mut usize),
) -> *mut usize {
unsafe fn push(mut sp: *mut usize, val: usize) -> *mut usize {
sp = sp.offset(-1);
Expand Down Expand Up @@ -83,8 +83,6 @@ pub unsafe fn swap_and_link_stacks(
new_sp: *mut usize,
sp: *const usize,
) -> (usize, *mut usize) {
let cfa_of_caller = sp.offset(-6);

let ret_val: usize;
let ret_sp: *mut usize;

Expand All @@ -108,7 +106,7 @@ pub unsafe fn swap_and_link_stacks(
"push rax",

// Link stacks
"mov [rdi], rsp",
"mov [rdi-48], rsp",

// Set the current pointer as the 2nd element (rdx) of the function we are jumping to.
"mov rdx, rsp",
Expand All @@ -134,7 +132,7 @@ pub unsafe fn swap_and_link_stacks(
"1337:",
// Mark all registers as clobbered as we don't know what the code we are jumping to is going to use.
// The compiler will optimise this out and just save the registers it actually knows it must.
in("rdi") cfa_of_caller => _,
in("rdi") sp => _,
in("rsi") new_sp => _,
inout("rcx") arg => ret_val, // 1st argument to called function
out("rdx") ret_sp, // 2nd argument to called function
Expand Down
76 changes: 52 additions & 24 deletions switcheroo/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ pub mod stack;

use std::cell::Cell;
use std::marker::PhantomData;
use std::{mem, ptr};

pub struct Generator<'a, Input: 'a, Output: 'a, Stack: stack::Stack> {
stack: Stack,
stack_ptr: *mut usize,
stack_ptr: Option<ptr::NonNull<usize>>,
phantom: PhantomData<(&'a (), *mut Input, *const Output)>,
}

Expand All @@ -25,48 +26,60 @@ where
unsafe extern "C" fn generator_wrapper<Input, Output, Stack, F>(
f_ptr: usize,
stack_ptr: *mut usize,
) -> !
where
) where
Stack: stack::Stack,
F: FnOnce(&Yielder<Input, Output>, Input),
{
let f = std::ptr::read(f_ptr as *const F);
let (data, stack_ptr) = arch::swap(0, stack_ptr);
let input = std::ptr::read(data as *const Input);
let yielder = Yielder::new(stack_ptr);

f(&yielder, input);
// Any other call to resume will just yield back.
loop {
yielder.suspend(None);
}
// On last invocation of `suspend` return None
yielder.suspend_(None);
}

let stack_ptr = unsafe { arch::init(&stack, generator_wrapper::<Input, Output, Stack, F>) };
let f = mem::ManuallyDrop::new(f);
let stack_ptr = unsafe {
arch::swap_and_link_stacks(&f as *const F as usize, stack_ptr, stack.bottom()).1
arch::swap_and_link_stacks(
&f as *const mem::ManuallyDrop<F> as usize,
stack_ptr,
stack.bottom(),
)
.1
};
// We can't drop f when returning from this function. Maybe store it inside the Generator struct so it
// doesn't get dropped before the generator.
std::mem::forget(f);

Generator {
stack,
stack_ptr,
stack_ptr: Some(ptr::NonNull::new(stack_ptr).unwrap()),
phantom: PhantomData,
}
}

#[inline(always)]
pub fn resume(&mut self, input: Input) -> Option<Output> {
if self.stack_ptr.is_none() {
return None;
};
let stack_ptr = self.stack_ptr.unwrap();
self.stack_ptr = None;
unsafe {
let input = mem::ManuallyDrop::new(input);
let (data_out, stack_ptr) = arch::swap_and_link_stacks(
&input as *const Input as usize,
self.stack_ptr,
&input as *const mem::ManuallyDrop<Input> as usize,
stack_ptr.as_ptr(),
self.stack.bottom(),
);
self.stack_ptr = stack_ptr;
std::mem::forget(input);
std::ptr::read(data_out as *const Option<Output>)

// Should always be a pointer and never 0
if data_out == 0 {
return None;
} else {
self.stack_ptr = Some(ptr::NonNull::new_unchecked(stack_ptr));
Some(std::ptr::read(data_out as *const Output))
}
}
}
}
Expand All @@ -85,13 +98,28 @@ impl<Input, Output> Yielder<Input, Output> {
}

#[inline(always)]
pub fn suspend(&self, val: Option<Output>) -> Input {
unsafe {
let (data, stack_ptr) =
arch::swap(&val as *const Option<Output> as usize, self.stack_ptr.get());
self.stack_ptr.set(stack_ptr);
std::mem::forget(val);
std::ptr::read(data as *const Input)
pub fn suspend(&self, val: Output) -> Input {
unsafe { self.suspend_(Some(val)) }
}

#[inline(always)]
unsafe fn suspend_(&self, val: Option<Output>) -> Input {
match val {
None => {
// Let the resume know we are done here
arch::swap(0, self.stack_ptr.get());
unreachable!();
}
Some(val) => {
let val = mem::ManuallyDrop::new(val);
let (data, stack_ptr) = arch::swap(
&val as *const mem::ManuallyDrop<Output> as usize,
self.stack_ptr.get(),
);
self.stack_ptr.set(stack_ptr);

std::ptr::read(data as *const Input)
}
}
}
}
6 changes: 4 additions & 2 deletions switcheroo/src/stack/eight_mb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use winapi::ctypes::c_void;
#[cfg(target_family = "windows")]
use winapi::um::memoryapi::{VirtualAlloc, VirtualFree, VirtualProtect};
#[cfg(target_family = "windows")]
use winapi::um::winnt::{MEM_COMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_GUARD, PAGE_READWRITE, PAGE_NOACCESS};
use winapi::um::winnt::{
MEM_COMMIT, MEM_RELEASE, MEM_RESERVE, PAGE_GUARD, PAGE_NOACCESS, PAGE_READWRITE,
};

use super::Stack;

Expand Down Expand Up @@ -83,7 +85,7 @@ impl Stack for EightMbStack {

let old_protect: u32 = 0;
let bottom_1 = VirtualProtect(
ptr.add((EIGHT_MB + EXCEPTION_ZONE - 1 * 4096) / size_of::<usize>()) as *mut c_void,
ptr.add((EIGHT_MB + EXCEPTION_ZONE - 1 * 4096) / size_of::<usize>()) as *mut c_void,
1 * 4096,
PAGE_READWRITE,
&old_protect as *const u32 as *mut u32,
Expand Down
5 changes: 2 additions & 3 deletions switcheroo/tests/switch_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ fn switch_stack() {
if input == 0 {
break;
}
input = yielder.suspend(Some(input + 1));
input = yielder.suspend(input + 1);
}
});
assert_eq!(add_one.resume(2), Some(3));
Expand Down Expand Up @@ -41,7 +41,6 @@ fn rec(n: u64) -> u8 {
}
}


#[test]
#[should_panic]
fn panic_on_different_stack() {
Expand All @@ -50,4 +49,4 @@ fn panic_on_different_stack() {
panic!("Ups");
});
let _: u32 = add_one.resume(0).unwrap();
}
}
Loading

0 comments on commit d44a1a8

Please sign in to comment.