diff --git a/src/lib.rs b/src/lib.rs index c4962eb..76b364c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,6 +190,69 @@ macro_rules! foreign_obj_type { } } + impl ::std::fmt::Debug for $owned_ident { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + ::std::ops::Deref::deref(self).fmt(f) + } + } + }; + { + type CType = $raw_ident:ident; + pub struct $owned_ident:ident; + type ParentType = $parent_ident:ident; + nodrop; + } => { + foreign_obj_type! { + type CType = $raw_ident; + pub struct $owned_ident; + nodrop; + } + + impl ::std::ops::Deref for paste!{[<$owned_ident Ref>]} { + type Target = paste!{[<$parent_ident Ref>]}; + + #[inline] + fn deref(&self) -> &Self::Target { + unsafe { &*(self as *const Self as *const Self::Target) } + } + } + + impl ::std::convert::From<$owned_ident> for $parent_ident { + fn from(item: $owned_ident) -> Self { + unsafe { Self::from_ptr(::std::mem::transmute(item.into_ptr())) } + } + } + }; + { + type CType = $raw_ident:ident; + pub struct $owned_ident:ident; + nodrop; + } => { + foreign_type! { + pub unsafe type $owned_ident: Sync + Send { + type CType = $raw_ident; + // TODO This is not really OK, but somehow the release for Drop + // makes the autoreleasepool drop it a second time at the end of the + // program leading to a crash. + fn drop = crate::obj_nodrop; + fn clone = crate::obj_clone; + } + } + + unsafe impl ::objc::Message for $raw_ident { + } + unsafe impl ::objc::Message for paste!{[<$owned_ident Ref>]} { + } + + impl ::std::fmt::Debug for paste!{[<$owned_ident Ref>]} { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + unsafe { + let string: *mut ::objc::runtime::Object = msg_send![self, debugDescription]; + write!(f, "{}", crate::nsstring_as_str(&*string)) + } + } + } + impl ::std::fmt::Debug for $owned_ident { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { ::std::ops::Deref::deref(self).fmt(f) @@ -612,6 +675,11 @@ unsafe fn obj_drop(p: *mut T) { msg_send![(p as *mut Object), release] } +#[inline] +unsafe fn obj_nodrop(p: *mut T) { + // msg_send![(p as *mut Object), release] +} + #[inline] unsafe fn obj_clone(p: *mut T) -> *mut T { msg_send![(p as *mut Object), retain] diff --git a/src/mps.rs b/src/mps.rs index 5d57833..38facf2 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -837,6 +837,7 @@ foreign_obj_type! { type CType = MPSMatrixDescriptor; pub struct MatrixDescriptor; type ParentType = NsObject; + nodrop; } impl MatrixDescriptor { @@ -1241,3 +1242,153 @@ fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger) // The right matrix must be larger or equal to result columns * interior columns assert!(K * N >= M * N); } + +#[cfg(test)] +mod tests{ + use super::*; + use rand::{Rng, thread_rng}; + use std::ops::{Add, AddAssign, Mul}; +// Naive matrix multiplication for testing +fn matrix_mul( + a: Vec, + b: Vec, + m: usize, + n: usize, + k: usize, +) -> Vec +where + T::Type: AddAssign + Mul + Copy, +{ + let size = m * n; + + let mut c = Vec::with_capacity(size); + + for idx in 0..size { + let i = idx / m; + let j = idx % n; + + let mut sum = T::from_f64(0.0); + for di in 0..k { + sum += a[(i * k) + di] * b[(di * n) + j]; + } + c.push(sum); + } + + c +} + +fn euclidean_distance(a: Vec, b: Vec) -> f64 +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let mut sum = 0.0; + + for i in 0..a.len() { + sum += (a[i].into() - b[i].into()).powi(2); + } + + sum.sqrt() +} + +fn approx_eq(a: Vec, b: Vec) +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let avg_magnitude = 0.004f64; + let avg_deviation = (a.len() as f64).sqrt(); + let tolerance = avg_magnitude.max(avg_deviation * 3e-7); + + let distance = euclidean_distance(a, b); + assert!( + distance < tolerance, + "Distance not less than tolerance: {} < {} ", + distance, + tolerance + ); +} + + +fn generate_matrix(device: &Device) -> MatrixBuffer +where + T: MPSDataType, + GEMMInput: Valid, +{ + let mut rng = thread_rng(); + + // Create descriptors for the matrices. + let row_bytes_for_columns = MatrixDescriptor::row_bytes_for_columns(COLS, T::TYPE_ID); + + // Create buffers + let options = MTLResourceOptions::StorageModeShared; + let data = (0..ROWS * COLS) + .map(|_| T::from_f64(rng.gen())) + .collect::>(); + let buffer = + device.new_buffer_with_data(data.as_ptr().cast(), ROWS * row_bytes_for_columns, options); + + MatrixBuffer::from_buffer(buffer, ROWS, COLS) +} + + #[test] + fn correctness(){ + // First verify the correctness of the naive solution + let m = 3; + let n = 3; + let k = 2; + let a = vec![1, 2, 6, 24, 120, 720]; + let b = vec![1, 2, 6, 24, 120, 720]; + let result = matrix_mul::(a, b, m, n, k); + assert_eq!( + result, + &[49, 242, 1446, 582, 2892, 17316, 17400, 86640, 519120] + ); + + const M: u64 = 100; + const N: u64 = 100; + const K: u64 = 100; + const ITERATIONS: usize = 50; + + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + println!("Correctness: "); + for i in 0..ITERATIONS { + // progress_bar(i, ITERATIONS); + + let a = generate_matrix::(&device); + let b = generate_matrix::(&device); + let mut c = generate_matrix::(&device); + + let command_buffer = command_queue.new_command_buffer(); + encode_gemm( + &device, + command_buffer, + false, + false, + &a, + &b, + &mut c, + 1.0, + 0.0, + ) + .expect("Encoding failed"); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = matrix_mul::( + a.contents(), + b.contents(), + M as usize, + K as usize, + N as usize, + ); + approx_eq(c.contents(), expected); + } + + // println!(" ✅\n"); + } +}