Skip to content

Commit

Permalink
feat: Make the features testable and threadsafe.
Browse files Browse the repository at this point in the history
tl;dr. No idea why we need this, but if we don't the main thread
will attempt to release the `MatrixDescriptor` and fail.

This fix is a dirty hack which simply leaks those `MatrixDescriptor`.
Which *seems* ok since they get cleaned up by the releasepool at the end of
the program, but this is definitely a temporary workaround, which should
at least enable adding a few tests on the test suite.
  • Loading branch information
Narsil committed Oct 31, 2023
1 parent 3a4bd86 commit 5a66ce1
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,69 @@ macro_rules! foreign_obj_type {
}
}

impl ::std::fmt::Debug for $owned_ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
::std::ops::Deref::deref(self).fmt(f)
}
}
};
{
type CType = $raw_ident:ident;
pub struct $owned_ident:ident;
type ParentType = $parent_ident:ident;
nodrop;
} => {
foreign_obj_type! {
type CType = $raw_ident;
pub struct $owned_ident;
nodrop;
}

impl ::std::ops::Deref for paste!{[<$owned_ident Ref>]} {
type Target = paste!{[<$parent_ident Ref>]};

#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*(self as *const Self as *const Self::Target) }
}
}

impl ::std::convert::From<$owned_ident> for $parent_ident {
fn from(item: $owned_ident) -> Self {
unsafe { Self::from_ptr(::std::mem::transmute(item.into_ptr())) }
}
}
};
{
type CType = $raw_ident:ident;
pub struct $owned_ident:ident;
nodrop;
} => {
foreign_type! {
pub unsafe type $owned_ident: Sync + Send {
type CType = $raw_ident;
// TODO This is not really OK, but somehow the release for Drop
// makes the autoreleasepool drop it a second time at the end of the
// program leading to a crash.
fn drop = crate::obj_nodrop;
fn clone = crate::obj_clone;
}
}

unsafe impl ::objc::Message for $raw_ident {
}
unsafe impl ::objc::Message for paste!{[<$owned_ident Ref>]} {
}

impl ::std::fmt::Debug for paste!{[<$owned_ident Ref>]} {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
unsafe {
let string: *mut ::objc::runtime::Object = msg_send![self, debugDescription];
write!(f, "{}", crate::nsstring_as_str(&*string))
}
}
}

impl ::std::fmt::Debug for $owned_ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
::std::ops::Deref::deref(self).fmt(f)
Expand Down Expand Up @@ -612,6 +675,11 @@ unsafe fn obj_drop<T>(p: *mut T) {
msg_send![(p as *mut Object), release]
}

#[inline]
unsafe fn obj_nodrop<T>(p: *mut T) {
// msg_send![(p as *mut Object), release]
}

#[inline]
unsafe fn obj_clone<T: 'static>(p: *mut T) -> *mut T {
msg_send![(p as *mut Object), retain]
Expand Down
151 changes: 151 additions & 0 deletions src/mps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ foreign_obj_type! {
type CType = MPSMatrixDescriptor;
pub struct MatrixDescriptor;
type ParentType = NsObject;
nodrop;
}

impl MatrixDescriptor {
Expand Down Expand Up @@ -1241,3 +1242,153 @@ fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger)
// The right matrix must be larger or equal to result columns * interior columns
assert!(K * N >= M * N);
}

#[cfg(test)]
mod tests{
use super::*;
use rand::{Rng, thread_rng};
use std::ops::{Add, AddAssign, Mul};
// Naive matrix multiplication for testing
fn matrix_mul<T: MPSDataType>(
a: Vec<T::Type>,
b: Vec<T::Type>,
m: usize,
n: usize,
k: usize,
) -> Vec<T::Type>
where
T::Type: AddAssign + Mul<Output = T::Type> + Copy,
{
let size = m * n;

let mut c = Vec::with_capacity(size);

for idx in 0..size {
let i = idx / m;
let j = idx % n;

let mut sum = T::from_f64(0.0);
for di in 0..k {
sum += a[(i * k) + di] * b[(di * n) + j];
}
c.push(sum);
}

c
}

fn euclidean_distance<T>(a: Vec<T>, b: Vec<T>) -> f64
where
T: Into<f64> + Clone + Copy,
{
assert_eq!(a.len(), b.len(), "Lengths not equal");

let mut sum = 0.0;

for i in 0..a.len() {
sum += (a[i].into() - b[i].into()).powi(2);
}

sum.sqrt()
}

fn approx_eq<T>(a: Vec<T>, b: Vec<T>)
where
T: Into<f64> + Clone + Copy,
{
assert_eq!(a.len(), b.len(), "Lengths not equal");

let avg_magnitude = 0.004f64;
let avg_deviation = (a.len() as f64).sqrt();
let tolerance = avg_magnitude.max(avg_deviation * 3e-7);

let distance = euclidean_distance(a, b);
assert!(
distance < tolerance,
"Distance not less than tolerance: {} < {} ",
distance,
tolerance
);
}


fn generate_matrix<T, const ROWS: u64, const COLS: u64>(device: &Device) -> MatrixBuffer<T>
where
T: MPSDataType,
GEMMInput<T>: Valid,
{
let mut rng = thread_rng();

// Create descriptors for the matrices.
let row_bytes_for_columns = MatrixDescriptor::row_bytes_for_columns(COLS, T::TYPE_ID);

// Create buffers
let options = MTLResourceOptions::StorageModeShared;
let data = (0..ROWS * COLS)
.map(|_| T::from_f64(rng.gen()))
.collect::<Vec<T::Type>>();
let buffer =
device.new_buffer_with_data(data.as_ptr().cast(), ROWS * row_bytes_for_columns, options);

MatrixBuffer::from_buffer(buffer, ROWS, COLS)
}

#[test]
fn correctness(){
// First verify the correctness of the naive solution
let m = 3;
let n = 3;
let k = 2;
let a = vec![1, 2, 6, 24, 120, 720];
let b = vec![1, 2, 6, 24, 120, 720];
let result = matrix_mul::<Int32>(a, b, m, n, k);
assert_eq!(
result,
&[49, 242, 1446, 582, 2892, 17316, 17400, 86640, 519120]
);

const M: u64 = 100;
const N: u64 = 100;
const K: u64 = 100;
const ITERATIONS: usize = 50;

let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();

println!("Correctness: ");
for i in 0..ITERATIONS {
// progress_bar(i, ITERATIONS);

let a = generate_matrix::<Float32, M, K>(&device);
let b = generate_matrix::<Float32, K, N>(&device);
let mut c = generate_matrix::<Float32, K, N>(&device);

let command_buffer = command_queue.new_command_buffer();
encode_gemm(
&device,
command_buffer,
false,
false,
&a,
&b,
&mut c,
1.0,
0.0,
)
.expect("Encoding failed");
command_buffer.commit();
command_buffer.wait_until_completed();

let expected = matrix_mul::<Float32>(
a.contents(),
b.contents(),
M as usize,
K as usize,
N as usize,
);
approx_eq(c.contents(), expected);
}

// println!(" ✅\n");
}
}

0 comments on commit 5a66ce1

Please sign in to comment.